VGG13 with global average pooling instead of the expensive FC layers in the classifier

In [1]:
import sys
sys.path.append("..")
from utils.dataset import FerDataset

import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision

from PIL import Image

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
class VGG13(nn.Module):
    
    def __init__(self):
        super(VGG13, self).__init__()
        
        self.convnet = nn.Sequential(
            # 224 x 224 x 1
            nn.Conv2d(1, 64, kernel_size=(3, 3), stride = 1, padding = 1),
            # kernel size F=3, stride S = 1, to retain input size padding must be P = (F - 1)/2
            nn.ReLU(),
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            # 112 x 112 x 64
            
            # max pool with F=2 and S=2 chooses the max out of a 2x2 square and only keeps that max value.
            # Therefore 75% of the information are left out
            # the max pool layer works on every depth dimension independently, therefore the input depth remains
            # unchanged
            
            #nn.Dropout2d(p=0.25),

            
            nn.Conv2d(64, 128, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            # 56 x 56 x 128

            #nn.Dropout2d(p=0.25),


            nn.Conv2d(128, 256, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            # 28 x 28 x 256

            #nn.Dropout2d(p=0.25),


            nn.Conv2d(256, 512, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            # 14 x 14 x 512
            
            #nn.Dropout2d(p=0.25),
            
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.Conv2d(512, 512, kernel_size=(3, 3), stride = 1, padding = 1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            # 7 x 7 x 512
            
            #nn.Dropout2d(p=0.25)
            
            nn.AdaptiveAvgPool2d(1)
            # general average pooling. Size: 1 x 1 x 512
            
            
        )
        
        self.fc = nn.Sequential(
            #nn.Linear(512*7*7, 4096), without the AvgPool2d layer
            nn.Linear(512*3*3, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096, 4096),
            nn.ReLU(),
            nn.Dropout(p=0.5),
            nn.Linear(4096,10),
        )
        
        
    def forward(self, x):
        x = self.convnet(x)
        print("x.shape", x.shape)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x

In [None]:
dataset = FerDataset(base_path='../../data',
                     data='ferplus',
                     mode='train',
                     label='ferplus_votes')
dataloader = DataLoader(dataset, batch_size=6, shuffle=True, num_workers=0)
net = VGG13()
log_softmax = nn.LogSoftmax(dim=-1)
criterion = nn.KLDivLoss(size_average=False)
optimizer = torch.optim.SGD(net.parameters(), lr=0.0001, momentum=0.9)

In [None]:
def resize_img(x_batch, y_batch, size):
    x_batch_resized = torch.zeros((x_batch.shape[0], x_batch.shape[1], size, size))
    for i in range(x_batch.shape[0]):
        image = torchvision.transforms.ToPILImage()(x_batch[i])
        image = torchvision.transforms.functional.resize(image, (size, size))
        x_batch_resized[i] = torchvision.transforms.ToTensor()(image)
        
    return x_batch_resized, y_batch

In [None]:
x_batch, y_batch = next(iter(dataloader))
print(x_batch.shape)
x_batch, y_batch = resize_img(x_batch, y_batch, 224)
print(x_batch.shape)

In [None]:
losses = []
for i in range(1000):
    optimizer.zero_grad()
    logits = net(x_batch)
    log_probs = log_softmax(logits)
    loss = criterion(log_probs, y_batch)
    losses.append(float(loss))
    loss.backward()
    optimizer.step()
    print(i, end='\r')