In [1]:
import sys
sys.path.append("..")
from utils.dataset import FerDataset

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision

from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class ResidualUnit(nn.Module):
    
    def __init__(self, depth_in, depth_out):
        super(ResidualUnit, self).__init__()
        self.stride = 1
        
        self.resBlock = nn.Sequential(
            nn.Conv2d(depth_in, depth_out, kernel_size=(3, 3), stride = self.stride, padding = 1),
            nn.BatchNorm2d(depth_out),
            nn.ReLU(),
            nn.Conv2d(depth_out, depth_out, kernel_size=(3, 3), stride = self.stride, padding = 1),
            nn.BatchNorm2d(depth_out)
        )
        
    def forward(self, x):
        identity = x
        
        x = self.resBlock(x)
        #print("x.shape", x.shape)
        #x = x.view(x.size(0), -1)
        
        x += identity
        x = nn.ReLU(x)
        
        return x

In [3]:
class DownsampleResidualUnit(nn.Module):
    
    def __init__(self, input_depth, output_depth):
        super(DownsampleResidualUnit, self).__init__()
        self.stride = 2
        
        self.resBlock = nn.Sequential(
            nn.Conv2d(depth_in, depth_out, kernel_size=(3, 3), stride = self.stride, padding = 1),
            nn.BatchNorm2d(depth_out),
            nn.ReLU(),
            nn.Conv2d(depth_out, depth_out, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.BatchNorm2d(depth_out)
        )
        
        self.matchDim = nn.Sequential(
            nn.Conv2d(depth_in, depth_out, kernel_size=(1,1), stride=self.stride, padding = 0),
            # this is required to match the dimensions of the identity x with F(x), because
            # in this block the first of the two convolutionl layers performs downsamlpling and therefore
            # changes the dimensions of the activation volume.
            nn.BatchNorm2d(depth_out)
        )
        
        
    def forward(self, x):
        identity = x
        
        x = self.resBlock(x)
        #print("x.shape", x.shape)
        #x = x.view(x.size(0), -1)
        
        
        identity = self.matchDim(identity)
        
        
        x += identity
        x = nn.ReLU(x)
        
        return x

In [None]:
# https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py