In [83]:
%matplotlib inline


Transfer Learning for Computer Vision Tutorial
==============================================
**Author**: `Sasank Chilamkurthy <https://chsasank.github.io>`_

In this tutorial, you will learn how to train a convolutional neural network for
image classification using transfer learning. You can read more about the transfer
learning at `cs231n notes <https://cs231n.github.io/transfer-learning/>`__

Quoting these notes,

    In practice, very few people train an entire Convolutional Network
    from scratch (with random initialization), because it is relatively
    rare to have a dataset of sufficient size. Instead, it is common to
    pretrain a ConvNet on a very large dataset (e.g. ImageNet, which
    contains 1.2 million images with 1000 categories), and then use the
    ConvNet either as an initialization or a fixed feature extractor for
    the task of interest.

These two major transfer learning scenarios look as follows:

-  **Finetuning the convnet**: Instead of random initializaion, we
   initialize the network with a pretrained network, like the one that is
   trained on imagenet 1000 dataset. Rest of the training looks as
   usual.
-  **ConvNet as fixed feature extractor**: Here, we will freeze the weights
   for all of the network except that of the final fully connected
   layer. This last fully connected layer is replaced with a new one
   with random weights and only this layer is trained.




In [105]:
# License: BSD
# Author: Sasank Chilamkurthy

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from skimage import io, transform

import copy
import os.path
import glob
from torch.utils.data import Dataset, DataLoader
import pandas as pd
plt.ion()   # interactive mode
from matplotlib.pyplot import imshow


Load Data
---------

We will use torchvision and torch.utils.data packages for loading the
data.

The problem we're going to solve today is to train a model to classify
**ants** and **bees**. We have about 120 training images each for ants and bees.
There are 75 validation images for each class. Usually, this is a very
small dataset to generalize upon, if trained from scratch. Since we
are using transfer learning, we should be able to generalize reasonably
well.

This dataset is a very small subset of imagenet.

.. Note ::
   Download the data from
   `here <https://download.pytorch.org/tutorial/hymenoptera_data.zip>`_
   and extract it to the current directory.



In [85]:
# Writing a custermized dataset for my data
class CommendDataset(Dataset):
    """Face Landmarks dataset."""

    def __init__(self, root_dir , label_dir  , phase, transform=None):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.root_dir = root_dir
        self.transform = transform
        self.num_image = glob.glob(os.path.join(root_dir,'*.png'))
        self.df_label = pd.read_csv(os.path.join(label_dir,'MPM_Layer0013.csv'))
        self.phase = phase

    def __len__(self):
        return len(self.num_image)

    def __getitem__(self, idx):
        lab = self.df_label.iloc[idx]['label']
        if self.phase == 'val':
            idx = idx+1999
        if torch.is_tensor(idx):
            idx = idx.tolist()

        img_name = os.path.join(self.root_dir,
                                str(idx+1)+'.png')
        image = io.imread(img_name).transpose((2, 0, 1))

        if self.transform:
            image = self.transform(image)
        
        return image,lab

In [86]:
# Data augmentation and normalization for training
# Just normalization for validation
# data_transforms = {
#     'train': transforms.Compose([
#         transforms.RandomResizedCrop(224),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
#     'val': transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
# }

data_dir = './CMD_Layer0013/'
label_dir = './labels/'
image_datasets = {x: CommendDataset(root_dir = os.path.join(data_dir, x),label_dir = os.path.join(label_dir,x),phase = x) for x in ['train', 'val']}
# dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
#                                              shuffle=True, num_workers=4)
#               for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
                                             shuffle=True)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Visualize a few images
^^^^^^^^^^^^^^^^^^^^^^
Let's visualize a few training images so as to understand the data
augmentations.



In [87]:
# def imshow(inp, title=None):
#     """Imshow for Tensor."""
#     inp = inp.numpy()
#     plt.imshow(inp)
#     if title is not None:
#         plt.title(title)
#     plt.pause(0.001)  # pause a bit so that plots are updated


