In [28]:
!pip install wandb -Uq
!pip install sam-pytorch
!pip install utility



In [46]:
import os
from multiprocessing import freeze_support

import torch
import wandb
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from sam import SAM


def get_default_device():
    if torch.cuda.is_available():
        return torch.device('cuda')
    if torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')


class CachedDataset(Dataset):
    def __init__(self, dataset, cache=True):
        if cache:
            dataset = tuple([x for x in dataset])
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return self.dataset[i]


class MLP(torch.nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input_size, hidden_size)
        self.fc2 = torch.nn.Linear(hidden_size, output_size)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))
        # x = self.fc1(x)
        # x = self.relu(x)
        # x = self.fc2(x)
        # return x

class SAM(torch.optim.Optimizer):
    def __init__(self, params, base_optimizer, rho=0.05, adaptive=False, **kwargs):
        assert rho >= 0.0, f"Invalid rho, should be non-negative: {rho}"

        defaults = dict(rho=rho, adaptive=adaptive, **kwargs)
        super(SAM, self).__init__(params, defaults)

        self.base_optimizer = base_optimizer(self.param_groups, **kwargs)
        self.param_groups = self.base_optimizer.param_groups
        self.defaults.update(self.base_optimizer.defaults)

    @torch.no_grad()
    def first_step(self, zero_grad=False):
        grad_norm = self._grad_norm()
        for group in self.param_groups:
            scale = group["rho"] / (grad_norm + 1e-12)

            for p in group["params"]:
                if p.grad is None: continue
                self.state[p]["old_p"] = p.data.clone()
                e_w = (torch.pow(p, 2) if group["adaptive"] else 1.0) * p.grad * scale.to(p)
                p.add_(e_w)  # climb to the local maximum "w + e(w)"

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def second_step(self, zero_grad=False):
        for group in self.param_groups:
            for p in group["params"]:
                if p.grad is None: continue
                p.data = self.state[p]["old_p"]  # get back to "w" from "w + e(w)"

        self.base_optimizer.step()  # do the actual "sharpness-aware" update

        if zero_grad: self.zero_grad()

    @torch.no_grad()
    def step(self, closure=None):
        assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
        closure = torch.enable_grad()(closure)  # the closure should do a full forward-backward pass

        self.first_step(zero_grad=True)
        closure()
        self.second_step()

    def _grad_norm(self):
        shared_device = self.param_groups[0]["params"][0].device  # put everything on the same device, in case of model parallelism
        norm = torch.norm(
                    torch.stack([
                        ((torch.abs(p) if group["adaptive"] else 1.0) * p.grad).norm(p=2).to(shared_device)
                        for group in self.param_groups for p in group["params"]
                        if p.grad is not None
                    ]),
                    p=2
               )
        return norm

    def load_state_dict(self, state_dict):
        super().load_state_dict(state_dict)
        self.base_optimizer.param_groups = self.param_groups

def accuracy(output, labels):
    fp_plus_fn = torch.logical_not(output == labels).sum().item()
    all_elements = len(output)
    return (all_elements - fp_plus_fn) / all_elements


def train(model, train_loader, criterion, optimizer, device):
    model.train()

    all_outputs = []
    all_labels = []

    batches_loss = []

    sam = wandb.config.optimizer == 4 or wandb.config.optimizer == 5

    for data, labels in train_loader:
        data = data.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)
        output = model(data)
        loss = criterion(output, labels)
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 5)

        def closure():
            output = model(data)
            loss = criterion(output, labels)
            loss.backward()
            return loss

        if sam:
            optimizer.step(closure)
        else:
            optimizer.step()

        optimizer.zero_grad(set_to_none=True)

        output = output.softmax(dim=1).detach().cpu().squeeze()
        labels = labels.cpu().squeeze()
        all_outputs.append(output)
        all_labels.append(labels)

        batches_loss.append(loss.item())

    all_outputs = torch.cat(all_outputs).argmax(dim=1)
    all_labels = torch.cat(all_labels)

    return round(accuracy(all_outputs, all_labels), 4), batches_loss


