# ResNet for MNIST in PyTorch 1.7
For details see

https://zablo.net/blog/post/pytorch-resnet-mnist-jupyter-notebook-2021/
---



In [1]:
# Install dependencies if needed
# !pip install -r requirements.txt

In [2]:
!nvidia-smi

Thu Nov 17 19:58:45 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.85.02    Driver Version: 510.85.02    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:01:00.0  On |                  N/A |
| N/A   56C    P3    25W /  N/A |    471MiB /  8192MiB |     12%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+---------------------------------------------------------------------------

In [3]:
import torch
torch.__version__

'1.12.1+cu113'

In [4]:
from torchvision.models import resnet18
from torch import nn
from torch.utils.data import DataLoader
import tqdm as tqdm
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from sklearn.metrics import accuracy_score
from torch.utils.data import RandomSampler

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
train_ds = MNIST("mnist", train=True, download=True, transform=ToTensor())
test_ds = MNIST("mnist", train=False, download=True, transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/train-labels-idx1-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to mnist/MNIST/raw



In [7]:
# train_dl = DataLoader(dataset=train_ds, batch_size=64, shuffle=True, sampler=RandomSampler(train_ds, num_samples=64)) # Normal mini-batching
train_dl = DataLoader(dataset=train_ds, batch_size=64, sampler=RandomSampler(train_ds)) # Random drawing

test_dl = DataLoader(test_ds, batch_size=64)

In [8]:
class ResNetMNIST(torch.nn.Module):
  def __init__(self):
    super().__init__()
    self.model = resnet18(num_classes=10) # Use ResNet18
    self.model.conv1 = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) # Fit first layer to shape of MNIST data

  def forward(self, x):
    return self.model(x)

model = ResNetMNIST().to(device)
criterion = nn.CrossEntropyLoss(reduction='none') # return loss per datapoint rather than the mean.
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

for epoch in range(4):  # loop over the dataset multiple times
    running_loss = 0.0
    print(f'==Epoch {epoch}==')
    for i, data in enumerate(train_dl, 0):
        # get the inputs; data is a list of [inputs, labels]
        inputs, labels = data

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = model(inputs.to(device))
        data_point_loss = criterion(outputs, labels.to(device))
        loss = data_point_loss.mean()
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i > 0 and i % 100 == 0:    # print every 2000 mini-batches
            print(f'[{epoch}, {i:5d}] loss: {running_loss / 100:.3f}')
            running_loss = 0.0

print('Finished Training')

==Epoch 0==
[0,   100] loss: 0.735
[0,   200] loss: 0.210
[0,   300] loss: 0.149
[0,   400] loss: 0.114
[0,   500] loss: 0.104
[0,   600] loss: 0.086
[0,   700] loss: 0.097
[0,   800] loss: 0.083
[0,   900] loss: 0.085
==Epoch 1==
[1,   100] loss: 0.053
[1,   200] loss: 0.043
[1,   300] loss: 0.051
[1,   400] loss: 0.047
[1,   500] loss: 0.039
[1,   600] loss: 0.045
[1,   700] loss: 0.043
[1,   800] loss: 0.039
[1,   900] loss: 0.051
==Epoch 2==
[2,   100] loss: 0.023
[2,   200] loss: 0.021
[2,   300] loss: 0.022
[2,   400] loss: 0.024
[2,   500] loss: 0.028
[2,   600] loss: 0.025
[2,   700] loss: 0.021
[2,   800] loss: 0.022
[2,   900] loss: 0.023
==Epoch 3==
[3,   100] loss: 0.014
[3,   200] loss: 0.011
[3,   300] loss: 0.012
[3,   400] loss: 0.012
[3,   500] loss: 0.013
[3,   600] loss: 0.012
[3,   700] loss: 0.014
[3,   800] loss: 0.015
[3,   900] loss: 0.012
Finished Training


In [75]:
def get_prediction(x, model: nn.Module):
  model.eval() # prepares model for predicting
  probabilities = torch.softmax(model(x), dim=1)
  predicted_class = torch.argmax(probabilities, dim=1)
  return predicted_class, probabilities

In [76]:
true_y, pred_y = [], []
for batch in tqdm.tqdm(iter(test_dl), total=len(test_dl)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x.to(device), model)
  pred_y.extend(preds.cpu())

100%|██████████| 157/157 [00:01<00:00, 100.84it/s]


In [77]:
print(accuracy_score(true_y, pred_y))

0.9866