# # Get a batch of training data
# inputs, classes = next(iter(dataloaders['train']))

# # Make a grid from batch
# out = torchvision.utils.make_grid(inputs)

# imshow(out, title=[class_names[x] for x in classes])

Training the model
------------------

Now, let's write a general function to train a model. Here, we will
illustrate:

-  Scheduling the learning rate
-  Saving the best model

In the following, parameter ``scheduler`` is an LR scheduler object from
``torch.optim.lr_scheduler``.



In [113]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 9999999

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            epoch_loss = 0
            epoch_acc = 0
            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                
                inputs = inputs.to(device, dtype=torch.float32)
                labels = labels.to(device, dtype=torch.float32)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()
                    
                epoch_loss += loss * inputs.size(0)
                # statistics
                # running_loss += loss.item() * inputs.size(0)
                # running_corrects += torch.sum(preds == labels.data)
            if phase == 'train':
                scheduler.step()

            epoch_loss = epoch_loss/dataset_sizes[phase]

            epoch_acc = epoch_loss.sqrt()

            print('{} MSE: {:.4f} RMSE: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))

            # deep copy the model
            if phase == 'val' and epoch_acc < best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model


Visualizing the model predictions
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Generic function to display predictions for a few images




In [107]:
def visualize_model(model, num_images=6):
    was_training = model.training
    model.eval()
    images_so_far = 0
    fig = plt.figure()

    with torch.no_grad():
        for i, (inputs, labels) in enumerate(dataloaders['val']):
            inputs = inputs.to(device, dtype=torch.float32)
            labels = labels.to(device, dtype=torch.float32)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)

            for j in range(inputs.size()[0]):
                images_so_far += 1
                ax = plt.subplot(num_images//2, 2, images_so_far)
                ax.axis('off')
                ax.set_title('predicted: {}'.format('name'))
                imshow(inputs.cpu().data[j])

                if images_so_far == num_images:
                    model.train(mode=was_training)
                    return
        model.train(mode=was_training)

Finetuning the convnet
----------------------

Load a pretrained model and reset final fully connected layer.




In [94]:
model_ft = models.resnet18(pretrained=True)
num_ftrs = model_ft.fc.in_features
# Here the size of each output sample is set to 2.
# Alternatively, it can be generalized to nn.Linear(num_ftrs, len(class_names)).
model_ft.fc = nn.Linear(num_ftrs,1)

model_ft = model_ft.to(device)

criterion = nn.MSELoss()

# Observe that all parameters are being optimized
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)

In [97]:
model_ft = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler,
                       num_epochs=50)

Epoch 0/49
----------
train MSE: 162.0099 RMSE: 12.7283
val MSE: 15720.0332 RMSE: 125.3796

Epoch 1/49
----------
train MSE: 99.3222 RMSE: 9.9661
val MSE: 3923.1514 RMSE: 62.6351

Epoch 2/49
----------
train MSE: 90.3431 RMSE: 9.5049
val MSE: 1891.8464 RMSE: 43.4954

Epoch 3/49
----------
train MSE: 1658.9434 RMSE: 40.7301
val MSE: 19936.6582 RMSE: 141.1972

Epoch 4/49
----------
train MSE: 8860.9199 RMSE: 94.1325
val MSE: 4127.8096 RMSE: 64.2480

Epoch 5/49
----------
train MSE: 627.9932 RMSE: 25.0598
val MSE: 7996.9219 RMSE: 89.4255

Epoch 6/49
----------
train MSE: 635.4999 RMSE: 25.2091
val MSE: 14255.0762 RMSE: 119.3946

Epoch 7/49
----------
train MSE: 9685.6201 RMSE: 98.4156
val MSE: 5992.7324 RMSE: 77.4127

