
Transfer Learning
===============




In [None]:

from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
import copy
%matplotlib inline

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

n_classes = 2
batch_size = 4


In [None]:
# Training helpers
def get_trainable(model_params):
    return (p for p in model_params if p.requires_grad)


def get_frozen(model_params):
    return (p for p in model_params if not p.requires_grad)


def all_trainable(model_params):
    return all(p.requires_grad for p in model_params)


def all_frozen(model_params):
    return all(not p.requires_grad for p in model_params)


def freeze_all(model_params):
    for param in model_params:
        param.requires_grad = False


# list(get_trainable(model.parameters()))
# list(get_frozen(model.parameters()))
# all_trainable(model.parameters())
# all_frozen(model.parameters())

## Dataset

In [None]:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

data_dir = 'mydata'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}



device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

class_names = image_datasets['train'].classes
print(class_names)
print('\n train_dataset  \n \n',image_datasets['train'])
print('\n validation_dataset  \n \n',image_datasets['val'])

## DataLoader
Batch loading for datasets with multi-processing and different sample strategies.

In [None]:

# Data loader
train_loader = torch.utils.data.DataLoader(dataset=image_datasets['train'],
                                           batch_size=batch_size, 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=image_datasets['val'],
                                          batch_size=batch_size, 
                                          shuffle=False)


In [None]:
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)  # pause a bit so that plots are updated


# Get a batch of training data
inputs, classes = next(iter(train_loader))

# Make a grid from batch
out = torchvision.utils.make_grid(inputs)

imshow(out, title=[class_names[x] for x in classes])

# The Model
PyTorch offers quite a few [pre-trained networks](https://pytorch.org/docs/stable/torchvision/models.html) for you to use:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
- Inception v3

And there are more available via [pretrained-models.pytorch](https://github.com/Cadene/pretrained-models.pytorch)
- NASNet,
- ResNeXt,
- InceptionV4,
- InceptionResnetV2, 
- Xception, 
- DPN,
- ...

In [None]:
from torchvision import models

model = models.resnet18(pretrained=True)

In [None]:
model

In [None]:
# Freeze all parameters
for param in model.parameters():
    param.requires_grad = False

In [None]:
freeze_all(model.parameters())
assert all_frozen(model.parameters())

Replace the last layer with a linear layer. New layers have `requires_grad = True`.

In [None]:
model.fc = nn.Linear(512, n_classes)

In [None]:
all_frozen(model.parameters())

In [None]:
def get_model(n_classes=2):
    model = models.resnet18(pretrained=True)
    freeze_all(model.parameters())
    model.fc = nn.Linear(512, n_classes)
    return model


model = get_model().to(device)

# The Loss

In [None]:
criterion = nn.CrossEntropyLoss()

# The Optimizer

In [None]:
optimizer = torch.optim.Adam(
    get_trainable(model.parameters()),
    # model.fc.parameters(),
    lr=0.001,
    # momentum=0.9,
)

# The Train Loop

In [None]:
N_EPOCHS = 5

In [None]:
for epoch in range(N_EPOCHS):
    print(f"Epoch {epoch+1}/{N_EPOCHS}")
    
    # Train
    model.train()  # IMPORTANT
    
    running_loss, correct = 0.0, 0
    for X, y in train_loader:
        X, y = X.to(device), y.to(device)
   
        optimizer.zero_grad()
        # with torch.set_grad_enabled(True):
        y_ = model(X)
        loss = criterion(y_, y)

        loss.backward()
        optimizer.step()
        
        # Statistics
        print(f"    batch loss: {loss.item():0.3f}")
        _, y_label_ = torch.max(y_, 1)
        correct += (y_label_ == y).sum().item()
        running_loss += loss.item() * X.shape[0]
    
    print(f"  Train Loss: {running_loss / len(train_loader.dataset)}")
    print(f"  Train Acc:  {correct / len(train_loader.dataset)}")
    
    
    # Eval
    model.eval()  # IMPORTANT
    
    running_loss, correct = 0.0, 0
    with torch.no_grad():  # IMPORTANT
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)
                    
            y_ = model(X)
        
            _, y_label_ = torch.max(y_, 1)
            correct += (y_label_ == y).sum().item()
            
            loss = criterion(y_, y)
            running_loss += loss.item() * X.shape[0]
    
    print(f"  Valid Loss: {running_loss / len(test_loader.dataset)}")
    print(f"  Valid Acc:  {correct / len(test_loader.dataset)}")
    print()