# Transfer learning and other tricks

In [31]:
import torch
import torchvision
import torch.optim as optim
import torch.nn as nn
import torch.utils.data
import torch.nn.functional as F
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt


In [32]:
transfer_model = models.resnet50(weights=True)

## Freezing params - except BatchNorm

In [33]:
for name, param in transfer_model.named_parameters():
    if ("bn" not in name):
        param.requires_grad = False 
        

## Replacing the classifier (fc)

In [34]:
transfer_model.fc = nn.Sequential(
    nn.Linear(transfer_model.fc.in_features, 500), 
    nn.ReLU(), 
    nn.Dropout(), 
    nn.Linear(500, 2))

## Training again 

In [35]:
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print("Device is: ", device) 
transfer_model.to(device)
optimizer = optim.Adam(transfer_model.parameters(), lr=0.001)

Device is:  cuda


In [36]:
def check_image(path):
    try:
        img = Image.open(path)
        return True 
    except:
        return False

In [37]:
img_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_data_path = "./cat-dog-data/train"
train_data = torchvision.datasets.ImageFolder(root=train_data_path, transform=img_transforms, is_valid_file=check_image)
val_data_path = "./cat-dog-data/val"
val_data = torchvision.datasets.ImageFolder(root=val_data_path, transform=img_transforms, is_valid_file=check_image)



In [38]:
batch_size=64
train_data_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = torch.utils.data.DataLoader(val_data, batch_size=batch_size, shuffle=True)
print(len(val_data_loader.dataset))

200


In [41]:
def train(model, optimizer, loss_fn, train_loader, val_loader, epochs, device):
    for epoch in range(epochs):
        training_loss = 0.0
        valid_loss = 0.0
        model.train()
        for batch in train_loader:
            optimizer.zero_grad()
            inputs, targets = batch 
            inputs = inputs.to(device)
            targets = targets.to(device)
            output = model(inputs)
            loss = loss_fn(output, targets)
            loss.backward()
            optimizer.step()
            training_loss += loss.data.item() * inputs.size(0)
        training_loss /= len(train_loader.dataset)

        model.eval()
        num_correct = 0
        num_examples = 0
        for batch in val_loader:
            inputs, targets = batch
            inputs = inputs.to(device)
            output = model(inputs) # [batch, num_classes]
            targets = targets.to(device)
            loss = loss_fn(output, targets)
            valid_loss += loss.data.item() * inputs.size(0)
            correct = torch.eq( torch.max(F.softmax(output, dim=1), dim=1)[1], targets ).view(-1) 
            num_correct += torch.sum(correct).item()
            num_examples += correct.shape[0]
        valid_loss /= len(val_loader.dataset)

        print("Epoch: {}, Training Loss: {:.2f}, Validation Loss: {:.2f}, accuracy: {:.2f}".format(epoch, training_loss, valid_loss, num_correct/num_examples))

In [42]:
train(transfer_model, optimizer, torch.nn.CrossEntropyLoss(), train_data_loader, val_data_loader, epochs=5, device=device)

Epoch: 0, Training Loss: 0.20, Validation Loss: 0.57, accuracy: 0.79
Epoch: 1, Training Loss: 0.16, Validation Loss: 0.53, accuracy: 0.81
Epoch: 2, Training Loss: 0.24, Validation Loss: 0.84, accuracy: 0.77
Epoch: 3, Training Loss: 0.25, Validation Loss: 0.86, accuracy: 0.76
Epoch: 4, Training Loss: 0.17, Validation Loss: 0.72, accuracy: 0.78
