# Computational Intelligence Coursework - Ali

## Imports, Functions, Model, Data Loading

### Imports & Seed

In [1]:
# general imports
import numpy as np

# torch & data manipulation imports
import torch
from torch.utils.data import ConcatDataset, Subset, DataLoader
import torchvision
import torchvision.transforms as transforms

# model-related imports
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit

# seed for reproducibility
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

device = torch.device("mps" if torch.backends.mps.is_available() else ("cuda" if torch.backends.cuda.is_availale() else "cpu"))



### Define Custom Preprocessing Functions

In [2]:
from torch.utils.data import Dataset

# DATA PREPROCESSING
class CustomDataset(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label

# calculate mean & standard deviation based on dataset
def calc_mean_std(dataset):
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
    mean_sum = 0.
    var_sum = 0.
    total_images_count = 0
    for images, _ in dataloader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean_sum += images.mean(2).sum(0)
        var_sum += images.var(2).sum(0)
        total_images_count += batch_samples

    mean = mean_sum / total_images_count
    var = var_sum / total_images_count
    std = np.sqrt(var)

    return mean, std

# define transformations for data augmentation
def train_transform(data, mean, std):
  data = normalize(data, mean, std)
  transform = transforms.Compose([
                                transforms.RandomHorizontalFlip(0.25),
                                transforms.RandomVerticalFlip(0.25),
                                transforms.RandomGrayscale(0.25),
                                transforms.RandomCrop(32, padding=4)
                                 ])
  return transform(data)

# define normalisation
def normalize(data, mean, std):
  transform = transforms.Compose([
                                transforms.Normalize(mean, std)
                                ])
  return transform(data)


### Define Model Architecture

In [3]:
class Net(nn.Module):
    def __init__(self, dropout_prob):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 32 * 4 * 4)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

    def freeze_all_but_last():
        for name, param in self.named_parameters():
            if 'fc2' not in name:
                param.requires_grad = False

    # Extract weights from the last layer
    def extract_weights():
        return [p.data.numpy() for p in self.fc2.parameters()]

### Checkpointing

In [4]:
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

### Data Loading & Preparation

In [5]:
# load CIFAR-10 dataset & convert to tensor
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True,
                                         transform=transforms.ToTensor())

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True,
                                        transform=transforms.ToTensor())

Files already downloaded and verified
Files already downloaded and verified


In [6]:
# dataset hyperparameters
num_folds = 10
test_size = 0.20

# combine train and test datasets for stratified splitting
combined_set = ConcatDataset([train_set, test_set])

# STRATIFIED SPLIT
# collect the labels
labels = [y for _, y in combined_set]

# stratified split subset indices
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
train_idx, test_idx = next(sss.split(np.zeros(len(labels)), labels))

# subset using indices & collect associated labels
stratified_train_set, stratified_train_labels = Subset(combined_set, train_idx).dataset, [labels[i] for i in range(len(labels)) if i in train_idx]
stratified_test_set, stratified_test_labels = Subset(combined_set, test_idx).dataset, [labels[i] for i in range(len(labels)) if i in test_idx]

# create StratifiedKFold object for train set only
skf = StratifiedKFold(n_splits=num_folds)

## Baselines

### Gradient Based - Adam

In [7]:
# function for training and evaluating the model
def adam_train_and_validate(model, train_loader, test_loader, criterion, optimizer, mean, std, epochs=30):
    model.to(device)
    model.train()

    # early stopping parameters
    early_stopping_patience = 3  # number of epochs to wait for improvement before stopping
    early_stopping_counter = 0    # counter for epochs without improvement
    best_accuracy = 0             # track the best accuracy

    # train
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        i = 0

        for inputs, train_load_labels in train_loader:
            inputs = train_transform(inputs, mean, std)
            inputs, train_load_labels = inputs.to(device), train_load_labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, train_load_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += train_load_labels.size(0)
            correct += (predicted == train_load_labels).sum().item()
            i += 1
            if i % 20 == 0:
              batch_accuracy = 100 * correct / total
              print(f'{i}th Batch Loss: {loss.item():.4f} Batch Accuracy: {batch_accuracy:.4f}')

        epoch_loss = running_loss / total
        epoch_accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{epochs}] Loss: {loss.item():.4f} Epoch Accuracy: {epoch_accuracy:.4f}')

        model.eval()
        correct = 0
        total = 0

        # validate
        with torch.no_grad():
            for inputs, test_load_labels in test_loader:
                inputs = normalize(inputs, mean, std)
                inputs, test_load_labels = inputs.to(device), test_load_labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += test_load_labels.size(0)
                correct += (predicted == test_load_labels).sum().item()

        validation_accuracy = 100 * correct / total

        # check if the current validation accuracy is better than the best recorded accuracy
        if validation_accuracy > best_accuracy:
            best_accuracy = validation_accuracy
            early_stopping_counter = 0  # Reset the counter
            # save the model checkpoint
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename=f"best_model_epoch_{epoch+1}.pth.tar")
        else:
            early_stopping_counter += 1

        print(early_stopping_counter)
        # check if early stopping should be triggered
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

        print(f'Validation Accuracy: {validation_accuracy:.2f}%')
    return model

