### BOGOMOL Neural Network

This notebook was created to test the conception of BOGOMOL Convolution Networks

Importing all needed libraries

In [None]:
import os

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchsummary import summary

import data.transformations as tfms
from data.datasets import CifarDataset
from bogomol.models import ImageClassifier
from utils.train_steps import training_step, validation_step
from utils.optimizer import configure_optimizer

Specifying CUDA if possible and setting multiprocessing start method to 'spawn', instead of 'fork', as it works better on UNIX-systems

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

##### Pre-train on tiny-imagenet 

In [None]:
dataset = load_dataset('uoft-cs/cifar10', cache_dir='./.datasets')
train_set = dataset['train']
val_set = dataset['test']
num_classes = len(train_set.features['label'].names)

In [None]:
IMG_SIZE = 64
train_transforms = tfms.train_transform((IMG_SIZE, IMG_SIZE), tfms.CIFAR_MEAN, tfms.CIFAR_STD)
val_transforms = tfms.val_transform((IMG_SIZE, IMG_SIZE), tfms.CIFAR_MEAN, tfms.CIFAR_STD)

In [None]:
train_ds = CifarDataset(train_set, train_transforms)
val_ds = CifarDataset(val_set, val_transforms)

Creating simple dataloader

In [None]:
BATCH_SIZE = 64

cutmix_or_mixup = tfms.cutmix_or_mixup(num_classes)

def collate_fn(batch):
    images = []
    labels = []
    for image, label in batch:
        image = torch.Tensor(image)
        images.append(image)
        labels.append(label)
    images = torch.stack(images, dim=0).float()
    labels = torch.Tensor(labels).long()
    return images, labels

def train_collate_fn(batch):
    return cutmix_or_mixup(collate_fn(batch))

train_dataloader = DataLoader(train_ds, collate_fn=train_collate_fn, batch_size=BATCH_SIZE, shuffle=True, pin_memory=True, num_workers=8)
val_dataloader = DataLoader(val_ds, collate_fn=collate_fn, batch_size=BATCH_SIZE, pin_memory=True, num_workers=4)

Specifying Image Classification model

In [None]:
entities = 64
model = ImageClassifier(3, entities, num_classes, 4, 5, (IMG_SIZE, IMG_SIZE)).to(device)
model_architecture = torch.compile(model)

In [None]:
num_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
model_name = f"bogomol-{num_parameters/1e+6:.1f}M({entities})"

print(model_name)
summary(model, (3, IMG_SIZE, IMG_SIZE))

In [None]:
EPOCHS = 300
optimizer, scheduler = configure_optimizer(model, base_lr=3e-4, weight_decay=0.05, warmup_epochs=5, total_epochs=EPOCHS, steps_per_epoch=len(train_dataloader), num_cycles=3)
scaler = torch.GradScaler(device)
checkpoint_pth = './models'
best_val_loss = float('inf')
history = {
    "train_loss" : [float('inf')],
    "train_acc" : [0],
    "val_loss" : [float('inf')],
    "val_acc" : [0]
}

for i in range(EPOCHS):
    print(f"Training epoch: {i+1}/{EPOCHS}")
    train_loss, train_acc = training_step(model, train_dataloader, optimizer, scheduler, scaler, history, accumulation_steps=4)
    print(f"Average training loss: {train_loss:.3f}, accuracy: {train_acc:.3f}")
    print(f"Learning rate: {scheduler.get_last_lr()[0]}, gradient scale: {scaler.get_scale()}")
    val_loss, val_acc = validation_step(model, val_dataloader, history)
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_params = model.state_dict()
        torch.save(best_params, os.path.join(checkpoint_pth, model_name+'(cifar).pth'))
    print(f"Average validation loss: {val_loss:.3f}, accuracy: {val_acc:.3f}")
    print("-"*80)
    if train_loss is torch.nan:
        break


In [None]:
plots_pth = './plots'

plt.plot(history['train_loss'], label="Train loss")
plt.plot(history['val_loss'], label="Validation loss")

plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title(f"Model loss ({model_name})")
plt.savefig(os.path.join(plots_pth, f'{model_name}-loss(cifar10).png'))

In [None]:
plt.plot(history['val_acc'], label="Validation accuracy")

plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title(f"Model accuracy({model_name})")
plt.savefig(os.path.join(plots_pth, f'{model_name}-acc(cifar10).png'))

In [None]:
with open(f'./reports/{model_name}(cifar10).txt', 'w') as f:
    f.write(f"Train accuracy : {max(history['train_acc'])}, loss : {min(history['train_loss'])}\n")
    f.write(f"Validation accuracy : {max(history['val_acc'])}, loss : {min(history['val_loss'])}\n")