# Transfer Learning

In this lab you will learn how to use a pre-trained network to save time and quickly solve your problem. Specifically, you will use a pre-trained ResNet50 model to classify the dataset CIFAR-10.

In [None]:
import matplotlib.pyplot as plt
import torch
from torch import nn
from torch import optim
import torch.nn.functional as F
from torchvision import datasets, transforms, models
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
model = models.resnet50(pretrained=False)
model.load_state_dict(torch.load('relative path of resnet50-19c8e357.pth'))

## TODO
Define transformation for the training data and testing data. Explore different functions in `torchvision.transforms`.

In [None]:
data_dir = os.path.join('.', 'data')

########## TODO: Define transforms for the training data and testing data ########
train_transforms = transforms.Compose([transforms.RandomRotation(30),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406],
                                                            [0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

##################################################################################

# Pass transforms in here, then run the next cell to see how the transforms look
train_data = datasets.CIFAR10(data_dir, download=False, train=True, transform=transforms.ToTensor())
test_data = datasets.CIFAR10(data_dir, download=False, train=False, transform=transforms.ToTensor())

trainloader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=64)

### View Image

In [None]:
img,lab = train_data[0]
to_pil = transforms.ToPILImage()
img = to_pil(img)
img

In [None]:
class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()    
    def forward(self, x):
        return x

model.avgpool = Identity() 

for param in model.parameters():
    param.required_grad = False
    
from collections import OrderedDict
########### TODO: Define the classifier #############
classifier = nn.Sequential(OrderedDict([
    ('relu1', nn.ReLU()),
    ('fc1', nn.Linear(1000,128)),
    ('relu2', nn.ReLU()),
    ('fc2', nn.Linear(128,10))
]))

#####################################################
model.classifier = classifier

In [None]:
################ TODO: choose the optimizer and loss function ###############
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
#############################################################################

In [None]:
model = model.to(device)
epochs = 10
steps = 0
running_loss = 0
print_every = 200
for epoch in range(epochs):
    for inputs, labels in trainloader:
        steps += 1
        # Move input and label tensors to the default device
        inputs, labels = inputs.to(device), labels.to(device)
        
        ######### TODO: training process ##########
        optimizer.zero_grad()
        
        logps = model(inputs)
        loss = criterion(logps, labels)
        loss.backward()
        optimizer.step()
        ############################################
        running_loss += loss.item()
        
        if steps % print_every == 0:
            test_loss = 0
            accuracy = 0
            model.eval()
            with torch.no_grad():
                for inputs, labels in testloader:
                    inputs, labels = inputs.to(device), labels.to(device)
                    logps = model.forward(inputs)
                    batch_loss = criterion(logps, labels)
                    
                    test_loss += batch_loss.item()
                    
                    # Calculate accuracy
                    ps = torch.exp(logps)
                    top_p, top_class = ps.topk(1, dim=1)
                    equals = top_class == labels.view(*top_class.shape)
                    accuracy += torch.mean(equals.type(torch.FloatTensor)).item()
                    
            print(f"Epoch {epoch+1}/{epochs}.. "
                  f"Train loss: {running_loss/print_every:.3f}.. "
                  f"Test loss: {test_loss/len(testloader):.3f}.. "
                  f"Test accuracy: {accuracy/len(testloader):.3f}")
            running_loss = 0
            model.train()