# function for testing the model
def test(model, test_loader, mean, std):

    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, test_load_labels in test_loader:
            inputs = normalize(inputs, mean, std)
            inputs, test_load_labels = inputs.to(device), test_load_labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += test_load_labels.size(0)
            correct += (predicted == test_load_labels).sum().item()


    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return model

In [8]:
# Model hyperparameters
dropout_prob = 0.30
num_epochs = 10
batch_size = 64

model = None

# call function
# Main loop for k-fold cross-validation
for fold, (train_fold_indices, val_fold_indices) in enumerate(skf.split(train_idx, stratified_train_labels)):
    print(f'Fold {fold + 1}/{num_folds}')
    mean, std = calc_mean_std(Subset(stratified_train_set, train_fold_indices))

    train_sampler = torch.utils.data.SubsetRandomSampler(train_fold_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_fold_indices)


    train_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=train_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    val_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=val_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    model = Net(dropout_prob).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    adam_train_and_validate(model, train_loader, val_loader, criterion, optimizer, mean, std, epochs=num_epochs)

Fold 1/10




20th Batch Loss: 1.9541 Batch Accuracy: 17.5781
40th Batch Loss: 2.0445 Batch Accuracy: 22.0312
60th Batch Loss: 1.7624 Batch Accuracy: 25.0781
80th Batch Loss: 1.6816 Batch Accuracy: 26.7578
100th Batch Loss: 1.7538 Batch Accuracy: 28.0469
120th Batch Loss: 1.3792 Batch Accuracy: 29.8438
140th Batch Loss: 1.7243 Batch Accuracy: 30.8482
160th Batch Loss: 1.6397 Batch Accuracy: 31.8262
180th Batch Loss: 1.7269 Batch Accuracy: 32.3177
200th Batch Loss: 1.4811 Batch Accuracy: 33.1406
220th Batch Loss: 1.6009 Batch Accuracy: 33.8778
240th Batch Loss: 1.5021 Batch Accuracy: 34.6810
260th Batch Loss: 1.6279 Batch Accuracy: 35.0361
280th Batch Loss: 1.6219 Batch Accuracy: 35.4018
300th Batch Loss: 1.5158 Batch Accuracy: 36.1250
320th Batch Loss: 1.9280 Batch Accuracy: 36.2207
340th Batch Loss: 1.6841 Batch Accuracy: 36.6360
360th Batch Loss: 1.5452 Batch Accuracy: 37.1094
380th Batch Loss: 1.3005 Batch Accuracy: 37.5493
400th Batch Loss: 1.3828 Batch Accuracy: 37.9961
420th Batch Loss: 1.4601



20th Batch Loss: 2.0649 Batch Accuracy: 18.2031
40th Batch Loss: 1.9758 Batch Accuracy: 22.8516
60th Batch Loss: 1.7973 Batch Accuracy: 26.5104
80th Batch Loss: 1.9143 Batch Accuracy: 27.8125
100th Batch Loss: 1.9717 Batch Accuracy: 28.8750
120th Batch Loss: 1.5834 Batch Accuracy: 29.7917
140th Batch Loss: 1.6843 Batch Accuracy: 30.6920
160th Batch Loss: 1.6858 Batch Accuracy: 31.4062
180th Batch Loss: 1.5582 Batch Accuracy: 32.3698
200th Batch Loss: 1.7960 Batch Accuracy: 33.3984
220th Batch Loss: 1.7438 Batch Accuracy: 33.8778
240th Batch Loss: 1.5983 Batch Accuracy: 34.7526
260th Batch Loss: 1.7247 Batch Accuracy: 35.4567
280th Batch Loss: 1.5850 Batch Accuracy: 35.9766
300th Batch Loss: 1.8746 Batch Accuracy: 36.4219
320th Batch Loss: 1.5473 Batch Accuracy: 36.9971
340th Batch Loss: 1.2830 Batch Accuracy: 37.4494
360th Batch Loss: 1.3993 Batch Accuracy: 37.8733
380th Batch Loss: 1.4981 Batch Accuracy: 38.4169
400th Batch Loss: 1.4524 Batch Accuracy: 39.0000
420th Batch Loss: 1.3246