Epoch 8/49
----------
train MSE: 1555.2343 RMSE: 39.4365
val MSE: 2213.7583 RMSE: 47.0506

Epoch 9/49
----------
train MSE: 9251.8975 RMSE: 96.1868
val MSE: 3009.8491 RMSE: 54.8621

Epoch 10/49
----------
train MSE: 54.4036 RMSE: 7.3759
val MSE: 11508.1377 RMSE: 107.2760

Ep

In [108]:
visualize_model(model_ft)

TypeError: transpose() received an invalid combination of arguments - got (tuple), but expected one of:
 * (name dim0, name dim1)
 * (int dim0, int dim1)


ConvNet as fixed feature extractor
----------------------------------

Here, we need to freeze all the network except the final layer. We need
to set ``requires_grad == False`` to freeze the parameters so that the
gradients are not computed in ``backward()``.

You can read more about this in the documentation
`here <https://pytorch.org/docs/notes/autograd.html#excluding-subgraphs-from-backward>`__.




In [111]:
model_conv = torchvision.models.resnet18(pretrained=True)
for param in model_conv.parameters():
    param.requires_grad = False

# Parameters of newly constructed modules have requires_grad=True by default
num_ftrs = model_conv.fc.in_features
model_conv.fc = nn.Linear(num_ftrs, 1)

model_conv = model_conv.to(device)

criterion = nn.MSELoss()

# Observe that only parameters of final layer are being optimized as
# opposed to before.
optimizer_conv = optim.SGD(model_conv.fc.parameters(), lr=0.001, momentum=0.9)

# Decay LR by a factor of 0.1 every 7 epochs
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_conv, step_size=7, gamma=0.1)

Train and evaluate
^^^^^^^^^^^^^^^^^^

On CPU this will take about half the time compared to previous scenario.
This is expected as gradients don't need to be computed for most of the
network. However, forward does need to be computed.




In [112]:
model_conv = train_model(model_conv, criterion, optimizer_conv,
                         exp_lr_scheduler, num_epochs=50)

Epoch 0/49
----------
train MSE: 6535.7681 RMSE: 80.8441
val MSE: 108.8095 RMSE: 10.4312

Epoch 1/49
----------
train MSE: 6715.2686 RMSE: 81.9467
val MSE: 5485.1382 RMSE: 74.0617

Epoch 2/49
----------
train MSE: 2872.3582 RMSE: 53.5944
val MSE: 2153.4875 RMSE: 46.4057

Epoch 3/49
----------
train MSE: 6352.4570 RMSE: 79.7023
val MSE: 6195.6016 RMSE: 78.7121

Epoch 4/49
----------
train MSE: 3571.6323 RMSE: 59.7631
val MSE: 16527.4688 RMSE: 128.5592

Epoch 5/49
----------
train MSE: 615.6530 RMSE: 24.8124
val MSE: 3038.9360 RMSE: 55.1265

Epoch 6/49
----------
train MSE: 5119.3140 RMSE: 71.5494
val MSE: 14325.2461 RMSE: 119.6881

Epoch 7/49
----------
train MSE: 8484.2500 RMSE: 92.1100
val MSE: 2146.0859 RMSE: 46.3259

Epoch 8/49
----------
train MSE: 2395.1133 RMSE: 48.9399
val MSE: 9014.4736 RMSE: 94.9446

Epoch 9/49
----------
train MSE: 514.3694 RMSE: 22.6797
val MSE: 976.9918 RMSE: 31.2569

Epoch 10/49
----------
train MSE: 4620.6504 RMSE: 67.9754
val MSE: 12757.5117 RMSE: 112.94

In [None]:
visualize_model(model_conv)

plt.ioff()
plt.show()

Further Learning
-----------------

If you would like to learn more about the applications of transfer learning,
checkout our `Quantized Transfer Learning for Computer Vision Tutorial <https://pytorch.org/tutorials/intermediate/quantized_transfer_learning_tutorial.html>`_.



