### Report with experiments available here:
https://wandb.ai/podcast-o-rybach-warsaw-university-of-technology/iml_lab2/reports/IML-Lab-2--Vmlldzo5OTk2OTA0?accessToken=podfkih4w73w3doiz9hjnd8ayb9fg0kw9jz8pfmut6y8h9wh39o1y6brzplvqoxh

# Training CNNs

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import os

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import wandb
wandb.login()

### Use GPU if available

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

### Pytorch transformations for loading the data

In [None]:
def load_dataset(name, input_resolution):
    if name == 'FashionMNIST':
        mean, std = (0.286, ), (0.338, )
    elif name == 'imagenette':
        mean, std = (0.462, 0.458, 0.430), (0.270, 0.267, 0.290)
    elif name == 'imagewoof':
        mean, std = (0.486, 0.456, 0.394), (0.248, 0.241, 0.250)
    else:
        raise Exception(f'Dataset {name} not found')
    input_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((input_resolution, input_resolution)),
        transforms.Normalize(mean, std)
    ])
    if name == 'FashionMNIST':
        trainset = torchvision.datasets.FashionMNIST(
            './data',
            train=True,
            download=True,
            transform=input_transform
        )
        valset = torchvision.datasets.FashionMNIST(
            './data',
            train=False,
            download=True,
            transform=input_transform
        )
        is_grayscale = True
        class_names = trainset.classes
    elif name == 'imagewoof':
        trainset = torchvision.datasets.ImageFolder(
            './data/imagewoof2-160/train',
            transform=input_transform
        )
        valset = torchvision.datasets.ImageFolder(
            './data/imagewoof2-160/val',
            transform=input_transform
        )
        is_grayscale = False
        class_names = trainset.classes
    elif name == 'imagenette':
        trainset = torchvision.datasets.Imagenette(
            './data/imagenette_train',
            split='train',
            size='160px',
            transform=input_transform
        )
        valset = torchvision.datasets.Imagenette(
            './data/imagenette_val',
            split='val',
            size='160px',
            transform=input_transform
        )
        is_grayscale = False
        class_names = [names[0] for names in trainset.classes]
    return trainset, valset, is_grayscale, class_names


### Define the models

In [None]:
class OurCNN(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNN, self).__init__()
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 96, 3)
        self.conv3 = nn.Conv2d(96, 128, 3)
        self.pool_last = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(128 * 4 * 4, 200)
        self.fc2 = nn.Linear(200, 100)
        self.fc3 = nn.Linear(100, num_outputs)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool_last(F.relu(self.conv3(x)))
        x = x.view(-1, self.fc1.in_features)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class OurCNNDropout(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNNDropout, self).__init__()
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 96, 3)
        self.conv3 = nn.Conv2d(96, 128, 3)
        self.pool_last = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(128 * 4 * 4, 200)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(200, 100)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(100, num_outputs)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool_last(F.relu(self.conv3(x)))
        x = x.view(-1, self.fc1.in_features)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(self.dropout1(x)))
        x = self.fc3(self.dropout2(x))
        return x

class OurCNNBatchNormAfterConvs(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNNBatchNormAfterConvs, self).__init__()
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 5, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(64, 96, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(96)
        self.conv3 = nn.Conv2d(96, 128, 3, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool_last = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(128 * 4 * 4, 200)
        self.fc2 = nn.Linear(200, 100)
        self.fc3 = nn.Linear(100, num_outputs)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv1(x))))
        x = self.pool(F.relu(self.bn2(self.conv2(x))))
        x = self.pool_last(F.relu(self.bn3(self.conv3(x))))
        x = x.view(-1, self.fc1.in_features)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

class OurCNNVGGLike(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNNVGGLike, self).__init__()
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 3)
        self.conv2 = nn.Conv2d(64, 64, 3, bias=False)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 96, 3)
        self.conv4 = nn.Conv2d(96, 96, 3)
        self.conv5 = nn.Conv2d(96, 96, 3, bias=False)
        self.conv6 = nn.Conv2d(96, 128, 3)
        self.conv7 = nn.Conv2d(128, 128, 3, bias=False)
        self.pool_last = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(128 * 3 * 3, 200)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(200, 100)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(100, num_outputs)

    def forward(self, x):
        x = self.pool(F.relu(self.conv2(F.relu(self.conv1(x)))))
        x = self.pool(F.relu(self.conv5(F.relu(self.conv4(F.relu(self.conv3(x)))))))
        x = self.pool_last(F.relu(self.conv7(F.relu(self.conv6(x)))))
        x = x.view(-1, self.fc1.in_features)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(self.dropout1(x)))
        x = self.fc3(self.dropout2(x))
        return x