20th Batch Loss: 2.1286 Batch Accuracy: 19.2188
40th Batch Loss: 1.7393 Batch Accuracy: 24.1797
60th Batch Loss: 1.8011 Batch Accuracy: 26.8490
80th Batch Loss: 2.0341 Batch Accuracy: 27.8125
100th Batch Loss: 1.8069 Batch Accuracy: 28.9062
120th Batch Loss: 1.5750 Batch Accuracy: 29.9479
140th Batch Loss: 1.5700 Batch Accuracy: 31.3616
160th Batch Loss: 1.5629 Batch Accuracy: 32.5586
180th Batch Loss: 1.7652 Batch Accuracy: 33.4462
200th Batch Loss: 1.5635 Batch Accuracy: 33.8906
220th Batch Loss: 1.8003 Batch Accuracy: 34.5455
240th Batch Loss: 1.3682 Batch Accuracy: 35.2799
260th Batch Loss: 1.5920 Batch Accuracy: 35.9555
280th Batch Loss: 1.8020 Batch Accuracy: 36.0770
300th Batch Loss: 1.7504 Batch Accuracy: 36.8177
320th Batch Loss: 1.3725 Batch Accuracy: 37.3291
340th Batch Loss: 1.4293 Batch Accuracy: 37.7711
360th Batch Loss: 1.3656 Batch Accuracy: 38.2161
380th Batch Loss: 1.2632 Batch Accuracy: 38.6308
400th Batch Loss: 1.4080 Batch Accuracy: 39.0117
420th Batch Loss: 1.1633



20th Batch Loss: 2.0683 Batch Accuracy: 19.0625
40th Batch Loss: 1.8711 Batch Accuracy: 23.1641
60th Batch Loss: 1.6365 Batch Accuracy: 26.4062
80th Batch Loss: 2.0086 Batch Accuracy: 28.3398
100th Batch Loss: 1.6532 Batch Accuracy: 29.1406
120th Batch Loss: 1.7631 Batch Accuracy: 30.0000
140th Batch Loss: 1.4994 Batch Accuracy: 30.7478
160th Batch Loss: 1.5679 Batch Accuracy: 31.5527
180th Batch Loss: 1.4073 Batch Accuracy: 32.1875
200th Batch Loss: 1.5987 Batch Accuracy: 33.0000
220th Batch Loss: 1.7278 Batch Accuracy: 33.5156
240th Batch Loss: 1.8046 Batch Accuracy: 34.2708
260th Batch Loss: 1.7831 Batch Accuracy: 34.8978
280th Batch Loss: 1.5956 Batch Accuracy: 35.6920
300th Batch Loss: 1.4687 Batch Accuracy: 35.9583
320th Batch Loss: 1.5782 Batch Accuracy: 36.5430
340th Batch Loss: 1.3880 Batch Accuracy: 37.1278
360th Batch Loss: 1.6288 Batch Accuracy: 37.6519
380th Batch Loss: 1.3515 Batch Accuracy: 38.1579
400th Batch Loss: 1.3508 Batch Accuracy: 38.7031
420th Batch Loss: 1.2730



