# Laboratorio 1

## Libraries

In [None]:
import sys
import os

import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchmetrics
import torch.optim as optim
import torchvision.transforms as T
from torch.utils.data import DataLoader
from torchvision.datasets import CIFAR10

import wandb
from tqdm import tqdm

from model import ConvolutionalNeuralNetworks
from utils import (get_loaders, save_checkpoint, load_checkpoint, 
                   load_best_model, metrics, eval_fn, 
                   create_directory_if_does_not_exist, EarlyStopping, Lion)

## Models

In [None]:
class FirstConvLayer(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size):
        super(FirstConvLayer, self).__init__()
        self.sequential = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size))

    def forward(self, x):
        return self.sequential(x)


class ConvolutionalBlock(nn.Module):
    def __init__(self, in_channels, out_channels, want_shortcut, downsample, last_layer, pool_type):
        super(ConvolutionalBlock, self).__init__()

        self.want_shortcut = want_shortcut
        if self.want_shortcut:
            self.shortcut == nn.Sequential(
                nn.Conv2d(in_channels, out_channels,
                          kernel_size=1, stride=2, bias=False),
                nn.BatchNorm2d(out_channels)
            )

        self.sequential = nn.Sequential(
            nn.BatchNorm2d(in_channels),
            nn.ReLu(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                      kernel_size=3, padding='same', bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLu()
        )

        self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                               kernel_size=3, stride=1, padding=1, bias=False)

        if downsample:
            if last_layer:
                self.want_shortcut = False
                self.sequential.append(nn.AdaptiveMaxPool2d(2))
            else:
                if pool_type == 'convolution':
                    self.conv1 = nn.Conv2d(
                        in_channels=in_channels, out_channels=in_channels, kernel_size=3, stride=2, padding=1, bias=False)
                elif pool_type == 'kmax':
                    channels = [64, 128, 256, 512]
                    dimension = [511, 256, 128]
                    index = channels.index(in_channels)
                    self.sequential.append(
                        nn.AdaptiveMaxPool2d(dimension[index]))
                else:
                    self.sequential.append(nn.MaxPool2d(
                        kernel_size=3, stride=2, padding=1))

        self.relu = nn.ReLu()

    def forward(self, x):
        if self.want_shortcut:
            short = x
            out = self.conv1(x)
            out = self.sequential(out)
            if out.shape != short.shape:
                short = self.shortcut(short)
            out = self.relu(short + out)
            return out
        else:
            out = self.conv1(x)
            return self.sequential(out)


class FullyConnectedBlock(nn.Module):
    def __init__(self, n_class):
        super(FullyConnectedBlock, self).__init__()
        self.sequential = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLu(),
            nn.Linear(1024, 1024),
            nn.ReLu(),
            nn.Linear(1024, n_class),
            # nn.Softmax(dim=1)
        )

    def forward(self, x):
        return self.sequential(x)


## Utils

In [None]:
class Lion(optim.Optimizer):
    def __init__(self, params, lr=1e-4, betas=(0.9, 0.99), weight_decay=0.0):
        if not 0.0 <= lr:
            raise ValueError('Invalid learning rate: {}'.format(lr))
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0]))
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1]))
        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()
        
        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                p.data.mul_(1 - group['lr'] * group['weight_decay'])

                grad = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['exp_avg'] = torch.zeros_like(p)

                exp_avg = state['exp_avg']
                beta1, beta2 = group['betas']

                update = exp_avg * beta1 + grad * (1 - beta1)
                p.add_(torch.sign(update), alpha=-group['lr'])
                exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2)

        return loss

## Functions

In [None]:
def train_fn(epoch, loader, model, optimizer, scheduler, loss_fn, scaler, metric_collection, device):
    model.train()
    running_loss = 0

    for (data, target) in tqdm(loader, desc=f"Epoch {epoch + 1}"):
        data = data.to(device, non_blocking=True)
        target = target.to(device, non_blocking=True)

        # forward
        with torch.cuda.amp.autocast():
            prediction = model(data)
            loss = loss_fn(prediction, target)

        # backward
        optimizer.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.updata()
        scheduler.step()
        running_loss += loss.item()
        metric_collection(prediction, target)

    train_loss = running_loss / len(loader)
    train_accuracy = metric_collection("MulticlassAccuracy").compute().cpu() * 100
    
    metric_collection.reset()

    return train_loss, train_accuracy


