<a href="https://colab.research.google.com/github/akshayjain777/TSAI/blob/master/ResNet_Architecture.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F ## activation function
import torchvision ##contains dataset
import pdb

In [0]:
class baseBlock_layer_1_3(torch.nn.Module):
  
  def __init__(self, input_planes, planes, stride=1):
    super(baseBlock_layer_1_3, self).__init__()
    
    #declare convolution layers with batch norms
    self.baseconv = nn.Sequential(
        nn.Conv2d(input_planes,planes,stride=stride,kernel_size=3,padding=1),
        nn.MaxPool2d(2,2),
        nn.BatchNorm2d(planes),
        nn.ReLU(),
    )
    

    self.conv1 = torch.nn.Conv2d(planes,planes,stride=stride,kernel_size=3,padding=1)
    self.bn1= torch.nn.BatchNorm2d(planes)
    self.conv2 = torch.nn.Conv2d(planes,planes,stride=1,kernel_size=3,padding=1)
    self.bn2=torch.nn.BatchNorm2d(planes)
    
  def forward(self,x):
    x = self.baseconv(x)

    output = F.relu(self.bn1(self.conv1(x)))
    output = F.relu(self.bn2(self.conv2(output)))

    x += output
    
    return x

In [0]:
class baseBlock_layer_2(torch.nn.Module):
  
  def __init__(self, input_planes, planes, stride=1):
    super(baseBlock_layer_2, self).__init__()
    
    self.conv1 = torch.nn.Conv2d(input_planes,planes,stride=stride,kernel_size=3,padding=1)
    self.pool1 = torch.nn.MaxPool2d(2,2) 
    self.bn1= torch.nn.BatchNorm2d(planes)
    
  def forward(self,x):
    output = self.pool1(F.relu(self.bn1(self.conv1(x))))
    return output

In [0]:
class ResNet(torch.nn.Module):
  def __init__(self, block1_3, block2, classes=10):
    super(ResNet,self).__init__()
    self.in_planes = 64
    
    self.conv1= torch.nn.Conv2d(3, self.in_planes, kernel_size=3,stride=1,padding=1)
    self.bn1= torch.nn.BatchNorm2d(64)

    
    self.layer1 = self._layer(block1_3, self.in_planes, 128, stride=1)
    self.layer2 = self._layer(block2, self.in_planes, 256, stride=1)
    self.layer3 = self._layer(block1_3, self.in_planes, 512, stride=1)

    self.fc = nn.Linear(self.in_planes, classes)
    
  def _layer(self,block,in_planes,planes,stride=1):
    netLayers= []
    netLayers.append(block(in_planes, planes, stride=stride))
    self.in_planes = planes
    return torch.nn.Sequential(*netLayers)
  
  def forward(self,x):
    x= F.relu(self.bn1(self.conv1(x)))
    
    print(x.shape)

    x=self.layer1(x)
    print(x.shape)

    x=self.layer2(x)
    print(x.shape)

    x=self.layer3(x)
    print(x.shape)

    x = F.max_pool2d(x,4)
    print(x.shape)

    x=x.view(-1, x.size(1))
    print(x.shape)

    x=self.fc(x)
    print(x.shape)

    x= F.log_softmax(x,dim=-1)
    

    return x
  

In [0]:
def ResNet18():
    return ResNet(baseBlock_layer_1_3, baseBlock_layer_2)

In [0]:
def test():
    net = ResNet18()
    y = net(torch.randn(3, 3, 32, 32))
    print(y.size())

In [0]:
test()

torch.Size([3, 64, 32, 32])
torch.Size([3, 128, 16, 16])
torch.Size([3, 256, 8, 8])
torch.Size([3, 512, 4, 4])
torch.Size([3, 512, 1, 1])
torch.Size([3, 512])
torch.Size([3, 10])
torch.Size([3, 10])