20th Batch Loss: 2.0907 Batch Accuracy: 17.7344
40th Batch Loss: 1.7628 Batch Accuracy: 23.0078
60th Batch Loss: 1.8672 Batch Accuracy: 25.2604
80th Batch Loss: 1.9181 Batch Accuracy: 27.2461
100th Batch Loss: 1.4736 Batch Accuracy: 28.9531
120th Batch Loss: 1.6071 Batch Accuracy: 30.4948
140th Batch Loss: 1.6745 Batch Accuracy: 31.3281
160th Batch Loss: 1.6399 Batch Accuracy: 32.0508
180th Batch Loss: 1.6267 Batch Accuracy: 32.7431
200th Batch Loss: 1.4883 Batch Accuracy: 33.6719
220th Batch Loss: 1.4068 Batch Accuracy: 34.2259
240th Batch Loss: 1.6058 Batch Accuracy: 34.6680
260th Batch Loss: 1.5557 Batch Accuracy: 35.0841
280th Batch Loss: 1.3192 Batch Accuracy: 35.9096
300th Batch Loss: 1.1511 Batch Accuracy: 36.5833
320th Batch Loss: 1.6782 Batch Accuracy: 36.9531
340th Batch Loss: 1.2365 Batch Accuracy: 37.4632
360th Batch Loss: 1.3304 Batch Accuracy: 37.9774
380th Batch Loss: 1.2172 Batch Accuracy: 38.3594
400th Batch Loss: 1.3866 Batch Accuracy: 38.8281
420th Batch Loss: 1.1826



20th Batch Loss: 2.0818 Batch Accuracy: 19.9219
40th Batch Loss: 1.7757 Batch Accuracy: 21.9922
60th Batch Loss: 1.6884 Batch Accuracy: 24.8438
80th Batch Loss: 1.7717 Batch Accuracy: 26.9336
100th Batch Loss: 1.7133 Batch Accuracy: 27.7969
120th Batch Loss: 1.5870 Batch Accuracy: 28.9323
140th Batch Loss: 1.4024 Batch Accuracy: 30.0223
160th Batch Loss: 1.6976 Batch Accuracy: 30.1562
180th Batch Loss: 1.6227 Batch Accuracy: 31.0243
200th Batch Loss: 1.3720 Batch Accuracy: 32.0391
220th Batch Loss: 1.8053 Batch Accuracy: 32.5923
240th Batch Loss: 1.7242 Batch Accuracy: 33.3789
260th Batch Loss: 1.8640 Batch Accuracy: 33.9663
280th Batch Loss: 1.3022 Batch Accuracy: 34.7935
300th Batch Loss: 1.7208 Batch Accuracy: 35.4844
320th Batch Loss: 1.2388 Batch Accuracy: 36.0742
340th Batch Loss: 1.3369 Batch Accuracy: 36.3465
360th Batch Loss: 1.2111 Batch Accuracy: 36.8533
380th Batch Loss: 1.5502 Batch Accuracy: 37.2903
400th Batch Loss: 1.5669 Batch Accuracy: 37.8477
420th Batch Loss: 1.3933



20th Batch Loss: 1.8335 Batch Accuracy: 22.6562
40th Batch Loss: 1.9455 Batch Accuracy: 25.8984
60th Batch Loss: 1.8244 Batch Accuracy: 25.9896
80th Batch Loss: 1.9712 Batch Accuracy: 27.6953
100th Batch Loss: 1.6925 Batch Accuracy: 29.2812
120th Batch Loss: 1.6151 Batch Accuracy: 30.4948
140th Batch Loss: 1.6139 Batch Accuracy: 31.2835
160th Batch Loss: 1.3855 Batch Accuracy: 32.5586
180th Batch Loss: 1.4090 Batch Accuracy: 33.1163
200th Batch Loss: 1.4016 Batch Accuracy: 33.6172
220th Batch Loss: 1.4730 Batch Accuracy: 34.5384
240th Batch Loss: 1.5228 Batch Accuracy: 35.2799
260th Batch Loss: 1.6774 Batch Accuracy: 35.8714
280th Batch Loss: 1.4972 Batch Accuracy: 36.5290
300th Batch Loss: 1.6232 Batch Accuracy: 37.0990
320th Batch Loss: 1.5139 Batch Accuracy: 37.5781
340th Batch Loss: 1.4844 Batch Accuracy: 38.2169
360th Batch Loss: 1.5630 Batch Accuracy: 38.7196
380th Batch Loss: 1.3695 Batch Accuracy: 39.2640
400th Batch Loss: 1.5292 Batch Accuracy: 39.7695
420th Batch Loss: 1.3390



