In [None]:
import time
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt 


# stem class, outputs low level features straight from the batch input
# consists of a convolution layer followed by batch normalization and ReLU activation.
# normalise to make the output easier to process, relu adding no linearity to get more complex feature
# THIS CODE IS ADAPTED FROM WEEK 7 ECS659U TUTORIAL SHEET "conv_example.ipynb"
class Stem(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.batch_norm = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        return x # output is [B, C, 32, 32]




# Expert branch creates the weights need to multiply by the corresponding convk
# uses global avg feature pooling, followed by two fully connected layers
# input from stem or previous block = [B, C, W, H]
# output = [B, K]
# THIS CODE IS ALSO ADAPTED FROM WEEK 7 ECS659U TUTORIAL SHEET "conv_example.ipynb"
class ExpertBranch(nn.Module):
    def __init__(self, channels, reduction, K):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1) #compress spatial info via avg -> [B, C, 1, 1]
        self.fc1 = nn.Linear(channels, channels // reduction) # first fully connected layer, as required reduce channels by a factor of R in my case reduction
        self.fc2 = nn.Linear(channels // reduction, K) # as required this produce -> [B, K]
        self.relu = nn.ReLU(inplace=True) # this represents g from the coursework outline, this adds non linearity to get more complex feature maps
        self.softmax = nn.Softmax(dim=1) # soft max to sum to one

    def forward(self, x):
        x = self.pool(x)
        x = torch.flatten(x, start_dim=1) # flattern the dimensions from [B, C, 1, 1] -> [B, C] needed as linear needs 2d
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x) # here dimensions go from [B, C] -> [B, K]
        a = self.softmax(x) # each batch sums to 1
        return a # [B, K]






# create K convolutional layers with added complexity, each layer includes BatchNorm and ReLU activation
class ConvLayers(nn.Module):
    def __init__(self, channels, K):
        super().__init__()
        self.convs = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(channels, channels, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(channels), # normalise features to make the output easier to process
                nn.ReLU(inplace=True)  # adds non linearity allowing the network to learn more complex functions
            )
            for _ in range(K) # repeat for k times
        ])

    #takes in input x which is either the output of the stem, or the output of a previous block, dimensions [B, C, H, W]
    def forward(self, x):
        outputs = []
        for conv in self.convs: # go through each conv layer in the convs module list
            out = conv(x)   # apply each conv path to the same input
            outputs.append(out)  # collect all outputs into a list
        return outputs  # return list of [B, C, H, W]





# single block, block consists of expert layer and 'K' conv layers
class Block(nn.Module):
    def __init__(self, channels, reduction, K):
        super().__init__()
        self.expert = ExpertBranch(channels, reduction, K)
        self.conv_layers = ConvLayers(channels, K)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        a = self.expert(x) #output of expert branch is [b, k]
        conv_outs = self.conv_layers(x)
        weighted_sum = 0
        # conv_out shape [B, C, H, W ]
        # ai shape is [b,]
        for i, conv_out in enumerate(conv_outs): # for every batches ak multiply it by the corresponding conv layers output convk
            weight = a[:, i].view(-1, 1, 1, 1) #reshape ai so that it can be multiplied by conv_out, new shape [B, 1, 1, 1]
            weighted_sum += weight * conv_out
        
        out = self.relu(weighted_sum) # add non linearity 
        
        return out # [B, C, H, W]




# the backbone consists of 'n' number of blocks, I have defined 'n' as 'num_blocks'
# Each block takes input from the output of the previous block, the first blocks takes input from the stem.
# this implementation of the backbone was inspired by https://pytorch.org/docs/stable/generated/torch.nn.ModuleList.html
class Backbone(nn.Module):
    def __init__(self, num_blocks, channels, reduction, K):
        super().__init__()
        self.blocks = nn.ModuleList([        #create a list of 'n' block modules         
            Block(channels, reduction, K) for _ in range(num_blocks)
        ])
    
    def forward(self, x):
        for block in self.blocks:
            x = block(x) #the output of a block is the input for the next block
        return x #output of the backbone is the result of going through each block




# the final high level features after all the blocks, takes in a input of [B, C, H, W]
class Classifier(nn.Module):
    def __init__(self, channels, num_classes=10):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool2d(1) # reduces each feature map to a single value which is the avg
        self.fc = nn.Linear(channels, num_classes)
    
    def forward(self, x):
        x = self.pool(x) #outputs [B, C, 1, 1]
        x = torch.flatten(x, start_dim=1) # flattern to [B, C]
        return self.fc(x) # output is [B, classes] logits, no softmax called in classifier as its called in the train one epoch, called once cross enthropy is called





