In [None]:


import torch
import torchvision
import torchvision.transforms as transforms
from torch import nn
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt



# Data Augmentation for trainig(I added a reference to the github link I got it from)
train_transform = transforms.Compose([
   transforms.RandomHorizontalFlip(),
   transforms.RandomCrop(32, padding=4),
   transforms.ToTensor(),
   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])

# Normalisation for Test
test_transform = transforms.Compose([
   transforms.ToTensor(),
   transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])


#Dataset Loaders
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=train_transform)
trainIter = DataLoader(trainset, batch_size=256, shuffle=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=test_transform)
testIter = DataLoader(testset, batch_size=256, shuffle=False)


# First layer processes the image and extracts the features
class Stem(nn.Module):
    def __init__(self, num_outputs):
        super(Stem, self).__init__()
        self.conv = nn.Conv2d(3, num_outputs, kernel_size=3, stride=1, padding=1, bias=False)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.relu(self.conv(x))


# generate weights
class Expert_branch(nn.Module):
    def __init__(self, input, r, k):
        super(Expert_branch, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(input, input // r)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(input // r, k)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x):
        out = self.avgpool(x)
        out = torch.flatten(out, 1)
        out=self.fc1(out)
        out=self.relu(out)
        out=self.fc2(out)
        return self.softmax(out)

#c onv branch with k convs
class ConvBranch(nn.Module):
    def __init__(self, input, output, k):
        super(ConvBranch, self).__init__()
        self.k = k
        self.convs = nn.ModuleList()
        for i in range(k):
            self.convs.append(
                nn.Sequential(
                    nn.Conv2d(input,output,kernel_size=3, padding=1),
                    nn.BatchNorm2d(output)
                )
            )
    def forward(self, x, a):
        out = 0
        for i in range(self.k):
            weight = a[:, i].view(-1, 1, 1, 1) # help from github
            out += weight * self.convs[i](x)
        return out

# combines expert + conv and adds skip connection(like ResNet from the lectures)

class Block(nn.Module):
    def __init__(self, input, output, k, r):
        super().__init__()
        self.expert = Expert_branch(input, r, k)
        self.conv = ConvBranch(input, output, k)
        self.relu = nn.ReLU()
        if input != output:
            self.skip = nn.Conv2d(input, output, kernel_size=1)
        else:
            self.skip = nn.Identity()

    def forward(self, x):
        a = self.expert(x)
        out = self.conv(x, a)
        out += self.skip(x) #residual connection
        out = self.relu(out)

        return out

# classifier with global pooling,mlp and dropouts
class Classifier(nn.Module):
    def __init__(self, in_channels, num_classes=10):
        super(Classifier, self).__init__()
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.flatten=nn.Flatten()
        self.fc1=nn.Linear(in_channels,1024)
        self.relu1=nn.ReLU()
        self.drop1=nn.Dropout(0.4)
        self.fc2=nn.Linear(1024,512)
        self.relu2=nn.ReLU()
        self.drop2=nn.Dropout(0.3)
        self.fc3=nn.Linear(512,num_classes)

    def forward(self, x):
        x = self.avgpool(x)
        x=self.flatten(x)
        x=self.fc1(x)
        x=self.relu1(x)
        x=self.drop1(x)
        x=self.fc2(x)
        x=self.relu2(x)
        x=self.drop2(x)
        x=self.fc3(x)
        return x




#model definition
class CIFAR10Model(nn.Module):
    def __init__(self):
        super(CIFAR10Model, self).__init__()
        self.stem = Stem(32)
        self.block1 = Block(32, 64, k=4, r=8)
        self.pool1 = nn.MaxPool2d(2)
        self.block2 = Block(64, 128, k=6, r=4)
        self.pool2 = nn.MaxPool2d(2)
        self.block3 = Block(128, 256, k=6, r=4)
        self.pool3 = nn.MaxPool2d(2)
        self.block4 = Block(256, 512, k=6, r=4)

        self.classifier = Classifier(512, 10)

    def forward(self, x):
        x = self.stem(x)
        x = self.block1(x)
        x = self.pool1(x)
        x = self.block2(x)
        x = self.pool2(x)
        x = self.block3(x)
        x= self.pool3(x)
        x= self.block4(x)
        return self.classifier(x)




device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = CIFAR10Model().to(device)



# weight initialisation
def init_weights(m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
        nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.constant_(m.weight, 1)
        nn.init.constant_(m.bias, 0)

model.apply(init_weights)

# training loop

def trainModel(model, trainData, testData, numEpochs, lr):

    lossFunction = nn.CrossEntropyLoss(label_smoothing=0.1)
    optim = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=0.0005)
    lrScheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=100)

    # lists to keep track of how things are going
    trainLossHistory = []
    testLossHistory = []
    trainAccuracy = []
    testAccuracy = []

    for currentEpoch in range(numEpochs):
        model.train()
        runningLoss = 0
        correctPredictions = 0
        totalExamples = 0
        for images, labels in trainData:
            images = images.to(device)
            labels = labels.to(device)
            optim.zero_grad()
            predictions = model(images)
            batchLoss = lossFunction(predictions, labels)
            batchLoss.backward()
            optim.step()
            runningLoss += batchLoss.item() * labels.size(0)
            correctPredictions += (predictions.argmax(dim=1) == labels).sum().item()
            totalExamples += labels.size(0)

        avgTrainLoss = runningLoss / totalExamples
        trainAcc = correctPredictions / totalExamples
        trainLossHistory.append(avgTrainLoss)
        trainAccuracy.append(trainAcc)
        model.eval()
        testLoss = 0
        testCorrect = 0
        testTotal = 0

        with torch.no_grad():
            for testImages, testLabels in testData:
                testImages = testImages.to(device)
                testLabels = testLabels.to(device)

                testOutputs = model(testImages)
                loss = lossFunction(testOutputs, testLabels)

                testLoss += loss.item() * testLabels.size(0)
                testCorrect += (testOutputs.argmax(dim=1) == testLabels).sum().item()
                testTotal += testLabels.size(0)

        avgTestLoss = testLoss / testTotal
        testAcc = testCorrect / testTotal
        testLossHistory.append(avgTestLoss)
        testAccuracy.append(testAcc)


        lrScheduler.step()

        print(f"Epoch {currentEpoch+1}/{numEpochs} | train loss: {avgTrainLoss:.3f}, test loss: {avgTestLoss:.3f} | train acc: {trainAcc:.3f}, test acc: {testAcc:.3f}")


    return trainLossHistory, testLossHistory, trainAccuracy, testAccuracy





num_epochs = 100
learning_rate = 0.001
train_losses, test_losses,train_accs, test_accs = trainModel(model, trainIter, testIter,
                                                  numEpochs=num_epochs, lr=learning_rate)

# Plot training and test losses
plt.figure()
plt.plot(range(1, num_epochs+1), train_losses, label="Train Loss")
plt.plot(range(1, num_epochs+1), test_losses, label="Test Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training vs Test Loss over epochs")
plt.legend()
plt.show()


# Plot training and test accuracies
plt.figure()
plt.plot(range(1, num_epochs+1), train_accs, label="Train Accuracy")
plt.plot(range(1, num_epochs+1), test_accs, label="Test Accuracy")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Training vs Test Loss over epochs")
plt.legend()
plt.show()