In [1]:

# Imports
import torch
import torchvision # torch package for vision related things
import torch.nn.functional as F  # Parameterless functions, like (some) activation functions
import torchvision.datasets as datasets  # Standard datasets
import torchvision.transforms as transforms  # Transformations we can perform on our dataset for augmentation
import torchvision.transforms as transforms

from torch import optim  # For optimizers like SGD, Adam, etc.
from torch import nn  # All neural network modules
from torch.utils.data import DataLoader  # Gives easier dataset managment by creating mini batches etc.
from tqdm import tqdm  # For nice progress bar!

torch.cuda.empty_cache()



In [2]:
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:


class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResBlock,self).__init__()

        self.in_channels=in_channels
        self.out_channels=out_channels

        if out_channels>in_channels: ## The dotted case in the paper
            self.first_stride=2
            self.skip_connection=nn.Sequential(
                nn.Conv2d(in_channels=in_channels,out_channels=out_channels,kernel_size=1,stride=2),
                nn.BatchNorm2d(out_channels))
        else:                        ## The usual case in the paper
            self.first_stride=1     
            self.skip_connection=nn.Sequential()


        self.conv1=nn.Conv2d(in_channels=self.in_channels,out_channels=self.out_channels,kernel_size=3,stride=self.first_stride,padding=1)
        self.bn1=nn.BatchNorm2d(out_channels)

        self.conv2=nn.Conv2d(in_channels=self.out_channels,out_channels=self.out_channels,kernel_size=3,stride=1,padding=1)
        self.bn2=nn.BatchNorm2d(out_channels)

    def forward(self,x):
        shortcut=self.skip_connection(x)

        x=F.relu(self.bn1(self.conv1(x)))
        x=F.relu(self.bn2(self.conv2(x)))
        x= x+shortcut
        x=F.relu(x)
        
        return x



In [4]:
class ResNet(nn.Module): # input image is 224x224
    def __init__(self,in_channels=3,num_classes=10):
        super(ResNet,self).__init__()
        self.in_channels=in_channels
        self.num_classes=num_classes

        self.layer0=nn.Sequential(
            nn.Conv2d(in_channels=self.in_channels,out_channels=64,kernel_size=7,stride=2,padding=3),
            nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
            nn.BatchNorm2d(64)
        )

        self.layer1=ResBlock(64,64)
        self.layer2=ResBlock(64,64)
        self.layer3=ResBlock(64,64)
        
        self.layer4=ResBlock(64,128)
        self.layer5=ResBlock(128,128)
        self.layer6=ResBlock(128,128)
        self.layer7=ResBlock(128,128)

        self.layer8=ResBlock(128,256)
        self.layer9=ResBlock(256,256)
        self.layer10=ResBlock(256,256)
        self.layer11=ResBlock(256,256)
        self.layer12=ResBlock(256,256)
        self.layer13=ResBlock(256,256)

        self.layer14=ResBlock(256,512)
        self.layer15=ResBlock(512,512)
        self.layer16=ResBlock(512,512)

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc=nn.Linear(512,self.num_classes)



    def forward(self,x):
        x=self.layer0(x)
        x=F.relu(x)
        x=self.layer1(x)
        x=self.layer2(x)
        x=self.layer3(x)
        x=self.layer4(x)
        x=self.layer5(x)
        x=self.layer6(x)
        x=self.layer7(x)
        x=self.layer8(x)
        x=self.layer9(x)
        x=self.layer10(x)
        x=self.layer11(x)
        x=self.layer12(x)
        x=self.layer13(x)
        x=self.layer14(x)
        x=self.layer15(x)
        x=self.layer16(x)
        x=self.avgpool(x)
        x=F.relu(x)
        #print("SHAPE=",x.shape)
        x=x.reshape(x.shape[0],-1)
        x=self.fc(x)
        return x
        
        





In [5]:
model=ResNet(1,10).to(device=device)