# Transfer learning
TL of Alexnet with XGBoost as classifier applied to galaxy pictures dataset

In [1]:
from PIL import Image
import torch
from torch import nn
from torchvision import transforms
from torchvision import datasets
import torchvision.models as models

from sklearn.metrics import confusion_matrix, classification_report

import pandas as pd
import numpy as np
import os

import matplotlib.pyplot as plt
import seaborn as sns

### Train classifier

In [5]:
torch.multiprocessing.set_sharing_strategy('file_system')

# Load model
model = models.alexnet(pretrained=True)
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

In [3]:
# Refit model class (to overload nn.Module class)
class RefitModel(nn.Module):
    def __init__(self, orgininal_model, num_classes):
        super(RefitModel, self).__init__()
        
        self.features = orgininal_model.features # to keep convolitionnal layers as is
        
        # modify only classification layers
        self.classifier = orgininal_model.classifier 
        self.classifier == nn.Sequential(
            nn.Dropout(),
            nn.Linear(256*6*6, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )
        self.modelName = 'alexnet'
        
        for w in self.features.parameters():
            w.required_grad = False # to freeze features weights
    
    # transfer features in classification layer
    def forward(self,x):
        f = self.features(x)
        f = f.view(f.size(0), 256*6*6)
        y = self.classifier(f)
        return y
    
def preprocess(img):  
    img = img.resize((224,224)).convert('RGB')
    tensor = transforms.ToTensor()(img)
    return tensor

In [4]:
# Refit model
num_classes = 7
refit_model = RefitModel(model, num_classes)

# define training and test data directories
data_dir = 'data/'
train_dir = os.path.join(data_dir, 'train/')
test_dir = os.path.join(data_dir, 'test/')

# classes are folders in each directory with these names
classes = os.listdir('data/train')
classes

['edge', 'other', 'smooth', 'spiral']

In [5]:
# Load training data and fit to alexnet data
train = datasets.ImageFolder(train_dir, transform=preprocess)
test = datasets.ImageFolder(test_dir, transform=preprocess)

# print out some data stats
print('Num training images: ', len(train))
print('Num test images: ', len(test))

Num training images:  352
Num test images:  48


In [6]:
batch_size = 20
num_workers = 0
train_loader = torch.utils.data.DataLoader(train,batch_size=batch_size,shuffle=True,num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test,batch_size=batch_size,shuffle=True,num_workers=num_workers)

In [None]:
# Train model
error = nn.CrossEntropyLoss()
learning_rate = 0.001
optimizer = torch.optim.SGD(refit_model.parameters(), lr=learning_rate)
num_epochs = 2

for epoch in range(num_epochs):
    running_loss = 0.0
    
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward propagation
        outputs = refit_model(inputs)
        
        # Calculate cross entropy loss
        loss = error(outputs, labels)
        
        # Calculating gradients
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        # Print stats
        running_loss += loss.item()
        
        if i%20 == 20:
            print('epoch: %d, batch: %5d, loss: %.3f'%(epoch+1, i, running_loss/20))
            running_loss = 0.0

# Save model
torch.save(refit_model, 'tl_model.pth')

### Test classifier

In [None]:
model = torch.load('tl_model.pth')
model.eval()
y_test = []
y_pred = []

k = 5

for i, data in enumerate(test_loader, 0):
    inputs, label = data
    out = model(inputs)
    _, topk_catid = torch.topk(out, 1)
    
    y_test.append(int(labels[0]))
    y_pred.append(int(topk_catid[0]))

In [None]:
classification_report(y_test, y_pred)

In [None]:
# Confusion matrix
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
    confusion_matrix(y_test, y_pred), 
    cmap='YlGnBu', 
    linewidths=.2,
    linecolor='gray',
    cbar_kws={"shrink": .8},
    annot = True,
    fmt='d'
)