20th Batch Loss: 1.8699 Batch Accuracy: 17.2656
40th Batch Loss: 2.0061 Batch Accuracy: 23.2031
60th Batch Loss: 1.8542 Batch Accuracy: 25.6510
80th Batch Loss: 1.8772 Batch Accuracy: 26.8164
100th Batch Loss: 1.7554 Batch Accuracy: 28.6094
120th Batch Loss: 1.6898 Batch Accuracy: 29.5573
140th Batch Loss: 1.7195 Batch Accuracy: 30.5804
160th Batch Loss: 1.3718 Batch Accuracy: 31.3574
180th Batch Loss: 1.6789 Batch Accuracy: 32.3872
200th Batch Loss: 1.6408 Batch Accuracy: 33.0938
220th Batch Loss: 1.6157 Batch Accuracy: 34.0909
240th Batch Loss: 1.3278 Batch Accuracy: 34.4596
260th Batch Loss: 1.5297 Batch Accuracy: 35.0962
280th Batch Loss: 1.2825 Batch Accuracy: 35.6864
300th Batch Loss: 1.6645 Batch Accuracy: 36.3021
320th Batch Loss: 1.6629 Batch Accuracy: 36.7236
340th Batch Loss: 1.2101 Batch Accuracy: 37.3300
360th Batch Loss: 1.2473 Batch Accuracy: 37.8472
380th Batch Loss: 1.4889 Batch Accuracy: 38.1003
400th Batch Loss: 1.1911 Batch Accuracy: 38.5273
420th Batch Loss: 1.5563



20th Batch Loss: 2.1188 Batch Accuracy: 19.6094
40th Batch Loss: 2.2026 Batch Accuracy: 23.4375
60th Batch Loss: 2.0386 Batch Accuracy: 26.6927
80th Batch Loss: 1.7685 Batch Accuracy: 28.3008
100th Batch Loss: 1.7875 Batch Accuracy: 29.8906
120th Batch Loss: 1.7373 Batch Accuracy: 31.1589
140th Batch Loss: 1.6983 Batch Accuracy: 32.2879
160th Batch Loss: 1.6247 Batch Accuracy: 33.3398
180th Batch Loss: 1.7404 Batch Accuracy: 33.9670
200th Batch Loss: 1.4221 Batch Accuracy: 34.5234
220th Batch Loss: 1.9029 Batch Accuracy: 35.2202
240th Batch Loss: 1.4830 Batch Accuracy: 35.8268
260th Batch Loss: 1.3839 Batch Accuracy: 36.1959
280th Batch Loss: 1.6980 Batch Accuracy: 36.9364
300th Batch Loss: 1.7003 Batch Accuracy: 37.4792
320th Batch Loss: 1.3932 Batch Accuracy: 38.0127
340th Batch Loss: 1.5975 Batch Accuracy: 38.5018
360th Batch Loss: 1.3461 Batch Accuracy: 38.7587
380th Batch Loss: 1.7119 Batch Accuracy: 39.2270
400th Batch Loss: 1.8222 Batch Accuracy: 39.7656
420th Batch Loss: 1.3702



20th Batch Loss: 1.9922 Batch Accuracy: 18.9844
40th Batch Loss: 1.8407 Batch Accuracy: 23.8281
60th Batch Loss: 1.5453 Batch Accuracy: 26.5885
80th Batch Loss: 1.8907 Batch Accuracy: 28.7500
100th Batch Loss: 1.6059 Batch Accuracy: 30.2344
120th Batch Loss: 1.7699 Batch Accuracy: 30.7292
140th Batch Loss: 1.6722 Batch Accuracy: 31.5625
160th Batch Loss: 1.7009 Batch Accuracy: 32.0117
180th Batch Loss: 1.5569 Batch Accuracy: 32.9167
200th Batch Loss: 1.7286 Batch Accuracy: 33.5625
220th Batch Loss: 1.6428 Batch Accuracy: 34.0909
240th Batch Loss: 1.5875 Batch Accuracy: 34.7526
260th Batch Loss: 1.5934 Batch Accuracy: 35.3005
280th Batch Loss: 1.5733 Batch Accuracy: 35.9431
300th Batch Loss: 1.2503 Batch Accuracy: 36.4271
320th Batch Loss: 1.2818 Batch Accuracy: 37.1680
340th Batch Loss: 1.2581 Batch Accuracy: 37.7803
360th Batch Loss: 1.3260 Batch Accuracy: 38.4071
380th Batch Loss: 1.3808 Batch Accuracy: 38.9309
400th Batch Loss: 1.5208 Batch Accuracy: 39.4141
420th Batch Loss: 1.5732