def val(model, val_loader, criterion, device):
    model.eval()

    all_outputs = []
    all_labels = []

    val_loss = 0

    for data, labels in val_loader:
        data = data.to(device, non_blocking=True)
        labels_cuda = labels.to(device, non_blocking=True)

        with torch.no_grad():
            output = model(data)

        val_loss += criterion(output, labels_cuda).item()

        output = output.softmax(dim=1).cpu().squeeze()
        labels = labels.squeeze()
        all_outputs.append(output)
        all_labels.append(labels)

    all_outputs = torch.cat(all_outputs).argmax(dim=1)
    all_labels = torch.cat(all_labels)

    return round(accuracy(all_outputs, all_labels), 4), val_loss


def do_epoch(model, train_loader, val_loader, criterion, optimizer, device):
    acc, batches_loss = train(model, train_loader, criterion, optimizer, device)
    acc_val, val_loss = val(model, val_loader, criterion, device)
    # torch.cuda.empty_cache()
    return acc, acc_val, batches_loss, val_loss


def get_model_norm(model):
    norm = 0.0
    for param in model.parameters():
        norm += torch.norm(param)
    return norm


def main(device=get_default_device(), config=None):
    transforms = [
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        v2.Resize((28, 28), antialias=True),
        v2.Grayscale(),
        torch.flatten,
    ]

    data_path = '../data'
    train_dataset = CIFAR10(root=data_path, train=True, transform=v2.Compose(transforms), download=True)
    val_dataset = CIFAR10(root=data_path, train=False, transform=v2.Compose(transforms), download=True)
    train_dataset = CachedDataset(train_dataset)
    val_dataset = CachedDataset(val_dataset)

    model = MLP(784, 100, 10)
    model = model.to(device)

    epochs = 60
    val_batch_size = 500
    num_workers = 2
    persistent_workers = (num_workers != 0)
    pin_memory = device.type == 'cuda'

    with wandb.init(config=config):
        config = wandb.config
        train_loader = DataLoader(train_dataset, shuffle=True, pin_memory=pin_memory, num_workers=num_workers,
                            batch_size=config.batch_size, drop_last=True, persistent_workers=persistent_workers)
        val_loader = DataLoader(val_dataset, shuffle=False, pin_memory=True, num_workers=0, batch_size=val_batch_size,
                                drop_last=False)
        if config.optimizer == 0:
            optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)
        elif config.optimizer == 1:
            optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        elif config.optimizer == 2:
            optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
        elif config.optimizer == 3:
            optimizer = torch.optim.Adagrad(model.parameters(), lr=config.learning_rate)
        elif config.optimizer == 4:
            base_optimizer = torch.optim.SGD
            optimizer = SAM(model.parameters(), base_optimizer, lr=config.learning_rate, momentum=0.9)
        elif config.optimizer == 5:
            base_optimizer = torch.optim.SGD
            optimizer = SAM(model.parameters(), base_optimizer, rho = 2.0, adaptive=True, lr=config.learning_rate, momentum=0.9)
        criterion = torch.nn.CrossEntropyLoss()

        writer = SummaryWriter()
        tbar = tqdm(tuple(range(epochs)))

        optimizer_name = (type (optimizer).__name__)
        writer.add_scalar("Params/Learning_rate", config.learning_rate)
        writer.add_scalar("Params/Batch_size", config.batch_size)
        writer.add_scalar("Params/Optimizer", config.optimizer)

        for epoch in tbar:
            acc, acc_val, batches_loss, val_loss = do_epoch(model, train_loader, val_loader, criterion, optimizer, device)
            epoch_loss = sum(batches_loss) / len(batches_loss)
            tbar.set_postfix_str(f"Optimizer: {optimizer_name}, Acc: {acc}, Acc_val: {acc_val}")
            writer.add_scalar("Train/Accuracy", acc, epoch)
            writer.add_scalar("Val/Accuracy", acc_val, epoch)
            writer.add_scalar("Train/Loss", epoch_loss, epoch)
            writer.add_scalar("Val/Loss", val_loss, epoch)
            writer.add_scalar("Model/Norm", get_model_norm(model), epoch)
            for i in range(0, len(batches_loss)):
                writer.add_scalar("Train/Batch", batches_loss[i], i)

            wandb.log({"epoch": epoch, "accuracy": acc_val})