class OurCNNVGGLikeBatchNorm(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNNVGGLikeBatchNorm, self).__init__()
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 3)
        self.conv2 = nn.Conv2d(64, 64, 3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 96, 3)
        self.conv4 = nn.Conv2d(96, 96, 3)
        self.conv5 = nn.Conv2d(96, 96, 3, bias=False)
        self.bn2 = nn.BatchNorm2d(96)
        self.conv6 = nn.Conv2d(96, 128, 3)
        self.conv7 = nn.Conv2d(128, 128, 3, bias=False)
        self.bn3 = nn.BatchNorm2d(128)
        self.pool_last = nn.MaxPool2d(4, 4)
        self.fc1 = nn.Linear(128 * 3 * 3, 200)
        self.dropout1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(200, 100)
        self.dropout2 = nn.Dropout(p=0.5)
        self.fc3 = nn.Linear(100, num_outputs)

    def forward(self, x):
        x = self.pool(F.relu(self.bn1(self.conv2(F.relu(self.conv1(x))))))
        x = self.pool(F.relu(self.bn2(self.conv5(F.relu(self.conv4(F.relu(self.conv3(x))))))))
        x = self.pool_last(F.relu(self.bn3(self.conv7(self.conv6(x)))))
        x = x.view(-1, self.fc1.in_features)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(self.dropout1(x)))
        x = self.fc3(self.dropout2(x))
        return x

class OurCNNResnetLike(nn.Module):
    def __init__(self, input_is_grayscale, num_outputs):
        super(OurCNNResnetLike, self).__init__()
        self.residual_reshape1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 1, stride=2, bias=False)
        self.conv1 = nn.Conv2d(1 if input_is_grayscale else 3, 64, 3, padding=1, stride=2)
        self.conv2 = nn.Conv2d(64, 64, 3, padding=1, bias=False)
        self.bn1_res = nn.BatchNorm2d(64)
        self.bn1 = nn.BatchNorm2d(64)
        self.residual_reshape2 = nn.Conv2d(64, 128, 1, stride=2, bias=False)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1, stride=2)
        self.conv4 = nn.Conv2d(128, 128, 3, padding=1, bias=False)
        self.bn2_res = nn.BatchNorm2d(128)
        self.bn2 = nn.BatchNorm2d(128)
        self.residual_reshape3 = nn.Conv2d(128, 256, 1, stride=2, bias=False)
        self.conv5 = nn.Conv2d(128, 256, 3, padding=1, stride=2)
        self.conv6 = nn.Conv2d(256, 256, 3, padding=1, bias=False)
        self.bn3_res = nn.BatchNorm2d(256)
        self.bn3 = nn.BatchNorm2d(256)
        self.pool_last = nn.MaxPool2d(10, 10)
        self.fc1 = nn.Linear(256, num_outputs)

    def forward(self, x):
        x = F.relu(self.bn1_res(self.residual_reshape1(x)) + self.bn1(self.conv2(F.relu(self.conv1(x)))))
        x = F.relu(self.bn2_res(self.residual_reshape2(x)) + self.bn2(self.conv4(F.relu(self.conv3(x))))))
        x = self.pool_last(F.relu(self.bn3_res(self.residual_reshape3(x)) + self.bn3(F.relu(self.conv6(F.relu(self.conv5(x)))))))
        x = x.view(-1, self.fc1.in_features)
        x = self.fc1(x)
        return x

settings_models = [
    OurCNN,
    OurCNNDropout,
    OurCNNBatchNormAfterConvs,
    OurCNNVGGLike,
    OurCNNVGGLikeBatchNorm,
    OurCNNResnetLike
]

### Train. Save loss and accuracy for train and validation

In [None]:
def train(model, trainloader, valloader, criterion, optimizer, scheduler, max_num_epochs):
    model.to(device)
    PRINT_STEP = len(trainloader) // 5 - 1
    epochs_without_val_acc_improvement = 0
    best_val_acc = 0.0

    for epoch in range(0, max_num_epochs):
        print(f'Epoch {epoch}')
        running_loss = 0.0
        correct = 0
        for i, data in enumerate(trainloader):
            inputs, labels = data
            inputs = inputs.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            _, predictions = torch.max(outputs.data, 1)
            correct += (predictions == labels).float().mean().item()

            running_loss += loss.item()
            if i % PRINT_STEP == PRINT_STEP-1:
                accuracy = correct / PRINT_STEP
                loss = running_loss / PRINT_STEP
                step = epoch * len(trainloader) + i
                wandb.log({
                        "train/accuracy": accuracy,
                        "train/loss": loss
                    },
                    step=step
                )
                running_loss = 0.0
                correct = 0

        model.eval()
        val_loss = 0.0
        val_correct = 0
        with torch.no_grad():
            for j, data in enumerate(valloader, 0):
                inputs, labels = data
                inputs = inputs.to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()
                _, predictions = torch.max(outputs.data, 1)
                val_correct += (predictions == labels).float().mean().item()

        accuracy = val_correct / len(valloader)
        loss = val_loss / len(valloader)
        wandb.log({
                "validation/accuracy": accuracy,
                "validation/loss": loss
            },
            step = (epoch + 1) * len(trainloader)
        )
        if accuracy > best_val_acc:
            best_val_acc = accuracy
            epochs_without_val_acc_improvement = 0
        else:
            epochs_without_val_acc_improvement += 1
        model.train()
        scheduler.step()
        if epochs_without_val_acc_improvement >= 10:
            print("10 epochs without a val accuracy improvement. Stopping the train")
            return

    print('Finished Training')