In [9]:
# test model
test_loader = torch.utils.data.DataLoader(
    dataset=stratified_test_set,
    batch_size=batch_size,
    worker_init_fn=seed_worker,
    generator=g)

test(model, test_loader, mean, std)

Test Accuracy: 80.82%


Net(
  (conv1): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=512, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (dropout): Dropout(p=0.3, inplace=False)
)

## Population Based - Genetic Algorithm

## Proposed - Adaptive Baldwinian-Lamarckian Memetic Algorithm
## Self-regularizing Adam-guided Adaptive-SL-PSO

### Imports, Preprocessing & Definitions

In [8]:
import operator
import random
from matplotlib import pyplot as plt
import math
from deap import base
from deap import benchmarks
from deap import creator
from deap import tools
from numba import jit, cuda
from numpy import genfromtxt

In [18]:
def neighbour_search(pop, individual, neighbours=10):
    distances = {}
    for i, part in enumerate(pop):
        if individual != part:
            distance = np.linalg.norm(np.array(part)-np.array(individual))
            if len(distances) < neighbours:
                distances[i] = distance
            else:
                copy = distances.copy()
                for e in copy.keys():
                    if distances[e] > distance:
                        distances[i] = distance
                        del distances[e]
                        break
    neighbourhood = [pop[i] for i in distances.keys()]
    sort_population(neighbourhood, potential=True)
    return neighbourhood

# Ali's functions for question 3 ----------------------------------------------------------------
def sort_population(population, potential=False):
    if potential:
        population.sort(key=lambda x: x.potential, reverse=True)
    else:
        population.sort(key=lambda x: x.fitness.values, reverse=True)

In [39]:
# Function to freeze all but the last layer
def freeze_all_but_last(model):
    for name, param in model.named_parameters():
        if 'fc2' not in name:
            param.requires_grad = False

# Extract weights from the last layer
def extract_weights_biases(layer):
    return [p.data.numpy() for p in layer.parameters()]

def generate_particle(dimension):
    part = creator.Particle([random.uniform(-1, 1) for _ in range(dimension)])
    part.speed = [random.uniform(-1, 1) for _ in range(dimension)]
    part.smin = -1
    part.smax = 1
    return part

# Define the fitness function
def evaluate_particle(model, particle, inputs, labels, potential=True, local_best=False):
    weights = np.asarray(particle)


    new_weights = torch.from_numpy(weights[:weights_len].reshape(weights_dim)).float().to(device)
    new_biases = torch.from_numpy(weights[weights_len:dimension]).float().to(device)
    model.fc2.weight = torch.nn.Parameter(new_weights)
    model.fc2.bias = torch.nn.Parameter(new_biases)

    inputs, labels = torch.FloatTensor(inputs), torch.Tensor(labels)
    inputs, labels, model = inputs.to(device), labels.to(device), model.to(device)
    outputs = model(inputs)  # input and predict based on images
    loss = criterion(outputs, labels)
    optimizer.zero_grad()  # clear gradients for next train
    if potential:
            loss.backward()  # backpropagation, compute gradients
            optimizer.step()  # apply gradients
            outputs = model(inputs)  # input and predict based on images
            loss = criterion(outputs, labels)
            if local_best:
                    demonstrator_weights = model.fc2.weight.data.to('cpu').reshape(-1)
                    demosntrator_biases = model.fc2.bias.data.to('cpu').reshape(-1)
                    return np.concatenate((demonstrator_weights, demosntrator_biases))
            optimizer.zero_grad
    return loss.item(),