def main(wb, checkpoint_dir, weight_dir, device, num_workers):
    model = ConvolutionalNeuralNetworks(depth=wb.config['depth'],
                                        n_classes=wb.config['n_class'],
                                        want_shortcut=wb.config['want_shortcut'],
                                        pool_type=wb.config['pool_type']).to(device)
    
    optimizer = Lion(model.parameters(), lr=wb.config['learning_rate'], weight_decay=wb.config['weight_decay'])
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler()
    metric_collection = metrics(wb, device)

    if wb.config['evaluation']:
        load_best_model(torch.load(weight_dir), model)
        test_loader = get_loaders(wb.config['batch_size'], num_workers, training=False)
        eval_fn(test_loader, model, criterion, metric_collection, device)
        sys.exit()

    if wb.resumed:
        start, monitored_value, count = load_checkpoint(torch.load(checkpoint_dir), model, optimizer)
        patience = EarlyStopping('max', wb.config['patience'], count, monitored_value)
    else:
        start = 0
        patience = EarlyStopping('max', wb.config['patience'])

    train_loader, test_loader = get_loaders(wb.config['batch_size'], num_workers)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr=wb.config['max_lr'],
                                                    steps_per_epoch=len(train_loader),
                                                    epochs=wb.config['num_epochs'] - start)

    wb.watch(model, log="all")

    for epoch in range(start, wb.config['num_epochs']):
        train_loss, train_accuracy = train_fn(epoch, train_loader, model, optimizer, scheduler, criterion, scaler,
                                              metric_collection, device)
        
        test_loss, test_accuracy = eval_fn(test_loader, model, criterion, metric_collection, device)

        wb.log({'train_loss': train_loss,
                'train_accuracy': train_accuracy,
                'test_loss': test_loss,
                'test_accuracy': test_accuracy
                })
        
        # save best model
        if patience(test_accuracy):
            wb.log({
                "accuracy_epoch": epoch,
            })
            checkpoint = {"state_dict": model.state_dict()}
            save_checkpoint("=> Best model found", checkpoint, weight_dir)

        # save checkpoint
        checkpoint = {
            'start': epoch + 1,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'max_accuracy': getattr(patience, 'baseline'),
            'count': getattr(patience, 'count'),
        }
        save_checkpoint('=> Saving checkpoint', checkpoint, checkpoint_dir)

        # early stopping
        if getattr(patience, 'count') == 0:
            print('=> Patience finished')
            break

    sys.exit()


def save_checkpoint(string, state, directory):
    print(string)
    torch.save(state, "".join([directory, "model.pth.tar"]))


