<a href="https://colab.research.google.com/github/DUNGTK2004/deep-models-from-scratch/blob/main/resnet_multitask.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Import dependencies

In [2]:
import torch
import torchvision
import torchvision.datasets as datasets
from torchvision.transforms import ToTensor
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import tqdm

In [None]:
print(torch.__version__)

2.5.1+cu124


### Load data

In [7]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)),])
batch_size = 128

Customize MNIST dataset to add 1 label for datasets

In [8]:
class CustomMnist(Dataset):
  def __init__(self, root, train_bol, download, transform=None):
    self.mnist = datasets.MNIST(root=root, train=train_bol, download=download, transform=transform)
    self.curve = [0, 2, 3, 5, 6, 8, 9]

  def __len__(self):
    return len(self.mnist)

  def __getitem__(self, idx):
    image, label = self.mnist[idx]
    label_curve = 1 if label in self.curve else 0
    return image, label, label_curve


In [9]:
mnist_trainset = CustomMnist(root='./data', train_bol=True, download=True, transform=transform)
mnist_testset = CustomMnist(root='./data', train_bol=False, download=True, transform=transform)

100%|██████████| 9.91M/9.91M [00:00<00:00, 16.1MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 481kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 4.46MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 5.76MB/s]


In [10]:
mnist_trainloader = torch.utils.data.DataLoader(mnist_trainset, batch_size=batch_size, shuffle=True, num_workers=2)
mnist_testloader = torch.utils.data.DataLoader(mnist_testset, batch_size=batch_size, shuffle=False, num_workers=2)

Visualize data

In [None]:
# If error, only need change 3 -> 1 in transforms
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
  sample_idx = torch.randint(len(mnist_trainset), size=(1,)).item()
  img, label = mnist_trainset[sample_idx]
  figure.add_subplot(rows, cols, i)
  plt.title(label)
  plt.axis("off")
  plt.imshow(img.squeeze(), cmap="gray")

In [None]:
# Model
net = models.resnet18(pretrained=False)
net.fc = nn.Linear(512, 10, bias=True)
print(net)

### Resnet from scratch

In [14]:
class BasicBlock(nn.Module):
  def __init__(self, in_channels, channels, stride=1):
    super(BasicBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, stride=stride, padding=1)
    self.bn1 = nn.BatchNorm2d(channels)
    self.relu1 = nn.ReLU(inplace=True)
    self.conv2 = nn.Conv2d(channels, channels, stride=1, kernel_size=3, padding=1)
    self.bn2 = nn.BatchNorm2d(channels)
    self.relu2 = nn.ReLU(inplace=True)
    self.shortcut = nn.Sequential()
    if in_channels != channels or stride == 2:
      self.shortcut = nn.Sequential(
          nn.Conv2d(in_channels, channels, kernel_size=1, stride=2),
          nn.BatchNorm2d(channels)
      )

  def forward(self, x):
    out = self.conv1(x)
    out = self.bn1(out)
    out = self.relu1(out)
    out = self.conv2(out)
    out = self.bn2(out)
    out += self.shortcut(x)
    out = self.relu2(out)
    return out


class resnetModel(nn.Module):
  def __init__(self, basic_block, num_blocks, zero_init_residual=False):
    super(resnetModel, self).__init__()
    self.in_channel = 64
    self.conv1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=7, stride=2, padding=3)
    self.bn1 = nn.BatchNorm2d(64)
    self.relu1 = nn.ReLU(inplace=True)
    self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
    self.layer1 = self._make_layer(1, num_blocks[0], 64, basic_block)
    self.layer2 = self._make_layer(2, num_blocks[1], 128, basic_block)
    self.layer3 = self._make_layer(2, num_blocks[2], 256, basic_block)
    self.layer4 = self._make_layer(2, num_blocks[3], 512, basic_block)
    self.averagePooling = nn.AdaptiveAvgPool2d((1, 1))
    self.fc = nn.Linear(in_features=512, out_features=10)
    self.fc2 = nn.Linear(in_features=512, out_features=2)

  def _make_layer(self, stride_first, num_block, channel, basic_block):
    strides = [stride_first] + [1] * (num_block-1)
    layer = []
    for i in range(num_block):
      stride = strides[i]
      layer.append(basic_block(self.in_channel, channel, stride))
      self.in_channel = channel
    return nn.Sequential(*layer)

  def forward(self, x):
    x = self.conv1(x)
    x = self.bn1(x)
    x = F.relu(x)
    x = self.maxpool(x)
    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)
    x = self.averagePooling(x)
    x = x.view(x.size(0), -1)
    x1 = self.fc(x)
    x2 = self.fc2(x)

    return x1, x2

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('Linear') != -1:
        torch.nn.init.xavier_uniform_(m.weight)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.01)

