In [None]:
!pip install --upgrade -q wandb

In [None]:
import random, os
import numpy as np
import wandb

import torch
import torch.nn as nn
import torch.nn.utils.parametrize as parametrize
import torch.nn.functional as F
from torch.utils.data import default_collate

import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
from torchvision.ops import MLP

from tqdm import tqdm, trange
from kaggle_secrets import UserSecretsClient
from typing import Optional, List, Dict, Any, Tuple

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
try:
    user_secrets = UserSecretsClient()
    api_key = user_secrets.get_secret("wandb_key")
    wandb.login(key=api_key)
except Exception as e:
    print('Ups')
    print(e.what())

### CFG

In [None]:
class CFG:
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    batch_size = 512
    num_workers = 1
    seed = 42
    epochs = 20
    max_lr = 0.05
    grad_clip = 0.1
    weight_decay = 5e-4

### Seed

In [None]:
def seed_everything(seed: int):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
seed_everything(CFG.seed)

### Load CIFAR10

In [None]:
stats = ((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
        transforms.RandomHorizontalFlip(), 
        transforms.RandomRotation((-7,7)),
        transforms.ToTensor(),
        transforms.Normalize(*stats)
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(*stats)
    ]
)

train_set = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform_train)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform_test)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
train_loader = torch.utils.data.DataLoader(
    train_set,
    batch_size=CFG.batch_size,
    shuffle=True,
    num_workers=CFG.num_workers,
)

test_loader = torch.utils.data.DataLoader(
    test_set,
    batch_size=CFG.batch_size,
    shuffle=False,
    num_workers=CFG.num_workers
)

## Basic ResNet18

In [None]:
def get_simple_model():
    model = resnet18(weights=None)
    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=len(classes))
    return model.to(CFG.device)

### Training step

In [None]:
def accuracy(outputs: torch.Tensor, labels: torch.Tensor):
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def validation_step(model: nn.Module, batch: List[torch.Tensor]):
    images, labels = batch
    images = images.to(CFG.device)
    labels = labels.to(CFG.device)
    output = model(images)
    loss = F.cross_entropy(output, labels)
    return {'val_loss': loss.detach(), 'val_acc': accuracy(output, labels)}

def validation_epoch(outputs: List[Dict[str, Any]]):
    batch_losses = [x['val_loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['val_acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    
    wandb.log({"Val loss": epoch_loss.item()})
    wandb.log({"Val accuracy": epoch_acc.item()})
    
    return {'val_loss': epoch_loss.item(), 'val_acc': epoch_acc.item()}

@torch.no_grad()
def evaluate(model: nn.Module, val_loader: torch.utils.data.DataLoader):
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch(outputs)

def get_lr(optimizer: torch.optim.Optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def train_one_cycle(
    model: nn.Module, 
    train_loader: torch.utils.data.DataLoader,
    val_loader: torch.utils.data.DataLoader,
    epochs: int,
    max_lr: float,
    weight_decay: float=0.0,
    grad_clip: Optional[float]=None
):
    history = []
    
    wandb.watch(model, log_freq=50)
    torch.cuda.empty_cache()
    
    scaler = torch.cuda.amp.GradScaler(enabled = True)
    optimizer = torch.optim.AdamW(model.parameters(), max_lr, weight_decay=weight_decay)
    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epochs, 
                                                steps_per_epoch=len(train_loader))
    for _ in trange(epochs):
        model.train()
        train_losses = []
        for images, labels in train_loader:
            images = images.to(CFG.device)
            labels = labels.to(CFG.device)
            
            with torch.cuda.amp.autocast():
                out = model(images)
                loss = F.cross_entropy(out, labels)
        
            wandb.log({"Train loss": loss.item()})
            wandb.log({"lr": get_lr(optimizer)})
            
            train_losses.append(loss)
            
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            if grad_clip: 
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)
            scaler.step(optimizer)
            scaler.update()
            scheduler.step()
            optimizer.zero_grad()
        
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        history.append(result)
    return history

In [None]:
run = wandb.init(
    project='ResNet task',
    name='ResNet basic',
    config=dict(),
    job_type='Train',
    anonymous='must',
    reinit=True
)

In [None]:
model = get_simple_model()

history = train_one_cycle(
    model, 
    train_loader, 
    test_loader,
    CFG.epochs, 
    CFG.max_lr,
    grad_clip=CFG.grad_clip, 
    weight_decay=CFG.weight_decay
)

In [None]:
run.finish()

## ResNet18 with parametrization

In [None]:
class MLP_param(nn.Module):
    def __init__(self, kernel_size: Tuple[int, int], input_size: int=2):
        super().__init__()
        self.input_size = input_size
        self.kernel_size = kernel_size
        
        self.mlp = MLP(
            self.input_size,
            [self.kernel_size[0] * self.kernel_size[1]], 
            activation_layer=torch.nn.GELU
        ).to(CFG.device)
        self.mlp.requires_grad = False
        
    def forward(self, weight: torch.Tensor):
        shape = weight.shape
        grid = torch.cartesian_prod(
            torch.linspace(-1, 1, shape[0], dtype=weight.dtype, device=weight.device),
            torch.linspace(-1, 1, shape[1], dtype=weight.dtype, device=weight.device)
        ).reshape(shape[0], shape[1], self.input_size)
        with torch.cuda.amp.autocast(dtype=weight.dtype):
            result = self.mlp(grid).reshape(shape)
        return result

In [None]:
model = get_simple_model()

for module in model.modules():
    if isinstance(module, torch.nn.modules.conv.Conv2d):
        parametrize.register_parametrization(
            module,
            "weight",
            MLP_param(module.kernel_size)
        )

In [None]:
wandb.init(
    project='ResNet task',
    name='ResNet with parametrization',
    config=dict(),
    job_type='Train',
    anonymous='must'
)

In [None]:
history = train_one_cycle(
    model, 
    train_loader, 
    test_loader,
    CFG.epochs, 
    CFG.max_lr,
    grad_clip=CFG.grad_clip, 
    weight_decay=CFG.weight_decay
)