# Chapter 4: Transfer Learning with ResNet

In [2]:
import torch
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt

## Importing Pretrained Model

In [3]:
transfer_model = models.resnet50(pretrained=True)

Downloading: "https://download.pytorch.org/models/resnet50-19c8e357.pth" to /Users/dongdongdongdong/.cache/torch/checkpoints/resnet50-19c8e357.pth
100%|██████████| 97.8M/97.8M [00:11<00:00, 9.21MB/s]


In [5]:
transfer_model

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

## Freezing Parameters

In [6]:
for name, param in transfer_model.named_parameters():
    if ('bn' not in name):
        param.reguires_grad = False

## Replacing the Classifier (the Fully-Connected Layer)

In [34]:
transfer_model.fc = nn.Sequential(nn.Linear(2048, 500),
                                  nn.ReLU(),
                                  nn.Dropout(),
                                  nn.Linear(500,2)) 

## Training Again

In [52]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs=5, device='cpu'):
    for epoch in range(epochs):
        training_loss = 0.0
        val_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item()
        training_loss /= len(train_loader.dataset)
        
        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            val_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq(torch.max(F.softmax(output), dim=1)[1], targets).view(-1)
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        val_loss /= len(val_loader.dataset)
        
        print('Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, Accuracy = {:.2f}'.format(epoch, training_loss, val_loss, num_correct/num_examples))
        

In [36]:
def check_image(path):
    try:
        im = Image.open(path)
        return True
    except:
        return False

In [37]:
img_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                          std=[0.229, 0.224, 0.225])
])

In [38]:
train_data_path = '/Users/dongdongdongdong/Desktop/images/train'
train_data = torchvision.datasets.ImageFolder(root=train_data_path,
                                              transform=img_transforms,
                                              is_valid_file=check_image)

val_data_path = '/Users/dongdongdongdong/Desktop/images/val'
val_data = torchvision.datasets.ImageFolder(root=val_data_path,
                                              transform=img_transforms,
                                              is_valid_file=check_image)

In [39]:
batch_size = 64

train_data_loader = torch.utils.data.DataLoader(train_data,
                                                batch_size=batch_size,
                                                shuffle=True)
val_data_loader = torch.utils.data.DataLoader(val_data,
                                                batch_size=batch_size,
                                                shuffle=True)

In [42]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [43]:
transfer_model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [45]:
optimizer = torch.optim.Adam(transfer_model.parameters(), lr=0.001)

In [46]:
loss_fn = torch.nn.CrossEntropyLoss()

In [53]:
train(transfer_model, optimizer, loss_fn, train_data_loader, val_data_loader)



Epoch: 0, Training Loss: 0.00, Validation Loss: 0.45, Accuracy = 0.89
Epoch: 1, Training Loss: 0.00, Validation Loss: 0.41, Accuracy = 0.87
Epoch: 2, Training Loss: 0.00, Validation Loss: 0.66, Accuracy = 0.79
Epoch: 3, Training Loss: 0.00, Validation Loss: 1.21, Accuracy = 0.66
Epoch: 4, Training Loss: 0.00, Validation Loss: 0.37, Accuracy = 0.87
