In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import data　sets
import matplotlib.pyplot as plt
%matplotlib inline

In [6]:
# Residual Block

class ResidualBlock(nn.Module):
    def __init__(self, channels): # Input no. channels
        super(ResidualBlock,self).__init__()
    
        self.net = nn.Sequential(
            # Input and Output Size same
            nn.Conv2d(channels, channels, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        )
        
    def forward(self,x):
        y=self.net(x)
        # acquire y, send relu(x+y) to output
        return nn.functional.relu(x+y) 

In [7]:
class ResNet(nn.Module):
    def __init__(self): 
        super(ResNet,self).__init__()
        self.net=nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=5), nn.ReLU(),
            nn.MaxPool2d(2), 
            ResidualBlock(16), # small net within big net
            
            nn.Conv2d(16, 32, kernel_size=5), nn.ReLU(),
            nn.MaxPool2d(2),
            ResidualBlock(32),
            
            nn.Flatten(),
            nn.Linear(512,10)
        )
    def forward(self,x):
        y = self.net(x)

In [8]:
# Examine Net Structure
X = torch.rand(size= (1, 1, 28, 28))
for layer in ResNet().net:
    X = layer(X)
    print(layer.__class__.__name__, 'output shape: \t', X.shape)

Conv2d output shape: 	 torch.Size([1, 16, 24, 24])
ReLU output shape: 	 torch.Size([1, 16, 24, 24])
MaxPool2d output shape: 	 torch.Size([1, 16, 12, 12])
ResidualBlock output shape: 	 torch.Size([1, 16, 12, 12])
Conv2d output shape: 	 torch.Size([1, 32, 8, 8])
ReLU output shape: 	 torch.Size([1, 32, 8, 8])
MaxPool2d output shape: 	 torch.Size([1, 32, 4, 4])
ResidualBlock output shape: 	 torch.Size([1, 32, 4, 4])
Flatten output shape: 	 torch.Size([1, 512])
Linear output shape: 	 torch.Size([1, 10])