# main Model class, intialise the stem, backbone and classifier
class Model(nn.Module):
    def __init__(self, num_blocks, in_channels, output_channels, reduction, K, num_classes):
        super().__init__()
        self.stem = Stem(in_channels, output_channels)  
        self.backbone = Backbone(num_blocks, output_channels, reduction, K)
        self.classifier = Classifier(output_channels, num_classes)
    
    def forward(self, x):
        x = self.stem(x) # the stem gets the input tensor 'x' dimensions [batchsize, input_channels, height, width], returns tensor size [batchsize, outchannels, 32, 32]
        x = self.backbone(x) # takes in the output of the stem, [batchsize, outchannels, 32, 32]
        x = self.classifier(x)# takes in the output of the backbones
        return x #logits















# loads the CIFAR-10 dataset with data augmentation for training
# for the normalize values i used https://github.com/kuangliu/pytorch-cifar/issues/19 , found online, specific for cifar 10.
# https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html
def load_data(batch_size, num_workers):
    transform_train = transforms.Compose([      #for the augmentations I used the same  https://juliusruseckas.github.io/ml/lightning.html
        transforms.RandomHorizontalFlip(),
        transforms.RandomCrop(32, padding=4),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261))
    ])
    # load cifar 10 datasets with transforms
    train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    
    # wrap in DataLoaders for batching and shuffling
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, test_loader




#return the loss function, optimiser and scheduler
def get_lossfunc_optimiser_scheduler(model, lr, weight_decay, num_epochs):
    lossfunc = nn.CrossEntropyLoss() # using softmax regression for classifier
    optimiser = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay) # decided on adamW as i use weight decay, says adam is not great with weight decay https://www.datacamp.com/tutorial/adamw-optimizer-in-pytorch
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimiser, T_max=num_epochs) # gradually reduce learning rate, so if i chose a bad lr it can adapt
    return lossfunc, optimiser, scheduler




"""
adapted train_one_epoch and evaluate from these
https://pytorch.org/tutorials/beginner/introyt/trainingyt.html
https://discuss.pytorch.org/t/on-running-loss-and-average-loss/107890
"""
def train_one_epoch(model, train_loader, lossfunc, optimiser, device):
    model.train()
    running_loss = 0.0
    correct = 0 # correct predictions
    total = 0

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimiser.zero_grad() # clear gradients from previous step
        outputs = model(images) # forward pass
        loss = lossfunc(outputs, labels) # compute loss
        loss.backward() # backward pass
        optimiser.step() # update weights

        running_loss += loss.item() * images.size(0) # accumulate total loss across images accounts for varying batch sizes

        _, predicted = outputs.max(1) # get class with highest score
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item() # count correct predictions

    return running_loss / total, 100.0 * correct / total # return avg loss and accuracy





# getting test accuracy
def evaluate(model, test_loader, lossfunc, device):
    model.eval() # set model to eval instead  of train
    running_loss = 0.0
    correct = 0 # correct predictions
    total = 0
    with torch.no_grad(): # disable gradient comp
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = lossfunc(outputs, labels)
            running_loss += loss.item() * images.size(0)
            
            _, predicted = outputs.max(1) # get class with highest score
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    return running_loss / total, 100.0 * correct / total  # return avg loss and accuracy






"""
trains and evaluates a CNN on the cifar 10 dataset

sets up data, model, loss, optimiser, and scheduler. Runs training for multiple epochs,
logs accuracy and loss, and plots performance metrics

"""

def main():
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("Using device:", device)

    train_loader, test_loader = load_data(batch_size=64, num_workers=2)
    num_epochs = 100

    
    model = Model(num_blocks=8, in_channels=3, output_channels=64, reduction=8, K=3, num_classes=10).to(device)
    lossfunc, optimiser, scheduler = get_lossfunc_optimiser_scheduler(model, lr=0.0003, weight_decay=0.001, num_epochs=100)


    train_acc_history, test_acc_history, train_loss_history, test_loss_history = [], [], [], []

    for epoch in range(num_epochs):
        start_time = time.time()

        train_loss, train_acc = train_one_epoch(model, train_loader, lossfunc, optimiser, device)
        test_loss, test_acc = evaluate(model, test_loader, lossfunc, device)

        scheduler.step()

        train_loss_history.append(train_loss)
        test_loss_history.append(test_loss)
        train_acc_history.append(train_acc)
        test_acc_history.append(test_acc)

        duration = time.time() - start_time
        print(f"Epoch [{epoch+1}/{num_epochs}] Train Loss: {train_loss:.4f} Train Acc: {train_acc:.2f}% "
              f"Test Loss: {test_loss:.4f} Test Acc: {test_acc:.2f}% Time: {duration:.2f}s")

    #graphs
    plt.figure(figsize=(10, 5))
    plt.plot(train_acc_history, label='Train Accuracy')
    plt.plot(test_acc_history, label='Test Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy %')
    plt.legend()
    plt.grid(True)
    plt.show()

    plt.figure(figsize=(10, 5))
    plt.plot(train_loss_history, label='Train Loss')
    plt.plot(test_loss_history, label='Test Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    plt.show()


if __name__ == '__main__':
    main()

