# Transfer learning
TL of Alexnet applied to MNIST digits dataset

In [1]:
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
from torchvision import datasets
import torchvision.models as models

from sklearn.metrics import confusion_matrix, classification_report

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt
import seaborn as sns

### Train classifier

In [2]:
# Refit model class (to overload nn.Module class)
class RefitModel(nn.Module):
    def __init__(self, orgininal_model, num_classes):
        super(RefitModel, self).__init__()
        
        self.features = orgininal_model.features # to keep convolitionnal layers as is
        
        # modify only classification layers
        self.classifier = orgininal_model.classifier 
        self.classifier == nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        self.modelName = 'alexnet'
        
        for w in self.features.parameters():
            w.required_grad = False # to freeze features weights
    
    # transfer features in classification layer
    def forward(self,x):
        f = self.features(x)
        f = f.view(f.size(0), 256*6*6)
        y = self.classifier(f)
        return y
    
# Image pre-processing
def preprocess(img):   
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(img)
    return input_tensor


def preprocess(img):  
    img = img.resize((224,224)).convert('RGB')
    tensor = transforms.ToTensor()(img)
    return tensor

In [3]:
# Load model
model = models.alexnet(pretrained=True)

# Refit model
num_classes = 10
refit_model = RefitModel(model, num_classes)

# Load training data and fit to alexnet data
train = datasets.MNIST(root='data',train=True,transform=preprocess,download=True)
test = datasets.MNIST(root='data',train=False,transform=preprocess,download=True)

In [4]:
batch_size = 4
train_loader = torch.utils.data.DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=1)
test_loader = torch.utils.data.DataLoader(test,batch_size=batch_size,shuffle=True,num_workers=1)

classes = tuple(i for i in range(10))

In [14]:
# Train model
error = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.SGD(refit_model.parameters(), lr=learning_rate)
num_epochs = 2

for epoch in range(num_epochs):
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward propagation
        outputs = refit_model(inputs)
        
        # Calculate cross entropy loss
        loss = error(outputs, labels)
        
        # Calculating gradients
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        # Print stats
        running_loss += loss.item()
        
        if i%1000 == 0:
            print('epoch: %d, i: %5d, loss: %.3f'%(epoch+1, i, running_loss/2000))
            running_loss = 0.0

# Save model
torch.save(refit_model, 'tl_model.pth')

epoch: 1, i:     0, loss: 0.001


KeyboardInterrupt: 

### Test classifier

In [None]:
model = torch.load('tl_model.pth')
model.eval()
y_test = []
y_pred = []

k = 5

for i, data in enumerate(test_loader, 0):
    inputs, label = data
    out = model(inputs)
    _, topk_catid = torch.topk(out, 1)
    
    y_test.append(int(labels[0]))
    y_pred.append(int(topk_catid[0]))

In [None]:
classification_report(y_test, y_pred)

In [None]:
# Confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
    confusion_matrix(y_test, y_pred), 
    cmap='YlGnBu', 
    linewidths=.2,
    linecolor='gray',
    cbar_kws={"shrink": .8},
    annot = True,
    fmt='d'
)