Run the cell below to import the necessary dependencies.

In [4]:
import torch
from dataset import CitiesDataset
from torchvision.models import resnet50
from torchvision.models import ResNet50_Weights
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch
from torch.utils.tensorboard import SummaryWriter
import numpy as np
from torch.utils.data import random_split
from torchvision import transforms
from torch.optim import lr_scheduler

In the cell below, we will define our network. 

In [5]:
class TransferLearning(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.layers = resnet50(weights=ResNet50_Weights)
        for param in self.layers.parameters():
            param.grad_required = False
        linear_layers = torch.nn.Sequential(
            torch.nn.Linear(2048, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 10),
        )
        self.layers.fc = linear_layers
        # print(self.layers)

    def forward(self, x):
        return self.layers(x)


Load the dataset. This has been created for you in this instance.

In [10]:
def train(
    model,
    train_loader,
    val_loader,
    test_loader,
    lr=0.1,
    epochs=20,
    optimiser=torch.optim.SGD
):
    """
    Trains a neural network on a dataset and returns the trained model

    Parameters:
    - model: a pytorch model
    - dataloader: a pytorch dataloader

    Returns:
    - model: a trained pytorch model
    """

    # components of a ml algortithms
    # 1. data
    # 2. model
    # 3. criterion (loss function)
    # 4. optimiser

    writer = SummaryWriter()

    # initialise an optimiser
    optimiser = optimiser(model.parameters(), lr=lr, weight_decay=0.001)
    scheduler = lr_scheduler.MultiStepLR(optimiser, milestones=[5,15], gamma=0.1,verbose=True)
    batch_idx = 0
    epoch_idx= 0
    for epoch in range(epochs):  # for each epoch
        # 
        
        print('Epoch:', epoch_idx,'LR:', scheduler.get_lr())
        epoch_idx +=1
        
        for batch in train_loader:  # for each batch in the dataloader
            features, labels = batch
            prediction = model(features)  # make a prediction
            # compare the prediction to the label to calculate the loss (how bad is the model)
            loss = F.cross_entropy(prediction, labels)
            loss.backward()  # calculate the gradient of the loss with respect to each model parameter
            optimiser.step()  # use the optimiser to update the model parameters using those gradients
            print("Epoch:", epoch, "Batch:", batch_idx,
                  "Loss:", loss.item())  # log the loss
            optimiser.zero_grad()  # zero grad
            writer.add_scalar("Loss/Train", loss.item(), batch_idx)
            batch_idx += 1
            if batch_idx % 25 == 0:
                print('Evaluating on valiudation set')
                # evaluate the validation set performance
                val_loss, val_acc = evaluate(model, val_loader)
                writer.add_scalar("Loss/Val", val_loss, batch_idx)
                writer.add_scalar("Accuracy/Val", val_acc, batch_idx)

        scheduler.step()
    # evaluate the final test set performance
    
    print('Evaluating on test set')
    test_loss = evaluate(model, test_loader)
    # writer.add_scalar("Loss/Test", test_loss, batch_idx)
    model.test_loss = test_loss
    
    return model   # return trained model
    

def evaluate(model, dataloader):
    losses = []
    correct = 0
    n_examples = 0
    for batch in dataloader:
        features, labels = batch
        prediction = model(features)
        loss = F.cross_entropy(prediction, labels)
        losses.append(loss.detach())
        correct += torch.sum(torch.argmax(prediction, dim=1) == labels)
        n_examples += len(labels)
    avg_loss = np.mean(losses)
    accuracy = correct / n_examples
    print("Loss:", avg_loss, "Accuracy:", accuracy.detach().numpy())
    return avg_loss, accuracy







 

In [11]:
size = 128
transform = transforms.Compose([
    transforms.Resize(size),
    transforms.RandomCrop((size, size), pad_if_needed=True),
    transforms.ToTensor(),
    ])

dataset = CitiesDataset(transform=transform)
train_set_len = round(0.7*len(dataset))
val_set_len = round(0.15*len(dataset))
test_set_len = len(dataset) - val_set_len - train_set_len
split_lengths = [train_set_len, val_set_len, test_set_len]
# split the data to get validation and test sets
train_set, val_set, test_set = random_split(dataset, split_lengths)

batch_size = 32
train_loader = DataLoader(train_set, shuffle=True, batch_size=batch_size)
val_loader = DataLoader(val_set, batch_size=batch_size)
test_loader = DataLoader(test_set, batch_size=batch_size)
# nn = NeuralNetworkClassifier()
# cnn = CNN()
model = TransferLearning()

trained_model=train(
                model,
                train_loader,
                val_loader,
                test_loader,
                epochs=100,
                lr=0.0001,
                optimiser=torch.optim.AdamW
                )


Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 0 LR: [0.0001]
Epoch: 0 Batch: 0 Loss: 2.304605722427368
Epoch: 0 Batch: 1 Loss: 2.3003504276275635
Epoch: 0 Batch: 2 Loss: 2.2844228744506836
Epoch: 0 Batch: 3 Loss: 2.298607110977173
Epoch: 0 Batch: 4 Loss: 2.282902717590332
Epoch: 0 Batch: 5 Loss: 2.305506706237793
Epoch: 0 Batch: 6 Loss: 2.2964835166931152
Epoch: 0 Batch: 7 Loss: 2.2878315448760986
Epoch: 0 Batch: 8 Loss: 2.272757053375244
Epoch: 0 Batch: 9 Loss: 2.2604007720947266
Epoch: 0 Batch: 10 Loss: 2.2550315856933594
Adjusting learning rate of group 0 to 1.0000e-04.
Epoch: 1 LR: [0.0001]
Epoch: 1 Batch: 11 Loss: 2.147705078125
Epoch: 1 Batch: 12 Loss: 2.156311511993408
Epoch: 1 Batch: 13 Loss: 2.2031359672546387
Epoch: 1 Batch: 14 Loss: 2.119562864303589
Epoch: 1 Batch: 15 Loss: 2.148904323577881
Epoch: 1 Batch: 16 Loss: 2.069739580154419
Epoch: 1 Batch: 17 Loss: 2.1123878955841064
Epoch: 1 Batch: 18 Loss: 2.06619930267334
Epoch: 1 Batch: 19 Loss: 1.9806499481201172
E

KeyboardInterrupt: 