# My Model

In [2]:
import torch
from torch.utils.data import ConcatDataset, Subset, DataLoader
import torchvision
import torchvision.transforms as transforms
from sklearn.model_selection import StratifiedKFold, StratifiedShuffleSplit
import numpy as np

# Seed for Reproducibility

In [3]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

<torch._C.Generator at 0x78726c9d71d0>

In [5]:
# define the CNN architecture
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self, dropout_prob):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(128)
        self.conv3 = nn.Conv2d(128, 64, 3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 32, 3, padding=1)
        self.bn4 = nn.BatchNorm2d(32)
        self.fc1 = nn.Linear(32 * 32, 512)
        self.fc2 = nn.Linear(512, 10)
        self.dropout = nn.Dropout(dropout_prob)

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.max_pool2d(x, 2)
        x = F.relu(self.bn4(self.conv4(x)))
        x = F.max_pool2d(x, 2)
        x = x.view(-1, 32 * 32)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [6]:
from torch.utils.data import Dataset

# DATA PREPROCESSING
class CustomDataset(Dataset):
    def __init__(self, dataset, indices, transform=None):
        self.dataset = dataset
        self.indices = indices
        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        image, label = self.dataset[self.indices[idx]]
        if self.transform:
            image = self.transform(image)
        return image, label

# define normalisation
def calc_mean_std(dataset):
    dataloader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)
    mean_sum = 0.
    var_sum = 0.
    total_images_count = 0
    for images, _ in dataloader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean_sum += images.mean(2).sum(0)
        var_sum += images.var(2).sum(0)
        total_images_count += batch_samples

    mean = mean_sum / total_images_count
    var = var_sum / total_images_count
    std = np.sqrt(var)

    return mean, std

# Define transformations for data augmentation
def train_transform(data, mean, std):
  data = normalize(data, mean, std)
  transform = transforms.Compose([
                                transforms.RandomHorizontalFlip(0.25),
                                transforms.RandomVerticalFlip(0.25),
                                transforms.RandomGrayscale(0.25),
                                transforms.RandomCrop(32, padding=4)
                                 ])
  return transform(data)

# Define normalisation
def normalize(data, mean, std):
  transform = transforms.Compose([
                                transforms.Normalize(mean, std)
                                ])
  return transform(data)


In [7]:
def save_checkpoint(state, filename="checkpoint.pth.tar"):
    torch.save(state, filename)

In [8]:
device = 'cuda'

# Define function for training and evaluating the model
def train_and_evaluate(model, train_loader, test_loader, criterion, optimizer, mean, std, epochs=30):
    model.to(device)
    model.train()
    # train
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0

        for inputs, train_load_labels in train_loader:
            inputs = train_transform(inputs, mean, std)
            inputs, train_load_labels = inputs.to(device), train_load_labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, train_load_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += train_load_labels.size(0)
            correct += (predicted == train_load_labels).sum().item()
            print(f'Batch Average Loss: {running_loss/total:.4f} Batch Accuracy {100*correct/total:.4f}')

        epoch_loss = running_loss / total
        epoch_accuracy = 100 * correct / total
        print(f'Epoch [{epoch + 1}/{num_epochs}] Loss: {loss.item():.4f} Accuracy {epoch_accuracy:.4f}')

    model.eval()
    correct = 0
    total = 0

    # validate
    # Early stopping parameters
    early_stopping_patience = 10  # Number of epochs to wait for improvement before stopping
    early_stopping_counter = 0    # Counter for epochs without improvement
    best_accuracy = 0             # Track the best loss

    with torch.no_grad():
        for inputs, test_load_labels in test_loader:
            inputs = normalize(inputs, mean, std)
            inputs, test_load_labels = inputs.to(device), test_load_labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += test_load_labels.size(0)
            correct += (predicted == test_load_labels).sum().item()
            validation_accuracy = 100 * correct / total

            # Check if the current validation loss is better (lower) than the best recorded loss
            if validation_accuracy > best_accuracy:
                best_accuracy = validation_accuracy
                early_stopping_counter = 0  # Reset the counter
                # Save the model checkpoint
                save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                }, filename=f"best_model_epoch_{epoch+1}.pth.tar")
            else:
                early_stopping_counter += 1

            # Check if early stopping is triggered
            if early_stopping_counter >= early_stopping_patience:
                print("Early stopping triggered")
                break

    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return model

def test(model, test_loader, mean, std):

    model.to(device)
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, test_load_labels in test_loader:
            inputs = normalize(inputs, mean, std)
            inputs, test_load_labels = inputs.to(device), test_load_labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += test_load_labels.size(0)
            correct += (predicted == test_load_labels).sum().item()


    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return model

In [None]:
# dataset hyperparameters
num_folds = 10
test_size = 0.20
batch_size = 64

# load CIFAR-10 dataset & convert to tensor
train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True,
                                         transform=transforms.ToTensor())

test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                            download=True,
                                        transform=transforms.ToTensor())

# combine train and test datasets for stratified splitting
combined_set = ConcatDataset([train_set, test_set])

# STRATIFIED SPLIT
# collect the labels
labels = [y for _, y in combined_set]

# stratified split subset indices
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
train_idx, test_idx = next(sss.split(np.zeros(len(labels)), labels))

# subset using indices & collect associated labels
stratified_train_set, stratified_train_labels = Subset(combined_set, train_idx).dataset, [labels[i] for i in range(len(labels)) if i in train_idx]
stratified_test_set, stratified_test_labels = Subset(combined_set, test_idx).dataset, [labels[i] for i in range(len(labels)) if i in test_idx]

# create StratifiedKFold object for train set only
skf = StratifiedKFold(n_splits=num_folds)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 62972565.96it/s]


Extracting ./data/cifar-10-python.tar.gz to ./data
Files already downloaded and verified


In [None]:
# load the best model
# best_checkpoint = torch.load("best_model_epoch_X.pth.tar")  # Replace X with the epoch number
# model.load_state_dict(best_checkpoint['state_dict'])

In [40]:
# Model hyperparameters
dropout_prob = 0.25
num_epochs = 20

# call function
# Main loop for k-fold cross-validation
for fold, (train_fold_indices, val_fold_indices) in enumerate(skf.split(train_idx, stratified_train_labels)):
    print(f'Fold {fold + 1}/{num_folds}')
    mean, std = calc_mean_std(Subset(stratified_train_set, train_fold_indices))

    train_sampler = torch.utils.data.SubsetRandomSampler(train_fold_indices)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_fold_indices)


    train_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=train_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    val_loader = torch.utils.data.DataLoader(
        dataset=stratified_train_set,
        batch_size=batch_size,
        sampler=val_sampler,
        worker_init_fn=seed_worker,
        generator=g)

    model = Net(dropout_prob).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    train_and_evaluate(model, train_loader, val_loader, criterion, optimizer, mean, std, epochs=num_epochs)

Fold 1/10


RuntimeError: ignored

In [21]:
# test model
test_loader = torch.utils.data.DataLoader(
    dataset=stratified_test_set,
    batch_size=batch_size,
    worker_init_fn=seed_worker,
    generator=g)

test(model, test_loader, mean, std)



Test Accuracy: 76.85%


Net(
  (conv1): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn1): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc1): Linear(in_features=1024, out_features=512, bias=True)
  (fc2): Linear(in_features=512, out_features=10, bias=True)
  (dropout): Dropout(p=0, inplace=False)
)

In [16]:
params = list(model.parameters())
print(len(params))
print(params[0].size())  # conv1's .weight

16
torch.Size([16, 3, 3, 3])