# social learning in a neighbourhood of size
def behaviour_learning(inputs, labels, model, gamma, gbest, part, pop, epsilon, mu, neighbours=10):
    i = pop.index(part)
    neighbour_pop = neighbour_search(pop, part, neighbours)

    index = None
    for j in range(len(neighbour_pop)-1, -1, -1):
        if neighbour_pop[j].potential[0] <= part.potential[0]:
            index = j
            break
        

    demonstrator = None

    if (i != 0) & (index is not None):
        k = math.floor(random.randrange(0, index))
        demonstrator = neighbour_pop[k]
    else:
        demonstrator = evaluate_particle(model, part, inputs, labels, potential=True, local_best=True)

    r1 = (random.uniform(0, 1) for _ in range(len(part)))
    r2 = (random.uniform(0, 1) for _ in range(len(part)))
    r3 = (random.uniform(0, 1) for _ in range(len(part)))
    ones = [1] * len(part)
    one_minus_gamma = np.asarray(ones) - gamma

    v_r0 = list(map(operator.mul, r1, part.speed))
    v_r1 = list(map(operator.mul, r2, map(operator.sub, demonstrator, part))) # local best
    v_r2 = list(map(operator.mul,r3, map(operator.mul, [epsilon*x for x in mu], part))) # global best

    sl_speed = list(map(operator.add, v_r1, v_r2))
    exploitation_speed = gamma*np.asarray(gbest)
    exploration_speed = list(map(operator.mul, one_minus_gamma, sl_speed))

    part.speed = list(map(operator.add, v_r0 , map(operator.add, exploitation_speed, exploration_speed)))
    part[:] = list(map(operator.add, part, part.speed))

In [48]:
def pso_optimize(model, toolbox, pop, inputs, labels, g):
    interval        = 10
    iterations      = 100
    neighbours = 10
    beta = 0.01
    alpha = 0.5

    m = populationSize + math.floor(dimension/10)
    epsilon = beta * (dimension/populationSize)


    gbest = None

    # eval current fitness
    for part in pop:
        part.fitness.values = toolbox.evaluate(model, part, inputs, labels, potential=False) #actually only one fitness value

    # Begin the evolution
    #for g in range(iterations):


    # A new Search
    #print("-- Search %i --" % g)

    # find the global best - lamarckian search party lead - gradient descent
    sort_population(pop, potential=False)
    gbest = pop[0]

    # eval potential (after one step of Adam descent) of all search parties
    for part in pop:
        part.potential = toolbox.evaluate(model, part, inputs, labels, potential=True) #actually only one fitness value

    # sort the the baldwinian search participants by their potential, leave the current global best as the lamarckian search-lead
    sort_population(pop[1:], potential=True)

    # parameter setting - variable
    mu = [sum(np.asarray(pop)[:,x])/populationSize for x in range(dimension)]
    gamma = 1/(1+math.exp(3 - 6*min(1-abs(gbest.potential[0]/gbest.fitness.values[0]),1)))
    i = 0

    # evolve the lamarckian lead
    pop[0].fitness.values = toolbox.evaluate(model, pop[0], inputs, labels, potential=True) #actually only one fitness value

    # evolve the local-search-groups via SL-PSO algorithm
    for part in pop[1:]:
        i = i + 1
        learn_prob = (1 - (i-1)/m)**(alpha*math.log(math.ceil(dimension/m)))
        if random.random() < learn_prob:
            toolbox.learn(inputs, labels, model, gamma, gbest, part, pop[1:], epsilon, mu, neighbours)


        #update global best
        if (not gbest) or gbest.fitness < part.fitness:
            gbest = creator.Particle(part)
            gbest.fitness.values = part.fitness.values

    # set weights to best individual
    weights = np.asarray(gbest)
    new_weights = torch.from_numpy(weights[:weights_len].reshape(weights_dim)).float().to(device)
    new_biases = torch.from_numpy(weights[weights_len:dimension]).float().to(device)
    model.fc2.weight = torch.nn.Parameter(new_weights)
    model.fc2.bias = torch.nn.Parameter(new_biases)

    # Gather all the fitnesses in one list and print the stats
    # print every interval
    fits.append(gbest.fitness.values[0])
    if g%interval==0: # interval
        logbook.record(gen=g, evals=len(pop), **stats.compile(pop))
        print(logbook.stream)
        length = len(pop)
        mean = sum(fits) / length
        sum2 = sum(x*x for x in fits)
        std = abs(sum2 / length - mean**2)**0.5


        print("  Min %s" % min(fits))
        print("  Max %s" % max(fits))
        print("  Avg %s" % mean)
        print("  Std %s" % std)
        plt.plot(fits)

