<a href="https://colab.research.google.com/github/VishakBharadwaj94/Resnets_from_scratch/blob/master/CIFAR_Dataset_with_ResNet_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
from torch import nn

In [0]:
class ResBlock(nn.Module):     
  expansion = 1
  def __init__(self, in_channels, out_channels, stride=1):
    super(ResBlock, self).__init__()

    # Conv Layer 1
    self.conv1 = nn.Conv2d(
        in_channels=in_channels, out_channels=out_channels,
        kernel_size=(3, 3), stride=stride, padding=1, bias=False
    )
    self.bn1 = nn.BatchNorm2d(out_channels)

    # Conv Layer 2
    self.conv2 = nn.Conv2d(
        in_channels=out_channels, out_channels=out_channels,
        kernel_size=(3, 3), stride=1, padding=1, bias=False
    )
    self.bn2 = nn.BatchNorm2d(out_channels)

    self.shortcut = nn.Sequential()
    if stride != 1 or in_channels != out_channels:
      self.shortcut = nn.Sequential(
          nn.Conv2d(
              in_channels=in_channels, out_channels=out_channels,
              kernel_size=(1, 1), stride=stride, bias=False
          ),
          nn.BatchNorm2d(out_channels)
      )

  def forward(self, x):
    out = nn.ReLU()(self.bn1(self.conv1(x)))
    out = self.bn2(self.conv2(out))
    out += self.shortcut(x)
    out = nn.ReLU()(out)
    return out

In [0]:
class ResNet(nn.Module):

  def __init__(self,classes=10):

    super(ResNet,self).__init__()

    self.conv1 = nn.Conv2d(in_channels=3,out_channels=64,kernel_size=(3,3),stride=1, padding=1,bias=False)
    self.bn1 = nn.BatchNorm2d(64)
    self.rel1 = nn.ReLU()

    self.block1 = self.create_block(64,64,1)
    self.block2 = self.create_block(64,128,2)
    self.block3 = self.create_block(128,256,2)
    self.block4 = self.create_block(256,512,2)

    self.lin = nn.Linear(512,classes)

    #1 block has two reblocks

  def create_block(self, in_channels, out_channels, stride):
    
    net = nn.Sequential(
        ResBlock(in_channels, out_channels, stride),
        ResBlock(out_channels, out_channels, 1)
    )
    return net 
  # def create_block(self,in,out,stride):
  #   return nn.Sequential(
  #   ResBlock(in, out, stride),
  #   ResBlock(out,out, 1))
        
  def forward(self,x):
    x = self.rel1(self.bn1(self.conv1(x)))
    x = nn.AvgPool2d(4)(self.block4(self.block3(self.block2(self.block1(x)))))
    x = x.view(x.size(0),-1) #conversion from 3D to 2D
    x = self.lin(x)

    return x


In [4]:
#downloading the data 
!wget http://pjreddie.com/media/files/cifar.tgz
!tar xzf cifar.tgz

--2019-12-17 17:25:16--  http://pjreddie.com/media/files/cifar.tgz
Resolving pjreddie.com (pjreddie.com)... 128.208.4.108
Connecting to pjreddie.com (pjreddie.com)|128.208.4.108|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://pjreddie.com/media/files/cifar.tgz [following]
--2019-12-17 17:25:16--  https://pjreddie.com/media/files/cifar.tgz
Connecting to pjreddie.com (pjreddie.com)|128.208.4.108|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 168584360 (161M) [application/octet-stream]
Saving to: ‘cifar.tgz’


2019-12-17 17:25:26 (18.2 MB/s) - ‘cifar.tgz’ saved [168584360/168584360]



In [0]:
#preprocessing
import numpy as np
import random
def preprocess(image):
    image = np.array(image)
    
    if random.random() > 0.5:
        image = image[::-1,:,:]
    
    cifar_mean = np.array([0.4914, 0.4822, 0.4465]).reshape(1,1,-1)
    cifar_std  = np.array([0.2023, 0.1994, 0.2010]).reshape(1,1,-1)
    image = (image - cifar_mean) / cifar_std
    
    image = image.transpose(2,1,0)
    return image