model = resnetModel(BasicBlock, [2, 2, 2, 2])
model.apply(weights_init)


resnetModel(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (shortcut): Sequential()
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu

In [16]:

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 11187468


In [17]:
def test_accuracy(data_iter, net):
  acc_sum_lb, acc_sum_curve, n = 0, 0, 0
  for (imgs, labels, label_curves) in data_iter:
    # send data to the GPU if cuda is available
    if torch.cuda.is_available():
      imgs = imgs.cuda()
      labels = labels.cuda()
      label_curves = label_curves.cuda()
    net.eval()
    with torch.no_grad():
      labels = labels.long()
      label_curves = label_curves.long()
      out_lb, out_curve = net(imgs)
      acc_sum_lb += torch.sum((torch.argmax(out_lb, dim=1) == labels)).float()
      acc_sum_curve += torch.sum((torch.argmax(out_curve, dim=1) == label_curves)).float()

      n += labels.shape[0]

  return acc_sum_lb.item()/n, acc_sum_curve.item()/n

Train / test

In [21]:
if torch.cuda.is_available():
 model =model.cuda()

opt = optim.SGD(model.parameters(), lr=0.01, weight_decay=0.0005)

criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss()

for epoch in range(0, 10):
  n, start = 0, time.time()
  train_l_sum = torch.tensor([0.0], dtype=torch.float32)
  train_acc_sum_lb = torch.tensor([0.0], dtype=torch.float32)
  train_acc_sum_curve = torch.tensor([0.0], dtype=torch.float32)
  for i, (imgs, labels, label_curves) in tqdm.tqdm(enumerate(mnist_trainloader)):
    model.train()
    # If training on GPU
    if torch.cuda.is_available():
      imgs = imgs.cuda()
      labels = labels.cuda()
      label_curves = label_curves.cuda()
      train_l_sum = train_l_sum.cuda()
      train_acc_sum_lb = train_acc_sum_lb.cuda()
      train_acc_sum_curve = train_acc_sum_curve.cuda()

    opt.zero_grad()

    # loss function
    output_lb, output_curve =model(imgs)

    loss1 = criterion1(output_lb, labels)
    loss2 = criterion2(output_curve, label_curves)
    loss = loss1 + loss2
    loss.backward()
    opt.step()

    # Calculate training error
    model.eval()
    labels = labels.long()
    train_l_sum += loss.float()
    train_acc_sum_lb += (torch.sum((torch.argmax(output_lb, dim=1) == labels))).float()
    train_acc_sum_curve += (torch.sum((torch.argmax(output_curve, dim=1) == label_curves))).float()

    n += labels.shape[0]
  test_acc_lb, test_acc_curve = test_accuracy(iter(mnist_testloader), model)
  print('epoch %d, loss %.4f, train acc lb %.3f, train_acc_curve %.3f, test acc lb %.3f, test_acc_curve %.3f, time %.1f sec' \
        % (epoch + 1, train_l_sum/n, train_acc_sum_lb/n, train_acc_sum_curve/n, test_acc_lb, test_acc_curve, time.time()-start))

469it [00:13, 35.95it/s]


epoch 1, loss 0.0002, train acc lb 0.994, train_acc_curve 0.998, test acc lb 0.985, test_acc_curve 0.993, time 14.7 sec


469it [00:13, 35.60it/s]


epoch 2, loss 0.0001, train acc lb 0.997, train_acc_curve 0.999, test acc lb 0.987, test_acc_curve 0.996, time 14.8 sec


469it [00:13, 36.02it/s]


epoch 3, loss 0.0001, train acc lb 0.998, train_acc_curve 0.999, test acc lb 0.990, test_acc_curve 0.995, time 14.7 sec


469it [00:13, 35.68it/s]


epoch 4, loss 0.0000, train acc lb 0.999, train_acc_curve 1.000, test acc lb 0.991, test_acc_curve 0.996, time 14.8 sec


469it [00:12, 36.24it/s]


epoch 5, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.990, test_acc_curve 0.997, time 15.3 sec


469it [00:13, 35.67it/s]


epoch 6, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.991, test_acc_curve 0.997, time 14.8 sec


469it [00:13, 35.97it/s]


epoch 7, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.989, test_acc_curve 0.994, time 14.7 sec


469it [00:13, 35.33it/s]


epoch 8, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.991, test_acc_curve 0.996, time 14.9 sec


469it [00:13, 35.88it/s]


epoch 9, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.991, test_acc_curve 0.997, time 14.7 sec


469it [00:13, 35.52it/s]


epoch 10, loss 0.0000, train acc lb 1.000, train_acc_curve 1.000, test acc lb 0.992, test_acc_curve 0.996, time 15.5 sec