if __name__ == '__main__':
    wandb.login()
    sweep_config = {
    'method': 'random'
    }

    parameters_dict = {
    'optimizer': {
        'values': [0, 1, 2, 3, 4, 5]
        },
    'learning_rate': {
        'distribution': 'uniform',
        'min': 0,
        'max': 0.05
      },
    'batch_size': {
        'distribution': 'q_log_uniform_values',
        'q': 8,
        'min': 64,
        'max': 256,
      }
    }

    sweep_config['parameters'] = parameters_dict
    sweep_id = wandb.sweep(sweep_config, project="Lab05")
    freeze_support()
    wandb.agent(sweep_id, main, count=6)

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7a3ab6030ac0>> (for pre_run_cell):


BrokenPipeError: ignored



Create sweep with ID: tig6lovx
Sweep URL: https://wandb.ai/apetrii-radu/Lab05/sweeps/tig6lovx


[34m[1mwandb[0m: Agent Starting Run: k32fc7yp with config:
[34m[1mwandb[0m: 	batch_size: 160
[34m[1mwandb[0m: 	learning_rate: 0.04224258583346164
[34m[1mwandb[0m: 	optimizer: 1


Files already downloaded and verified
Files already downloaded and verified


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 1162, in init
    wi.setup(kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 225, in setup
    with telemetry.context(obj=self._init_telemetry_obj) as tel:
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/lib/telemetry.py", line 42, in __exit__
    self._run._telemetry_callback(self._obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 758, in _telemetry_callback
    self._telemetry_flush()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 769, in _telemetry_flush
    self._backend.interface._publish_telemetry(self._telemetry_obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_shared.py", line 101, in _publish_telemetry
    self._publish(rec)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_sock.py", line 51, in 

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 1162, in init
    wi.setup(kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 225, in setup
    with telemetry.context(obj=self._init_telemetry_obj) as tel:
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/lib/telemetry.py", line 42, in __exit__
    self._run._telemetry_callback(self._obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 758, in _telemetry_callback
    self._telemetry_flush()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 769, in _telemetry_flush
    self._backend.interface._publish_telemetry(self._telemetry_obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_shared.py", line 101, in _publish_telemetry
    self._publish(rec)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_sock.py", line 51, in 

Files already downloaded and verified
Files already downloaded and verified


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 1162, in init
    wi.setup(kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 225, in setup
    with telemetry.context(obj=self._init_telemetry_obj) as tel:
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/lib/telemetry.py", line 42, in __exit__
    self._run._telemetry_callback(self._obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 758, in _telemetry_callback
    self._telemetry_flush()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 769, in _telemetry_flush
    self._backend.interface._publish_telemetry(self._telemetry_obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_shared.py", line 101, in _publish_telemetry
    self._publish(rec)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_sock.py", line 51, in 

Files already downloaded and verified
Files already downloaded and verified


Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 1162, in init
    wi.setup(kwargs)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_init.py", line 225, in setup
    with telemetry.context(obj=self._init_telemetry_obj) as tel:
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/lib/telemetry.py", line 42, in __exit__
    self._run._telemetry_callback(self._obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 758, in _telemetry_callback
    self._telemetry_flush()
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/wandb_run.py", line 769, in _telemetry_flush
    self._backend.interface._publish_telemetry(self._telemetry_obj)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_shared.py", line 101, in _publish_telemetry
    self._publish(rec)
  File "/usr/local/lib/python3.10/dist-packages/wandb/sdk/interface/interface_sock.py", line 51, in 

Files already downloaded and verified
Files already downloaded and verified


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x7a3ab6030ac0>> (for post_run_cell):


BrokenPipeError: ignored