def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint['state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint['start'], checkpoint['max_accuracy'], checkpoint['count']


def create_directory_if_does_not_exist(dirs):
    for dir in dirs:
        if not os.path.exists(dir):
            os.makedirs(dir)


def load_best_model(checkpoint, model):
    print("Loading best model")
    model.load_state_dict(checkpoint['state_dict'])


def get_loaders(batch_size, num_workers, training=True):
    transform = T.Compose([
        T.AutoAugment(policy=T.AutoAugmentPolicy.CIFAR10),
        T.ToTensor(),
        T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
    
    test_ds = CIFAR10(
        root='./Dataset/',
        train=False,
        download=True,
        transform=transform
    )

    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        num_workers=num_workers,
        pin_memory=True,
        shuffle=True,
        persistent_workers=True,
    )

    if training:
        train_ds = CIFAR10(
            root='./Dataset/',
            train=True,
            download=True,
            transform=transform
        )

        train_loader = DataLoader(
            train_ds,
            batch_size=batch_size,
            num_workers=num_workers,
            pin_memory=True,
            shuffle=False,
            persistent_workers=True,
        )
        return train_loader, test_loader
    
    return test_loader


def metrics(wb, device):
    metric_collection = torchmetrics.MetricCollection([
        torchmetrics.classification.MulticlassAccuracy(num_classes=wb.config['num_class']).to(device=device)
    ])
    wb.define_metric("test_loss", summary="min")
    wb.define_metric("test_accuracy", summary="max")
    wb.define_metric("accuracy_epoch")
    return metric_collection


def eval_fn(loader, model, criterion, metric_collection, device):
    model.eval()
    running_loss = 0

    with torch.no_grad():
        for data, target in loader:
            data = data.to(device, non_blocking=True)
            target = target.to(device, non_blocking=True)
            prediction = model(data)
            loss = criterion(prediction, target)
            running_loss += loss.item()
            metric_collection(prediction, target)

    loss = running_loss / len(loader)
    accuracy = metric_collection['MulticlassAccurcay'].compute().cpu() * 100

    print(f"Got on test set --> Accuracy: {accuracy:.3f} and Loss: {loss:.3f}")

    metric_collection.reset()
    return loss, accuracy


def save_plot(train_l, train_a, test_l, test_a):
    plt.plot(train_a, "-")
    plt.plot(test_a, "-")
    plt.xlabel("epoch")
    plt.ylabel("accuracy")
    plt.legend(["Train", "Valid"])
    plt.title("Train Vs. Valid Accuracy")
    plt.savefig('result/accuracy')
    plt.close()

    plt.plot(train_l, "-")
    plt.plot(test_l, "-")
    plt.xlabel("epoch")
    plt.ylabel("losses")
    plt.legend(["Train", "Valid"])
    plt.title("Train Vs. Valid Losses")
    plt.savefig('result/losses')
    plt.close()


class EarlyStopping:
    def __init__(self, mod, patience, count=None, baseline=None):
        self.patience = patience
        self.count = patience if count is None else count
        if mod == 'max':
            self.baseline = 0
            self.operation = self.max
        if mod == 'min':
            self.baseline = baseline
            self.operation = self.min


    def max(self, monitored_value):
        if monitored_value > self.baseline:
            self.baseline = monitored_value
            self.count = self.patience
            return True
        else:
            self.count -= 1
            return False
        
    
    def min(self, monitored_value):
        if monitored_value < self.baseline:
            self.baseline = monitored_value
            self.count = self.patience
            return True
        else:
            self.count -= 1
            return False
        
    
    def __call__(self, monitored_value):
        return self.operation(monitored_value)



## Main

In [None]:
""" 
wab = wandb.init(
    # set the wandb project where this run will be logged
    project="DLA - Lab1",
    # group="Experiment",
    tags=[],
    resume=False,
    name="depth-48-skip",
    config={
        # model parameters
        "architecture": "Convolutional Neural Networks",
        "depth": 48,
        "n_class": 10,
        "want_shortcut": True,
        "pool_type": "max",

        # datasets
        "dataset": "CIFAR-10",

        # hyperparameters
        "learning_rate": 5e-5,
        "batch_size": 2048,
        "optimizer": "Lion",
        "weight_decay": 1e-2,
        "scheduler": "One Cycle Learning",
        "max_lr": 5e-4,
        "num_epochs": 200,
        "patience": 20,

        # run type
        "evaluation": False,
    })
 """

In [None]:
class Parameters:
    def __init__(self, dict):
        self.config = dict

In [None]:
parameters = {
        # model parameters
        "architecture": "CNN",
        "depth": 48,
        "n_class": 10,
        "want_shortcut": True,
        "pool_type": "max",

        # dataset
        "dataset": "CIFAR-10",

        # hyperparameters
        "learning_rate": 5e-5,
        "batch_size": 2048,
        "optimizer": "Lion",
        "weight_decay": 1e-2,
        "scheduler": "One Cycle Learning",
        "max_lr": 5e-4,
        "num_epochs": 200,
        "patience": 20,

        # run type
        "evaluation": False,
}


In [None]:
wab = Parameters(parameters)
checkpoint_dir = "".join(["checkpoint/"])
results_dir = "".join(["results/", wab.config['architecture'], '/'])
create_directory_if_does_not_exist([checkpoint_dir, results_dir])
dev = "cuda" if torch.cuda.is_available() else "cpu"
n_workers = 16
main(wab, checkpoint_dir, results_dir, dev, n_workers)