In [0]:
from pathlib import Path
data = Path('./cifar')
train = data/'train'
test = data/'test'

In [7]:
train,test

(PosixPath('cifar/train'), PosixPath('cifar/test'))

In [0]:
with open(data/'labels.txt','r') as f:
  labels = f.read().split()
  label_mapping = dict(zip(labels, list(range(len(labels)))))
  

In [9]:
label_mapping

{'airplane': 0,
 'automobile': 1,
 'bird': 2,
 'cat': 3,
 'deer': 4,
 'dog': 5,
 'frog': 6,
 'horse': 7,
 'ship': 8,
 'truck': 9}

In [10]:
import os

os.listdir(train)[0]

'8619_horse.png'

In [0]:
#creating the dataset
#try later without creating a list from the files
from PIL import Image
import os

class Cifar10Dataset(torch.utils.data.Dataset):

  def __init__(self,train_path=train,labels=label_mapping,transform=None):
    files = os.listdir(train_path)
    files = [os.path.join(train_path,x) for x in files]

    self.files = files
    self.transform = transform

  def __len__(self):

    return len(self.files)

  def __getitem__(self,idx):
    image = Image.open(self.files[idx])
    image = preprocess(image)
    image = image.astype(np.float32)
    label = label_mapping[self.files[idx].split('_')[-1].split('.')[0]]

    if self.transform:
      image = self.tranform(image)

    return (image,label)




In [0]:
dataset = Cifar10Dataset()

In [0]:
trainloader = torch.utils.data.DataLoader(dataset, batch_size=128, shuffle=True, num_workers=2)


testset = Cifar10Dataset(data/"test", transform=None)
testloader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=True, num_workers=2)

In [14]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")     #Check whether a GPU is present.

clf = ResNet()
clf.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (rel1): ReLU()
  (block1): Sequential(
    (0): ResBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (shortcut): Sequential()
    )
    (1): ResBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum

In [0]:
from torch import optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(clf.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 200], gamma=0.1)

In [17]:
import time
for epoch in range(5):
  losses = []
  scheduler.step()

  start = time.time()
  for inputs, targets in trainloader:
    inputs, targets = inputs.to(device), targets.to(device)

    optimizer.zero_grad()                 # Zero the gradients

    outputs = clf(inputs)                 # Forward pass
    loss = criterion(outputs, targets)    # Compute the Loss
    loss.backward()                       # Compute the Gradients

    optimizer.step()                      # Updated the weights
    losses.append(loss.item())
    end = time.time()



  # Evaluate
  clf.eval()
  total = 0
  correct = 0

  with torch.no_grad():
    for inputs, targets in testloader:
      inputs, targets = inputs.to(device), targets.to(device)

      outputs = clf(inputs)
      _, predicted = torch.max(outputs.data, 1)
      total += targets.size(0)
      correct += predicted.eq(targets.data).cpu().sum()

  print(f'Epoch : {epoch+4}  Loss: {loss} Test Acc : {100.*correct/total}')
  print('--------------------------------------------------------------')
  clf.train()    


Epoch : 4  Loss: 0.7643682956695557 Test Acc : 69.63999938964844
--------------------------------------------------------------
Epoch : 5  Loss: 0.7147625684738159 Test Acc : 67.52999877929688
--------------------------------------------------------------
Epoch : 6  Loss: 0.7407170534133911 Test Acc : 66.08999633789062
--------------------------------------------------------------
Epoch : 7  Loss: 0.4705318510532379 Test Acc : 72.6500015258789
--------------------------------------------------------------
Epoch : 8  Loss: 0.7329400777816772 Test Acc : 71.58000183105469
--------------------------------------------------------------
