The code is basically trying to show the effectiveness of resnet to in deep CNN's by achieving higher accuracy even with small images.

In [None]:
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from PIL import Image 

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
transform = transforms.Compose([
              transforms.Resize((28,28)),
              transforms.ToTensor()
])

train_data = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
val_data = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_load = torch.utils.data.DataLoader(train_data, batch_size = 60, shuffle = True)
val_load = torch.utils.data.DataLoader(val_data,batch_size = 60, shuffle = False)

In [None]:
def convert(image_tensor):
  image = image_tensor.clone().detach().numpy()
  return image 

In [None]:
data_iteration = iter(train_load)
images, labels = data_iteration.next()
figure = plt.figure(figsize = (20,20))

for index in range(60):
  axis = figure.add_subplot(6,10,index + 1, xticks = [], yticks = [])
  plt.imshow(np.squeeze(convert(images[index])))
  axis.set_title([labels[index].item()])

In [None]:
class resnet(nn.Module):
  def __init__(self):
    super().__init__()
    
    # Resnet Block 1
    self.conv1 = nn.Conv2d(1, 16, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_1 = nn.BatchNorm2d(16)
    self.conv2 = nn.Conv2d(16, 16, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_2 = nn.BatchNorm2d(16)
    self.conv3 = nn.Conv2d(16, 16, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_3 = nn.BatchNorm2d(16)
    self.conv4 = nn.Conv2d(16, 16, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_4 = nn.BatchNorm2d(16)

    # Resnet Block 2
    self.conv5 = nn.Conv2d(16, 32, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(32)
    self.conv6 = nn.Conv2d(32, 32, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(32) 
    self.conv7 = nn.Conv2d(32, 32, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(32) 
    self.conv8 = nn.Conv2d(32, 32, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(32)

    # Resnet Block 3
    self.conv9 = nn.Conv2d(32, 64, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(64)
    self.conv10 = nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(64) 
    self.conv11 = nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(64) 
    self.conv12 = nn.Conv2d(64, 64, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(64)

    # Resnet Block 4
    self.conv13 = nn.Conv2d(64, 128, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(128)
    self.conv14 = nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(128) 
    self.conv15 = nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(128) 
    self.conv16 = nn.Conv2d(128, 128, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(128)

    # Resnet Block 5
    self.conv17 = nn.Conv2d(128, 256, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(256)
    self.conv18 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(256) 
    self.conv19 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(256) 
    self.conv20 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(256)

    # Resnet Block 6
    self.conv21 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(256)
    self.conv22 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(256) 
    self.conv23 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(256) 
    self.conv24 = nn.Conv2d(256, 256, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(256)

    # Resnet Block 7
    self.conv25 = nn.Conv2d(256, 512, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(512)
    self.conv26 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(512) 
    self.conv27 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(512) 
    self.conv28 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(512)

    # Resnet Block 8
    self.conv29 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1))
    self.batch_norm_5 = nn.BatchNorm2d(512)
    self.conv30 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_6 = nn.BatchNorm2d(512) 
    self.conv31 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_7 = nn.BatchNorm2d(512) 
    self.conv32 = nn.Conv2d(512, 512, kernel_size = (3,3), stride = (1,1), padding = (1,1))
    self.batch_norm_8 = nn.BatchNorm2d(512)

    # Linear Block
    self.linear_1 = nn.Linear(4*2048, 2048)
    self.dropout_1 = nn.Dropout(0.5)
    self.linear_2 = nn.Linear(2048,512)
    self.dropout_2 = nn.Dropout(0.5)
    self.linear_3 = nn.Linear(512, 128)
    self.dropout_3 = nn.Dropout(0.5)
    self.linear_4 = nn.Linear(128, 32)
    self.dropout_4 = nn.Dropout(0.5)
    self.linear_5 = nn.Linear(32, 10)

  def forward_prop(self, x):

    # Resnet 1
    x = F.relu(self.conv1(x))
    res1 = x
    x = self.batch_norm_1(x)
    x = F.relu(self.conv2(x))
    x = self.batch_norm_2(x)
    x = F.relu(self.conv3(x))
    x = self.batch_norm_3(x)
    x = F.relu(self.conv4(x))
    x = self.batch_norm_4(x + res1)

    # Resnet 2
    x = F.relu(self.conv5(x))
    res2 = x
    x = self.batch_norm_5(x)
    x = F.relu(self.conv6(x))
    x = self.batch_norm_6(x)
    x = F.relu(self.conv7(x))
    x = self.batch_norm_7(x)
    x = F.relu(self.conv8(x))
    x = self.batch_norm_8(x + res2)

    # Resnet 3
    x = F.relu(self.conv9(x))
    res3 = x
    x = self.batch_norm_9(x)
    x = F.relu(self.conv10(x))
    x = self.batch_norm_10(x)
    x = F.relu(self.conv11(x))
    x = self.batch_norm_11(x)
    x = F.relu(self.conv12(x))
    x = self.batch_norm_12(x + res3)

    # Resnet 4
    x = F.relu(self.conv13(x))
    res4 = x
    x = self.batch_norm_13(x)
    x = F.relu(self.conv14(x))
    x = self.batch_norm_14(x)
    x = F.relu(self.conv15(x))
    x = self.batch_norm_15(x)
    x = F.relu(self.conv16(x))
    x = self.batch_norm_16(x + res4)

    # Resnet 5
    x = F.relu(self.conv17(x))
    res5 = x
    x = self.batch_norm_17(x)
    x = F.relu(self.conv18(x))
    x = self.batch_norm_18(x)
    x = F.relu(self.conv19(x))
    x = self.batch_norm_19(x)
    x = F.relu(self.conv20(x))
    x = self.batch_norm_20(x + res5)

    # Resnet 6
    x = F.relu(self.conv21(x))
    res6 = x
    x = self.batch_norm_21(x)
    x = F.relu(self.conv22(x))
    x = self.batch_norm_22(x)
    x = F.relu(self.conv23(x))
    x = self.batch_norm_23(x)
    x = F.relu(self.conv24(x))
    x = self.batch_norm_24(x + res6)

    # Resnet 7
    x = F.relu(self.conv25(x))
    res7 = x
    x = self.batch_norm_25(x)
    x = F.relu(self.conv26(x))
    x = self.batch_norm_26(x)
    x = F.relu(self.conv27(x))
    x = self.batch_norm_27(x)
    x = F.relu(self.conv28(x))
    x = self.batch_norm_28(x + res7)

    # Resnet 8
    x = F.relu(self.conv29(x))
    res8 = x
    x = self.batch_norm_29(x)
    x = F.relu(self.conv30(x))
    x = self.batch_norm_30(x)
    x = F.relu(self.conv31(x))
    x = self.batch_norm_31(x)
    x = F.relu(self.conv32(x))
    x = self.batch_norm_32(x + res8)

    # Linear Block
    x = x.view(-1, 4*2048)
    x = F.relu(self.linear_1(x))
    x = self.dropout_1(x)
    x = F.relu(self.linear_2(x))
    x = self.dropout_2(x)
    x = F.relu(self.linear_3(x))
    x = self.dropout_3(x)
    x = F.relu(self.linear_4(x))
    x = self.dropout_4(x)
    x = self.linear_5(x)

    return x

In [None]:
model = resnet().to(device)
model

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

In [None]:
epochs = 100
train_loss_h = []
train_hits_h = []
val_loss_h = []
val_hits_h = []

for i in range(epochs):

  train_loss = 0
  train_hit = 0
  val_loss = 0
  val_hits = 0

  for input, label in train_load:

    input = input.to(device)
    label = label.to(device)
    output = model(input)
    loss = criterion(output, label)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    _, prediction = torch.max(output, 1)
    train_loss += loss.item()
    train_hits += torch.sum(prediction == label.data)

  else:
    with torch.no_grad():
      for val_input, val_label in val_load:

        val_input = val_input.to(device)
        val_label = val_label.to(device)
        val_output = model(val_input)
        val_loss = criterion(val_output, val_label)
      
        _, val_prediction = torch.max(val_output, 1)
        val_loss += val_loss.item()
        val_hits += torch.sum(val_prediction == val_label.data)
      
  loss_per_epoch = train_loss/len(train_load)
  acc_per_epoch = train_hits.float()/ len(train_load)
  train_loss_h.append(loss_per_epoch)
  train_hits_h.append(acc_per_epoch)
  
  val_loss_per_epoch = val_loss/len(val_load)
  val_acc_per_epoch = val_hits.float()/ len(val_load)
  val_loss_h.append(val_loss_per_epoch)
  val_hits_h.append(val_acc_per_epoch)
  print('epoch :', (i+1))
  print('training loss: {:.4f}, acc {:.4f} '.format(loss_per_epoch, acc_per_epoch.item()))
  print('validation loss: {:.4f}, validation acc {:.4f} '.format(val_loss_per_epoch, val_acc_per_epoch.item()))