In [52]:
# function for training and evaluating the model
def memetic_train_and_validate(model, toolbox, pop, train_loader, test_loader, criterion, optimizer, mean, std, epochs=30):
    model.to(device)
    model.train()

    # early stopping parameters
    early_stopping_patience = 3  # number of epochs to wait for improvement before stopping
    early_stopping_counter = 0    # counter for epochs without improvement
    best_accuracy = 0             # track the best accuracy

    # train
    for epoch in range(epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        i = 0

        for inputs, train_load_labels in train_loader:
            inputs = train_transform(inputs, mean, std)
            #inputs, train_load_labels = inputs.to(device), train_load_labels.to(device)
            i += 1
            pso_optimize(model, toolbox, pop, inputs, train_load_labels, i)
            inputs, train_load_labels = inputs.to(device), train_load_labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, train_load_labels)

            _, predicted = torch.max(outputs.data, 1)
            running_loss += loss.item()
            total += train_load_labels.size(0)
            correct += (predicted == train_load_labels).sum().item()
            if i % 20 == 0:
              batch_accuracy = 100 * correct / total
              print(f'{i}th Batch Loss: {loss.item():.4f} Batch Accuracy: {batch_accuracy:.4f}')

        epoch_accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{epochs}] Loss: {loss.item():.4f} Epoch Accuracy: {epoch_accuracy:.4f}')

        model.eval()
        correct = 0
        total = 0

        # validate
        with torch.no_grad():
            for inputs, test_load_labels in test_loader:
                #inputs = normalize(inputs, mean, std)
                inputs, test_load_labels = inputs.to(device), test_load_labels.to(device)
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += test_load_labels.size(0)
                correct += (predicted == test_load_labels).sum().item()

        validation_accuracy = 100 * correct / total

        # check if the current validation accuracy is better than the best recorded accuracy
        if validation_accuracy > best_accuracy:
            best_accuracy = validation_accuracy
            early_stopping_counter = 0  # Reset the counter
            # save the model checkpoint
            save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }, filename=f"best_model_epoch_{epoch+1}.pth.tar")
        else:
            early_stopping_counter += 1

        print(early_stopping_counter)
        # check if early stopping should be triggered
        if early_stopping_counter >= early_stopping_patience:
            print("Early stopping triggered")
            break

        print(f'Validation Accuracy: {validation_accuracy:.2f}%')
    return model

In [53]:
best_checkpoint = torch.load("best_model_epoch_10.pth.tar")  # Replace X with the epoch number
model = Net(0)
model.load_state_dict(best_checkpoint['state_dict'])

<All keys matched successfully>

In [54]:
fc2_weights = model.fc2.weight.data
weights_dim = fc2_weights.shape
weights_len = len(fc2_weights.reshape(-1))
fc2_bias = model.fc2.bias.data
bias_dim = fc2_bias.shape
bias_len = len(fc2_bias.reshape(-1))
populationSize  = 100
dimension = weights_len + bias_len

# Freeze all layers except the last
freeze_all_but_last(model)

# DEAP inits
creator.create("FitnessMin", base.Fitness, weights=(-1.0,))
creator.create("Particle", list, fitness=creator.FitnessMin, speed=list,
                smin=None, smax=None, best=None, potential=None)

toolbox = base.Toolbox()
toolbox.register("particle", generate_particle, dimension)
toolbox.register("population", tools.initRepeat, list, toolbox.particle)
toolbox.register("evaluate", evaluate_particle)
toolbox.register("learn", behaviour_learning)

# Model hyperparameters
dropout_prob = 0.30
num_epochs = 10
batch_size = 64

# call function
# Main loop for k-fold cross-validation
for fold, (train_fold_indices, val_fold_indices) in enumerate(skf.split(train_idx, stratified_train_labels)):
    print(f'Fold {fold + 1}/{num_folds}')
    mean, std = calc_mean_std(Subset(stratified_train_set, train_fold_indices))

    train_sampler = torch.utils.data.SubsetRandomSampler(train_fold_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_fold_indices)

    # create an initial population of individuals
    pop = toolbox.population(n=populationSize)
    stats = tools.Statistics(lambda ind: ind.fitness.values)
    stats.register("avg", np.mean)
    stats.register("std", np.std)
    stats.register("min", np.min)
    stats.register("max", np.max)


    logbook = tools.Logbook()
    logbook.header = ["gen", "evals"] + stats.fields

    fits = []


    train_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=train_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    val_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=val_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    memetic_train_and_validate(model, toolbox, pop, train_loader, val_loader, criterion, optimizer, mean, std, epochs=num_epochs)

Fold 1/10


