# ResNet for MNIST in PyTorch 1.7
For details see

https://zablo.net/blog/post/pytorch-resnet-mnist-jupyter-notebook-2021/
---



In [1]:
# !pip install torch torchvision pytorch-lightning sklearn

In [2]:
!nvidia-smi

Sat Mar 06 01:02:41 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 452.56       Driver Version: 452.56       CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 2060   WDDM  | 00000000:01:00.0 Off |                  N/A |
| N/A   49C    P8    10W /  N/A |    164MiB /  6144MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [3]:
import torch
torch.__version__

'1.8.0+cu111'

In [4]:
from torchvision.models import resnet50
from torch import nn
from torch.utils.data import DataLoader

In [5]:
model = resnet50(num_classes = 10)

In [6]:
model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In order to adapt this architecture for MNIST there is one more thing required - input layer needs to accept single channel instead of 3 (MNIST images are single-channel = grayscale, whereas ImageNet are 3-channels = RGB).

In [7]:
model.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias=False)

In [8]:
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor

In [9]:
train_ds = MNIST("mnist", train = True, download = True, transform = ToTensor())
test_ds = MNIST("mnist", train = False, download = True, transform = ToTensor())

In [10]:
train_dl = DataLoader(train_ds, batch_size = 64, shuffle = True)
test_dl = DataLoader(test_ds, batch_size = 64)

In [11]:
import pytorch_lightning as pl
from pytorch_lightning.core.decorators import auto_move_data

In [12]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = resnet50(num_classes = 10)
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size = (7, 7), stride = (2, 2), padding = (3, 3), bias = False)
    self.loss = nn.CrossEntropyLoss()

  @auto_move_data
  def forward(self, x):
    return self.model(x)
  
  def training_step(self, batch, batch_no):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    return loss
  
  def configure_optimizers(self):
    return torch.optim.RMSprop(self.parameters(), lr = 0.005)

In [13]:
model = ResNetMNIST()

In [14]:
trainer = pl.Trainer(
    gpus = 1,
    max_epochs = 20,
    progress_bar_refresh_rate = 20
)

GPU available: True, used: True
TPU available: None, using: 0 TPU cores


In [15]:
trainer.fit(model, train_dl)


  | Name  | Type             | Params
-------------------------------------------
0 | model | ResNet           | 23.5 M
1 | loss  | CrossEntropyLoss | 0     
-------------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.089    Total estimated model params size (MB)
Epoch 29: 100%|██████████| 938/938 [00:48<00:00, 19.17it/s, loss=0.0097, v_num=6]


1

In [16]:
trainer.save_checkpoint("resnet18_mnist.pt")

In [17]:
def get_prediction(x, model: pl.LightningModule):
  model.freeze() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim = 1)
  predicted_class = torch.argmax(probabilities, dim = 1)
  return predicted_class, probabilities

In [18]:
from tqdm.autonotebook import tqdm

In [19]:
inference_model = ResNetMNIST.load_from_checkpoint("resnet18_mnist.pt", map_location = "cuda")

In [20]:
true_y, pred_y = [], []
for batch in tqdm(iter(test_dl), total = len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

100%|██████████| 157/157 [00:24<00:00,  6.35it/s]


In [21]:
from sklearn.metrics import classification_report

In [22]:
print(classification_report(true_y, pred_y, digits = 4))

              precision    recall  f1-score   support

           0     0.9859    0.9990    0.9924       980
           1     0.9277    0.9947    0.9600      1135
           2     0.9896    0.9196    0.9533      1032
           3     0.9786    0.9980    0.9882      1010
           4     0.9929    0.9949    0.9939       982
           5     0.9933    0.9910    0.9921       892
           6     0.9968    0.9864    0.9916       958
           7     0.9941    0.9893    0.9917      1028
           8     0.9969    0.9856    0.9912       974
           9     0.9960    0.9802    0.9880      1009

    accuracy                         0.9837     10000
   macro avg     0.9852    0.9839    0.9843     10000
weighted avg     0.9843    0.9837    0.9837     10000

