# Transfer Learning - Getting Started

In this notebook we will take a pre-trained model, perform classification on a dataset that is not part of the original ImageNet data. We will then go one setup further and reset the last fc connected layer to showcase a fixed extractor.

In [None]:
from torchvision import datasets
from torchvision import models
from torchvision import transforms
import torch
import json
import re
from pprint import pprint
from torchsummary import summary
import pandas as pd
import numpy as np

torch.backends.cudnn.deterministic = True
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Lets load a Resnet101 pretrained model

In [None]:
model = models.resnet101(pretrained=True)
model = model.to(device)

### PyTorch Transforms

Pytorch Vision has multiple transforms that we can use together. We call `Compose` to chain them together

In [None]:
preprocess = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
])

### Getting our dataset
Lets get a dataset on flowers from Kaggle

In [None]:
! kaggle datasets download -d olgabelitskaya/flower-color-images

In [None]:
! ls 

In [None]:
! unzip flower-color-images.zip

In [None]:
!rm *.h5 && rm *.zip && mv flower_images ../../data/

In [None]:
! ls ../../data/flower_images

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sb

In [None]:
df = pd.read_csv('../../data/flower_images/flower_labels.csv')

In [None]:
df.head(10)

In [None]:
df.groupby(df.label).count()

Lets declare the labels dictionary so we can use it for correlation

In [None]:
lbls = {
    0: 'phlox', 
    1: 'rose',
    2: 'calendula',
    3: 'iris',
    4: 'leucanthemum maximum',
    5: 'bellflower', 
    6: 'viola',
    7: 'rudbeckia laciniata (Goldquelle)',
    8: 'peony',
    9: 'aquilegia'
}

Lets write our own dataset class for the images

In [None]:
from PIL import Image
import os
class FlowerDataset(torch.utils.data.Dataset):
    def __init__(self, csv_file, root_dir, transform=preprocess, label_dict=lbls, **kwargs):
        self.img_frame = pd.read_csv(csv_file, usecols=range(1))
        self.lbl_frame = pd.read_csv(csv_file, usecols=range(1,2))
        self.root_dir = root_dir
        self.transform = transform
        #super(FlowerDataSet,self).__init__(root_dir, **kwargs)
        
    def __len__(self):
        return len(self.img_frame)
    
    def __getitem__(self, idx):
        img = os.path.join(self.root_dir, self.img_frame.iloc[idx, 0])
        img = Image.open(img)
        img = img.convert('RGB')
        img = self.transform(img)
        label = self.lbl_frame.iloc[idx, 0]
        return {'image': img, 'labels': label}
        

Load the dataset

In [None]:
dataset = FlowerDataset(
                csv_file='../../data/flower_images/flower_labels.csv',
                root_dir='../../data/flower_images/')

Plot the first few examples

In [None]:
fig = plt.figure()

for i in range(len(dataset)):
    sample = dataset[i]
    ax = plt.subplot(1, 4, i + 1)
    plt.tight_layout()
    ax.set_title('{}'.format(lbls[sample['labels']]))
    ax.axis('off')
    plt.imshow(transforms.ToPILImage()(sample['image']))

    if i == 3:
        plt.show()
        break

Since we will be using a model trained with imagenet we need to load those class mapping

In [None]:
classes = dict()
with open('../../data/imagenet_labels.txt', 'r') as f:
    for line in f.readlines():
        parts = line.split(':')
        classes[int(parts[0])] = re.sub("'|\n", '', parts[1])

To validate how bad the model would perform lets classify the first few examples

In [None]:
for idx, sample in enumerate(dataset):
        fig = plt.figure()
        image = sample['image']
        im = image.unsqueeze(0).to(device)
        pred = model(im)
        _, output = torch.max(pred, 1)
        confidence = torch.nn.functional.softmax(pred, dim=1)[0] * 100
        ax = plt.subplot(1, 1, 1)
        ax.set_title('Pred {} Conf: {}'.format(
            classes[int(output[0])], 
            confidence[output[0]].item()))
        ax.axis('off')
        plt.imshow(transforms.ToPILImage()(image.squeeze(0)))
        if idx == 3:
            plt.show()
            break
            

# Transfer Learning with Fixed Encoder

First lets reset our model

In [None]:
model = models.resnet101(pretrained=True)

Then we make sure we turn off autograd for all layers

In [None]:
for param in model.parameters():
    param.requires_grad = False

Now lets reset the final fully connected layer

In [None]:
in_features = model.fc.in_features
model.fc = torch.nn.Linear(in_features, len(lbls))

Set out loss function, optimizer and learning rate scheduler. Take note of what we are passing into the optimizer. This is very important

In [None]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.fc.parameters(), lr=0.001)
# Decay LR by a factor of 0.1 every 3 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

Lets reorganize our flowers images data to conform to the way PyTorch's ImageFolder dataset needs it. That way we can leverage a lot out of the box

In [None]:
import shutil
def organize_imgs(row):
    base_path = '../../data/flower_images_org'
    lbl = lbls[row['label']]
    dest = os.path.join(base_path, mode, lbl)
    if not os.path.exists(dest):
        os.makedirs(dest)
    shutil.move(os.path.join(base_path, row['file']), os.path.join(dest, row['file']))

Randomly split the dataset into 80/20

In [None]:
msk = np.random.rand(len(df)) < 0.8
train = df[msk]
val = df[~msk]

In [None]:
mode = 'train'
train.apply(organize_imgs, axis=1)

In [None]:
mode = 'val'
val.apply(organize_imgs, axis=1)

Since our dataset is very light we do some data augmentation

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

Here we load the datasets and then into data loaders

In [None]:
data_dir = '../../data/flower_images_org'

image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}

In [None]:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True, num_workers=4)
              for x in ['train', 'val']}

In [None]:
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes

Lets create the training function

In [None]:
import time
def train(num_epochs=10):
    model = model.to(device)
    since = time.time()
    best_acc = 0.0

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)
        
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

In [None]:
train(num_epochs=10)

# Transfer Learning with Fine Tuning

The only big difference here is we aren't freezing the parameters and instead fine-tuning the entire network

In [None]:
model = models.resnet101(pretrained=True)
model.to(device)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Decay LR by a factor of 0.1 every 7 epochs
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

In [None]:
train(num_epochs=25)

# Excercises

1. Try out different hyper parameters. Can you get it to perform better?
2. Try out a different model. Does it perform better?