# Finetuning a Convolutional Neural Network

In this exercise, you will have to finetune a pretrained CNN model on the CIFAR10 dataset. The data loading and model testing logic are already included in your code. You will have to create the model and the training loop.

**In this workspace you have GPU to help train the model but it is best practice to DISABLE it while writing code and only ENABLE it when you are training.** 

Here are the steps you need to do to complete this exercise:

1. Finish the `create_model()` function. You should use a pretrained model. You are free to choose any pre-trained model that you want to use. 
2. Finish the `train()` function. This function should validate the accuracy of the model during the training stage. You should stop the training when this validation accuracy stops increasing.
3. Save all your work and then **ENABLE** the GPU
4. Run the file to make sure that the model is training properly.
5. If it works, remember to **DISABLE** the GPU before moving to the next page. 

In case you get stuck, you can look at the solution by clicking the jupyter symbol at the top left and navigating to `finetune_a_cnn_solution.py`.

## Try It Out!
- See how your accuracy changes when using other pre-trained models.
- Play around with the number of layers and neurons in your model. How does the accuracy change? How long does it take to train the model?
- Can you create the same network in TensorFlow as well?

  ## Import libraries

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, models, transforms

## Download and load data

In [None]:
import torch.utils.data

# Define data transformations
training_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

testing_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.RandomResizedCrop(224),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])

# Download datasets
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=training_transform)
testset = datasets.CIFAR10(root='./data', train=False, download=True, transform=testing_transform)

batch_size = 32
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
        shuffle=True)
test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
        shuffle=False)

## Create and instance model

In [None]:
def create_model():
    # Instance the pre-trained model from the 'models' module
    model = models.resnet18(pretrained=True)

    # Freeze the conv layers
    for param in model.parameters():
        param.requires_grad = False

    # Add fc layer using the Sequential API
    num_feats = model.fc.in_features
    model.fc = nn.Sequential(nn.Linear(num_feats, 10))
    
    return model

In [None]:
# To run a pretrained model we might need the use of a GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Running on device {device}')

# Instance the model
model = create_model()

## Define Training and testing loops

In [None]:
def train(model, train_loader, validation_loader, epochs, loss_fn, optimizer, patience=3, baseline=1e-6):
    best_loss = baseline
    image_dataset = {'train':train_loader, 'valid':validation_loader}
    loss_counter = 0
    
    for epoch in range(epochs):
        print(f'Epoch {epoch+1}/{epochs}')

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            epoch_loss = 0.0 
            correct = 0 # To calculate epoch accuracy
            total = 0 #EXPLAIN
            # 1. Loop through data
            for data, target in image_dataset[phase]:
                # Move data and target to the same device as model
                data, target = data.to(next(model.parameters()).device), target.to(next(model.parameters()).device) #EXPLAIN next()

                # 2. Zero all gradients
                optimizer.zero_grad()
                # 3. Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    pred = model(data)
                    # 4 Compute loss
                    loss = loss_fn(pred, target)

                    if phase == 'train':
                        # 5 Backpropagation
                        loss.backward()
                        # 6 Update weights
                        optimizer.step()

                # Update epoch los...
                epoch_loss += loss.item() * data.size(0)
                _, predicted = torch.max(pred, 1) #EXPLAIN
                correct += (predicted == target).sum().item()
                total += target.size(0)

            epoch_loss = epoch_loss / len(image_dataset[phase].dataset)
            epoc_acc = correct / total * 100

            print(f'{phase.capitalize()} Loss: {epoch_loss:.4f}, Accuracy: {epoc_acc:.2f}%')

            # Save the model if validation loss decreases
            if phase == 'valid' and epoch_loss < best_loss:
                best_loss = epoch_loss
                loss_counter = 0
                #torch.save(model.state_dict(), 'best_model.pth')
            elif phase == 'valid':
                loss_counter += 1

            # Early stopping
            if loss_counter >= patience:
                print(f'Early stopping at {epoch} epoch.')

                return model
            
    return model
                

def test(model, test_loader, loss_fn):
    model.eval()
    running_loss = 0
    running_corrects = 0
    
    for inputs, labels in test_loader:
        # Move data to the same device as the model
        inputs, labels = inputs.to(next(model.parameters()).device), labels.to(next(model.parameters()).device)
        outputs = model(inputs)
        loss = loss_fn(outputs, labels)
        _, preds = torch.max(outputs, 1)
        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data).item()

    total_loss = running_loss / len(test_loader)
    total_acc = running_corrects / len(test_loader)
    print(f'Testing Loss: {total_loss}, Testing Accuracy: {100 * total_acc}')

In [None]:
# Set model configs
epochs = 50
lr = 0.001
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.fc.parameters(), lr=lr)

# Set model hyperparams

train(model, train_loader, test_loader, epochs, loss_fn, optimizer)

In [None]:
test(model, test_loader, loss_fn)

## Remember to Disable GPU when you are done training. 