## Exercise 3: Transfer Learning (30 points)

In practice, people won't train an entire CNN from scratch, because it is relatively rare to have a dataset of sufficient size (or sufficient computational power). Instead, it is common to pretrain a CNN on a very large dataset and then use the CNN either as an initialization or a fixed feature extractor for the task of interest.

In this task, you will learn how to use a pretrained CNN for CIFAR-10 classification.

### Task1: Load pretrained model

`torchvision.models` (https://pytorch.org/vision/stable/models.html) contains definitions of models for addressing different tasks, including: image classification, pixelwise semantic segmentation, object detection, instance segmentation, person keypoint detection and video classification.

First, you should load the **pretrained** ResNet-18 that has already been trained on [ImageNet](https://www.image-net.org/) using `torchvision.models`. If you are interested in more details about Resnet-18, read this paper https://arxiv.org/pdf/1512.03385.pdf.

In [None]:
import torchvision.models as models

resnet18 = models.resnet18(pretrained=True)
print(resnet18)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

### Task2: Create data loaders for CIFAR-10

Then you need to create a dataloader of CIFAR-10. Note that the model you load has been trained on **ImageNet** and it expects inputs as mini-batches of 3-channel RGB images of shape (3 x H x W), where H and W are expected to be **at least** 224. So you need to preprocess the CIFAR-10 data to make sure it has a height and width. See [`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize).
You will probably want to add this transform appropriately to the `transform` you created in a previous task.


In [None]:
import torchvision

batch_size = 100
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
                                            torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                            torchvision.transforms.Resize((224, 224))])

test_cifar10_dataset = torchvision.datasets.CIFAR10('./data', train=False, download=True, transform=transform)
train_cifar10_dataset = torchvision.datasets.CIFAR10('./data', train=True, transform=transform)
test_cifar10_loader = torch.utils.data.DataLoader(test_cifar10_dataset, batch_size=batch_size, shuffle=False)
train_cifar10_loader = torch.utils.data.DataLoader(train_cifar10_dataset, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


### Task3: Classify test data on pretrained model

Use the model you load to classify the **test** CIFAR-10 data and print out the test accuracy.

Don't be surprised if the accuracy is bad!

In [None]:
def test():  
  resnet18.to(device)
  resnet18.eval()
  with torch.no_grad():
    correct = 0
    tot = 0
    for images, labels in test_cifar10_loader:
      images = images.to(device)
      labels = labels.to(device)
      outputs = resnet18(images)
      prediction = outputs.data.max(1, keepdim=True)[1]
      correct += prediction.eq(labels.data.view_as(prediction)).sum()
      tot += len(images)
      print(f'Accuracy: {correct/tot}')

  print(f'Test Accuracy: {100.*correct/len(test_cifar10_dataset)}%')

In [None]:
# Testing pretrained Resnet18 model on CIFAR10 dataset
test()

Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0
Accuracy: 0.0005000000237487257
Accuracy: 0.0004761904710903764
Accuracy: 0.00045454545761458576
Accuracy: 0.00043478261795826256
Accuracy: 0.00041666667675599456
Accuracy: 0.00039999998989515007
Accuracy: 0.0003846153849735856
Accuracy: 0.000370370369637385
Accuracy: 0.0003571428533177823
Accuracy: 0.00034482759656384587
Accuracy: 0.00033333332976326346
Accuracy: 0.0003225806576665491
Accuracy: 0.0003124999930150807
Accuracy: 0.0003030303050763905
Accuracy: 0.00029411763534881175
Accuracy: 0.0002857142826542258
Accuracy: 0.00027777778450399637
Accuracy: 0.0002702702768146992
Accuracy: 0.0002631578827276826
Accuracy: 0.00025641024694778025
Accuracy: 0.0002500000118743628
Accuracy: 0.0002439024392515421
Accuracy: 0.0002380952

### Task 4: Update model for CIFAR-10

Now try to improve the test accuracy. We offer several possible solutions:

(1) You can try to directly continue to train the model you load with the CIFAR-10 training data. 

(2) For efficiency, you can try to freeze part of the parameters of the loaded models. For example, you can first freeze all parameters by

```
for param in model.parameters():
    param.requires_grad = False
```
and then unfreeze the last few layers by setting `somelayer.requires_grad=True`.

You are also welcome to try any other approach you can think of.


**Note:** You should print out the test accuracy and to get full credits, the test accuracy should be at least **80%**.

In [None]:
import torch.optim as optim
import torch.nn.functional as F
import itertools

cnt = 0
for child in resnet18.children():
  cnt += 1
  if cnt < 7:
    for param in child.parameters():
      param.requires_grad = False

optimizer = optim.SGD(list(filter(lambda p: p.requires_grad, resnet18.parameters())), lr=0.01, momentum=0.9)

def train(epoch):
  resnet18.to(device)
  resnet18.train()

  for batch_idx, (images, targets) in enumerate(train_cifar10_loader):
    images = images.to(device)
    targets = targets.to(device)
    optimizer.zero_grad()
    outputs = resnet18(images)
    loss = F.cross_entropy(outputs, targets)
    loss.backward()
    optimizer.step()

    if batch_idx%10 == 0:
      print(f'EPOCH: {epoch} [{batch_idx*len(images)}/{len(train_cifar10_dataset)}] Loss: {loss.item()}')

In [None]:
# Train Resnet for 1 epoch using CIFAR-10 Dataset
train(1)

EPOCH: 1 [0/50000] Loss: 0.660830020904541
EPOCH: 1 [1000/50000] Loss: 0.5830726027488708
EPOCH: 1 [2000/50000] Loss: 0.5147011876106262
EPOCH: 1 [3000/50000] Loss: 0.5107433795928955
EPOCH: 1 [4000/50000] Loss: 0.478744238615036
EPOCH: 1 [5000/50000] Loss: 0.39798057079315186
EPOCH: 1 [6000/50000] Loss: 0.31346216797828674
EPOCH: 1 [7000/50000] Loss: 0.25302502512931824
EPOCH: 1 [8000/50000] Loss: 0.39963868260383606
EPOCH: 1 [9000/50000] Loss: 0.4802074730396271
EPOCH: 1 [10000/50000] Loss: 0.23381084203720093
EPOCH: 1 [11000/50000] Loss: 0.41112786531448364
EPOCH: 1 [12000/50000] Loss: 0.30601832270622253
EPOCH: 1 [13000/50000] Loss: 0.24655316770076752
EPOCH: 1 [14000/50000] Loss: 0.39226046204566956
EPOCH: 1 [15000/50000] Loss: 0.29811060428619385
EPOCH: 1 [16000/50000] Loss: 0.2812885046005249
EPOCH: 1 [17000/50000] Loss: 0.3019710183143616
EPOCH: 1 [18000/50000] Loss: 0.23746557533740997
EPOCH: 1 [19000/50000] Loss: 0.43034327030181885
EPOCH: 1 [20000/50000] Loss: 0.300741374492

In [None]:
# Test RESNET18 Model on CIFAR-10 Dataset after training
test()

Accuracy: 0.9199999570846558
Accuracy: 0.9149999618530273
Accuracy: 0.9100000262260437
Accuracy: 0.9074999690055847
Accuracy: 0.9000000357627869
Accuracy: 0.9100000262260437
Accuracy: 0.9142857193946838
Accuracy: 0.9112499952316284
Accuracy: 0.9100000262260437
Accuracy: 0.9130000472068787
Accuracy: 0.9154545664787292
Accuracy: 0.9183333516120911
Accuracy: 0.9169231057167053
Accuracy: 0.9178571105003357
Accuracy: 0.9160000085830688
Accuracy: 0.9181249737739563
Accuracy: 0.9182352423667908
Accuracy: 0.9166666865348816
Accuracy: 0.9157894253730774
Accuracy: 0.9160000681877136
Accuracy: 0.9142857193946838
Accuracy: 0.9127272963523865
Accuracy: 0.9095652103424072
Accuracy: 0.909583330154419
Accuracy: 0.9107999801635742
Accuracy: 0.9084615111351013
Accuracy: 0.9096296429634094
Accuracy: 0.9092857241630554
Accuracy: 0.9093104004859924
Accuracy: 0.909333348274231
Accuracy: 0.9090322852134705
Accuracy: 0.9096874594688416
Accuracy: 0.9115151762962341
Accuracy: 0.9097058176994324
Accuracy: 0.9091