In [None]:
settings = [(m, lr) for m in settings_models for lr in [0.01, 0.003, 0.03]]

### Run experiments

In [None]:
LR_DECAY = 0.95
MOMENTUM = 0.9
INPUT_RESOLUTION = 80
OPTMIZER = optim.SGD

for DATASET_NAME in ['FashionMNIST', 'imagenette', 'imagewoof']:
    for MODEL, LEARNING_RATE in settings:
        trainset, valset, input_is_grayscale, class_names = load_dataset(
            DATASET_NAME, INPUT_RESOLUTION
        )
        trainloader = torch.utils.data.DataLoader(
            trainset,
            batch_size=64,
            num_workers=2,
            shuffle=True
        )
        valloader = torch.utils.data.DataLoader(
            valset,
            batch_size=64,
            shuffle=False,
            num_workers=2
        )
        print(f'Found {len(trainloader)} train and {len(valloader)} val batches')
        print(f'classes: {class_names}')

        model = MODEL(input_is_grayscale, len(class_names))
        criterion = nn.CrossEntropyLoss()
        
        optimizer = OPTMIZER(model.parameters(), lr=LEARNING_RATE, momentum=MOMENTUM)
        if MODEL == OurCNNResnetLike:
            scheduler1 = optim.lr_scheduler.ConstantLR(optimizer, factor=0.1, total_iters=4)
            scheduler2 = optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_DECAY)
            scheduler = optim.lr_scheduler.SequentialLR(
                optimizer, schedulers=[scheduler1, scheduler2], milestones=[4]
            )
        else:
            scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=LR_DECAY)
        num_learnable_parameters = sum([
            p.numel() for p in model.parameters() if p.requires_grad
        ])
        print(f'{num_learnable_parameters} learnable parameters')
        wandb.init(
            project="iml_lab2",
            config={
                "learning_rate": LEARNING_RATE,
                "learning_rate_decay": LR_DECAY,
                "momentum": MOMENTUM,
                "batch_size": trainloader.batch_size,
                "input_resolution": INPUT_RESOLUTION,
                "num_parameters": num_learnable_parameters,
                "optimizer": OPTMIZER.__name__,
                "architecture": MODEL.__name__,
                "dataset": DATASET_NAME
            }
        )
        train(model, trainloader, valloader, criterion, optimizer, scheduler, 100)

### Calculate means and standard deviations for Normalize

In [None]:
for DATASET_NAME in ['FashionMNIST', 'imagenette', 'imagewoof']:
    mean, std = 0.0, 0.0
    num_samples = 0
    INPUT_RESOLUTION = 80
    trainset, valset, input_is_grayscale, class_names = load_dataset(
        DATASET_NAME, INPUT_RESOLUTION
    )
    trainloader = torch.utils.data.DataLoader(
        trainset,
        batch_size=64,
        num_workers=2,
        shuffle=True
    )
    for i, data in enumerate(trainloader):
        inputs, labels = data
        mean = inputs.numpy().mean(axis=(0, 2, 3)) * inputs.shape[0] + mean
        std = np.std(inputs.numpy(), axis=(0, 2, 3)) * inputs.shape[0] + std
        num_samples += inputs.shape[0]
    print(f'{DATASET_NAME}:')
    print(mean / num_samples)
    print(std / num_samples)


### Save example images

In [None]:
for DATASET_NAME in ['imagenette', 'imagewoof']:
    table_data = []
    if DATASET_NAME == 'imagenette':
        train_dir = 'data/imagenette_train/imagenette2-160/train'
    elif DATASET_NAME == 'imagewoof':
        train_dir = 'data/imagewoof2-160/train'
    trainset, valset, input_is_grayscale, actual_class_names = load_dataset(
        DATASET_NAME, 80
    )
    i = 0
    for class_name in os.listdir(train_dir):
        if not os.path.isdir(train_dir + '/' + class_name):
            continue
        for file_name in os.listdir(train_dir + '/' + class_name)[:10]:
            if not file_name.endswith('JPEG'):
                continue
            filepath = f'{train_dir}/{class_name}/{file_name}'
            table_data.append([file_name, actual_class_names[i], wandb.Image(filepath)])
        i += 1
    print(table_data)
    columns = ["filename", "class", "image"]
    table = wandb.Table(data=table_data, columns=columns)
    wandb.log({f'example_samples_{DATASET_NAME}': table})