<a href="https://colab.research.google.com/github/DUNGTK2004/deep-models-from-scratch/blob/main/resnet_from_scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Import dependencies

In [1]:
import torch
import torchvision
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
import time
import tqdm

In [2]:
print(torch.__version__)

2.5.1+cu124


### Load data

In [2]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)),])
batch_size = 128

In [3]:
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
mnist_trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True, num_workers=2)
mnist_testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size, shuffle=False, num_workers=2)

Visualize data

In [None]:
# If error, only need change 3 -> 1 in transforms
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
  sample_idx = torch.randint(len(mnist_trainset), size=(1,)).item()
  img, label = mnist_trainset[sample_idx]
  figure.add_subplot(rows, cols, i)
  plt.title(label)
  plt.axis("off")
  plt.imshow(img.squeeze(), cmap="gray")

In [42]:
# Model
net = models.resnet18(pretrained=False)
net.fc = nn.Linear(512, 10, bias=True)
print(net)

### Resnet from scratch

In [76]:
class BasicBlock(nn.Module):
  def __init__(self, in_channels, channels, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1)
    self.bn1 = nn.BatchNorm2d(channels)
    self.relu1 = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(channels, channels, stride=1, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(channels)
    self.relu2 = nn.ReLU(inplace=True)
    self.shortcut = nn.Sequential()
    if in_channels != channels or stride == 2:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, channels, kernel_size=1, stride=2),
          nn.BatchNorm2d(channels)
      )

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu1(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out += self.shortcut(x)
    out = self.relu2(out)
    return out


class resnetModel(nn.Module):
  def __init__(self, basic_block, num_blocks, zero_init_residual=False):
    super(resnetModel, self).__init__()
    self.in_channel = 64
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu1 = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(1, num_blocks[0], 64, basic_block)
    self.layer2 = self._make_layer(2, num_blocks[1], 128, basic_block)
    self.layer3 = self._make_layer(2, num_blocks[2], 256, basic_block)
    self.layer4 = self._make_layer(2, num_blocks[3], 512, basic_block)
    self.averagePooling = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(in_features=512, out_features=10)

  def _make_layer(self, stride_first, num_block, channel, basic_block):
    strides = [stride_first] + [1] * (num_block-1)
    layer = []
    for i in range(num_block):
      stride = strides[i]
      layer.append(basic_block(self.in_channel, channel, stride))
      self.in_channel = channel
    return nn.Sequential(*layer)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = F.relu(x)
    x = self.maxpool(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.averagePooling(x)
    x = x.view(x.size(0), -1)
    x = self.fc(x)
    return x

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)

model = resnetModel(BasicBlock, [2, 2, 2, 2])
model.apply(weights_init)


resnetModel(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu

In [13]:
total_params = sum(p.numel() for p in net.parameters())
print(f"Total parameters: {total_params}")

In [40]:
def test_accuracy(data_iter, net):
  acc_sum, n = 0, 0
  for (imgs, labels) in data_iter:
    # send data to the GPU if cuda is available
    if torch.cuda.is_available():
      imgs = imgs.cuda()
      labels = labels.cuda()
    net.eval()
    with torch.no_grad():
      labels = labels.long()
      acc_sum += torch.sum((torch.argmax(net(imgs), dim=1) == labels)).float()
      n += labels.shape[0]

  return acc_sum.item()/n

Train / test

In [54]:
if torch.cuda.is_available():
 model =model.cuda()

opt = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
  n, start = 0, time.time()
  train_l_sum = torch.tensor([0.0], dtype=torch.float32)
  train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
  for i, (imgs, labels) in tqdm.tqdm(enumerate(mnist_trainloader)):
    model.train()
    # If training on GPU
    if torch.cuda.is_available():
      imgs = imgs.cuda()
      labels = labels.cuda()
      train_l_sum = train_l_sum.cuda()
      train_acc_sum = train_acc_sum.cuda()

    opt.zero_grad()

    # loss function
    output =model(imgs)

    loss = criterion(output, labels)
    loss.backward()
    opt.step()

    # Calculate training error
    model.eval()
    labels = labels.long()
    train_l_sum += loss.float()
    train_acc_sum += (torch.sum((torch.argmax(output, dim=1) == labels))).float()
    n += labels.shape[0]
  test_acc = test_accuracy(iter(mnist_testloader),model)
  print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec' \
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time()-start))

469it [00:12, 36.15it/s]


epoch 1, loss 0.0011, train acc 0.959, test acc 0.984, time 15.1 sec


469it [00:13, 35.99it/s]


epoch 2, loss 0.0003, train acc 0.989, test acc 0.986, time 14.8 sec


469it [00:12, 36.26it/s]


epoch 3, loss 0.0001, train acc 0.995, test acc 0.989, time 14.6 sec


469it [00:12, 36.61it/s]


epoch 4, loss 0.0001, train acc 0.998, test acc 0.989, time 14.4 sec


469it [00:12, 36.26it/s]


epoch 5, loss 0.0000, train acc 0.999, test acc 0.990, time 14.6 sec


469it [00:12, 36.49it/s]


epoch 6, loss 0.0000, train acc 1.000, test acc 0.991, time 14.5 sec


469it [00:12, 36.31it/s]


epoch 7, loss 0.0000, train acc 1.000, test acc 0.991, time 15.2 sec


469it [00:12, 36.48it/s]


epoch 8, loss 0.0000, train acc 1.000, test acc 0.991, time 14.5 sec


469it [00:12, 36.28it/s]


epoch 9, loss 0.0000, train acc 1.000, test acc 0.990, time 14.6 sec


469it [00:12, 36.60it/s]


epoch 10, loss 0.0000, train acc 1.000, test acc 0.991, time 14.4 sec


### Pretrained Resnet

In [30]:
# print(models.resnet18())
class ResNetFeatrueExtractor18(nn.Module):
    def __init__(self, pretrained = True):
        super(ResNetFeatrueExtractor18, self).__init__()
        model_resnet18 = models.resnet18(pretrained=pretrained)
        self.conv1 = model_resnet18.conv1
        self.bn1 = model_resnet18.bn1
        self.relu = model_resnet18.relu
        self.maxpool = model_resnet18.maxpool
        self.layer1 = model_resnet18.layer1
        self.layer2 = model_resnet18.layer2
        self.layer3 = model_resnet18.layer3
        self.layer4 = model_resnet18.layer4
        self.avgpool = model_resnet18.avgpool


    def forward(self, x):
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.maxpool(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)

        return x

class ResClassifier(nn.Module):
    def __init__(self, dropout_p=0.5): #in_features=512
        super(ResClassifier, self).__init__()
        self.fc = nn.Linear(512, 10)
    def forward(self, x):
        out = self.fc(x)
        return out

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)

# calculate test accuracy
def test_accuracy(data_iter, netG, netF):
    """Evaluate testset accuracy of a model."""
    acc_sum,n = 0,0
    for (imgs, labels) in data_iter:
        # send data to the GPU if cuda is availabel
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
        netG.eval()
        netF.eval()
        with torch.no_grad():
            labels = labels.long()
            acc_sum += torch.sum((torch.argmax(netF(netG(imgs)), dim=1) == labels)).float()
            n += labels.shape[0]
    return acc_sum.item()/n

In [None]:
netG = ResNetFeatrueExtractor18(pretrained = True)
netF = ResClassifier()

if torch.cuda.is_available():
    netG = netG.cuda()
    netF = netF.cuda()

# setting up optimizer for both feature generator G and classifier F.
opt_g = optim.SGD(netG.parameters(), lr=0.01, weight_decay=0.0005)
opt_f = optim.SGD(netF.parameters(), lr=0.01, momentum=0.9, weight_decay=0.0005)

# loss function
criterion = nn.CrossEntropyLoss()

for epoch in range(0, 10):
    n, start = 0, time.time()
    train_l_sum = torch.tensor([0.0], dtype=torch.float32)
    train_acc_sum = torch.tensor([0.0], dtype=torch.float32)
    for i, (imgs, labels) in tqdm.tqdm(enumerate(iter(mnist_trainloader))):
        netG.train()
        netF.train()

        # train on GPU if possible
        if torch.cuda.is_available():
            imgs = imgs.cuda()
            labels = labels.cuda()
            train_l_sum = train_l_sum.cuda()
            train_acc_sum = train_acc_sum.cuda()

        opt_g.zero_grad()
        opt_f.zero_grad()

        # extracted feature
        bottleneck = netG(imgs)

        # predicted labels
        label_hat = netF(bottleneck)

        # loss function
        loss= criterion(label_hat, labels)
        loss.backward()
        opt_g.step()
        opt_f.step()

        # calcualte training error
        netG.eval()
        netF.eval()
        labels = labels.long()
        train_l_sum += loss.float()
        train_acc_sum += (torch.sum((torch.argmax(label_hat, dim=1) == labels))).float()
        n += labels.shape[0]
    test_acc = test_accuracy(iter(mnist_testloader), netG, netF)
    print('epoch %d, loss %.4f, train acc %.3f, test acc %.3f, time %.1f sec'\
        % (epoch + 1, train_l_sum/n, train_acc_sum/n, test_acc, time.time() - start))


469it [00:13, 35.55it/s]


epoch 1, loss 0.0012, train acc 0.954, test acc 0.986, time 15.0 sec


469it [00:13, 33.69it/s]


epoch 2, loss 0.0003, train acc 0.989, test acc 0.990, time 15.8 sec


469it [00:13, 35.84it/s]


epoch 3, loss 0.0002, train acc 0.993, test acc 0.990, time 14.8 sec


469it [00:14, 33.31it/s]


epoch 4, loss 0.0001, train acc 0.996, test acc 0.992, time 15.7 sec


469it [00:13, 35.78it/s]


epoch 5, loss 0.0001, train acc 0.997, test acc 0.991, time 14.9 sec


469it [00:13, 34.44it/s]


epoch 6, loss 0.0001, train acc 0.998, test acc 0.991, time 15.8 sec


469it [00:13, 33.59it/s]


epoch 7, loss 0.0000, train acc 0.998, test acc 0.992, time 15.7 sec


469it [00:13, 34.82it/s]


epoch 8, loss 0.0000, train acc 0.999, test acc 0.992, time 15.2 sec


469it [00:13, 35.01it/s]


epoch 9, loss 0.0000, train acc 0.999, test acc 0.992, time 15.1 sec


469it [00:13, 33.79it/s]


epoch 10, loss 0.0000, train acc 0.999, test acc 0.993, time 15.6 sec


### Lightning

In [75]:
!pip install pytorch-lightning

In [7]:
import pytorch_lightning as pl

In [77]:
class ResNetMNIST(pl.LightningModule):
  def __init__(self):
    super().__init__()
    self.model = model
    self.loss = nn.CrossEntropyLoss()

  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    loss = self.loss(logits, y)
    self.log("train_loss", loss, on_epoch=True, prog_bar=True)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    val_loss = self.loss(logits, y)
    self.log("val_loss", val_loss, on_epoch=True, prog_bar=True)
    return val_loss

  def configure_optimizers(self):
    return optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

In [78]:
model_lightning = ResNetMNIST()

In [79]:
trainer = pl.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=10,
    enable_progress_bar=True
)

INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [80]:
for param in model_lightning.parameters():
    param.requires_grad = True  # Bật lại trainable params

In [81]:
trainer.fit(model_lightning, mnist_trainloader, mnist_testloader)

INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name  | Type             | Params | Mode 
---------------------------------------------------
0 | model | resnetModel      | 11.2 M | train
1 | loss  | CrossEntropyLoss | 0      | train
---------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.746    Total estimated model params size (MB)
82        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=10` reached.


In [82]:
trainer.save_checkpoint("resnet_mnist.pt")

In [83]:
def get_prediction(x, model):
  model.freeze()
  prob = torch.softmax(model(x), dim=1)
  pred_class = torch.argmax(prob, dim=1)
  return pred_class, prob

In [84]:
from tqdm.autonotebook import tqdm

In [85]:
inference_model = ResNetMNIST.load_from_checkpoint("resnet_mnist.pt")

In [86]:
true_y, pred_y = [], []
for batch in tqdm(iter(mnist_testloader), total=len(mnist_testloader)):
  x, y = batch
  true_y.extend(y)
  preds, probs = get_prediction(x, inference_model)
  pred_y.extend(preds.cpu())

  0%|          | 0/79 [00:00<?, ?it/s]

In [87]:
from sklearn.metrics import classification_report

In [88]:
print(classification_report(true_y, pred_y, digits=3))

              precision    recall  f1-score   support

           0      0.992     0.995     0.993       980
           1      0.996     0.997     0.997      1135
           2      0.988     0.992     0.990      1032
           3      0.992     0.993     0.993      1010
           4      0.992     0.992     0.992       982
           5      0.985     0.990     0.988       892
           6      0.995     0.990     0.992       958
           7      0.991     0.988     0.990      1028
           8      0.992     0.989     0.990       974
           9      0.989     0.987     0.988      1009

    accuracy                          0.991     10000
   macro avg      0.991     0.991     0.991     10000
weighted avg      0.991     0.991     0.991     10000

