In [None]:
pip install wandb

In [8]:
import os
from multiprocessing import freeze_support
import wandb

import torch
from torchvision.datasets import CIFAR10
from torchvision.transforms import v2
from torch.utils.data import Dataset, DataLoader

from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter


In [23]:
wandb.login(key='b3b01019822e0451300a1e9ea1f3310d63834d44')
config = {'method': 'random',
          'metric': {'goal': 'maximize', 'name': 'acc_val'},
          'parameters': {'val_batch_size': {'distribution': 'q_log_uniform_values', 'max': 256, 'min': 32, 'q': 8},
                         'epochs': {'value': 50},
                         'learning_rate': {'distribution': 'uniform', 'max': 0.01, 'min': 0},
                         }
          }

sweep_id = wandb.sweep(config, project="ATNN-Lab05")

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Create sweep with ID: cdffla3s
Sweep URL: https://wandb.ai/gardeleanu151/ATNN-Lab05/sweeps/cdffla3s


In [11]:
class MLP(torch.nn.Module):

    def __init__(self, input, hidden, output):
        super(MLP, self).__init__()
        self.fc1 = torch.nn.Linear(input, hidden)
        self.fc2 = torch.nn.Linear(hidden, output)
        self.relu = torch.nn.ReLU(inplace=True)

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))


class cached_dataset(Dataset):

    def __init__(self, dataset, cache=True):
        if cache:
            dataset = tuple([x for x in dataset])
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        return self.dataset[i]

In [12]:
def get_device():

    if torch.cuda.is_available():
        return torch.device('cuda')
    elif torch.backends.mps.is_available():
        return torch.device('mos')
    return torch.device('cpu')

def get_accuracy(results, targets):
    fp_fn = torch.logical_not(results == targets).sum().item()
    total = len(results)
    return (total - fp_fn) / total

def train(model, train_loader, criteria, optimizer, device):
    model.train()

    all_results = []
    all_targets = []
    loss = 0
    batch_loss = []
    for data, targets in train_loader:
        data = data.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)
        result = model(data)
        loss = criteria(result, targets)

        loss.backward()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)

        loss += loss.item()
        batch_loss.append(loss.item())

        result = result.softmax(dim=1).detach().cpu().squeeze()
        targets = targets.cpu().squeeze()
        all_results.append(result)
        all_targets.append(targets)

    all_results = torch.cat(all_results).argmax(dim=1)
    all_targets = torch.cat(all_targets)

    return round(get_accuracy(all_results, all_targets), 4), loss, batch_loss

def validate(model, valid_loader, criteria, device):
    model.eval()

    all_results = []
    all_targets = []

    val_loss = 0
    for data, targets in valid_loader:
        data = data.to(device, non_blocking=True)

        with torch.no_grad():
            results = model(data)

        results = results.softmax(dim=1).cpu().squeeze()
        targets = targets.squeeze()

        loss = criteria(results, targets)
        val_loss += loss.item()

        all_results.append(results)
        all_targets.append(targets)

    all_results = torch.cat(all_results).argmax(dim=1)
    all_targets = torch.cat(all_targets)

    return round(get_accuracy(all_results, all_targets), 4), val_loss

def run_epoch(model, train_loader, valid_loader, criteria, optimizer, device):
    acc, loss, batch_loss = train(model, train_loader, criteria, optimizer, device)
    valid_acc, valid_loss = validate(model, valid_loader, criteria, device)
    return acc, valid_acc, loss, valid_loss, batch_loss

def get_norm(model):
    norm = 0.0
    for param in model.parameters():
        norm += torch.norm(param)
    return norm


In [None]:
def main(device=get_device()):
    transforms = [v2.ToImage(), v2.ToDtype(torch.float32, scale=True), v2.Resize((28, 28), antialias=True), v2.Grayscale(), torch.flatten,]

    data_path = '../data'
    train_data = CIFAR10(root=data_path, train=True, transform=v2.Compose(transforms), download=True)
    valid_data = CIFAR10(root=data_path, train=False, transform=v2.Compose(transforms), download=True)
    train_data = cached_dataset(train_data)
    valid_data = cached_dataset(valid_data)

    model = MLP(784, 100, 10)
    model = model.to(device)

    with wandb.init(config=None):
        config = wandb.config
        epochs = config.epochs
        batch_size = 256
        val_batch_size = config.val_batch_size
        num_workers = 2
        persistent_workers = (num_workers != 0)
        criteria = torch.nn.CrossEntropyLoss()

        train_loader = DataLoader(train_data, shuffle=True, pin_memory=pin_memory, num_workers=num_workers, batch_size=batch_size, drop_last=True, persistent_workers=persistent_workers)
        valid_loader = DataLoader(valid_data, shuffle=False, pin_memory=True, num_workers=0, batch_size=val_batch_size, drop_last=False)

        pin_memory = device.type == 'cuda'

        summ_writer = SummaryWriter()
        tbar = tqdm(tuple(range(epochs)))

        #optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)
        #optimizer = torch.optim.Adagrad(model.parameters(), lr=config.learning_rate)
        optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
        #optimizer = torch.optim.SGD(model.parameters(), lr=config.learning_rate)

        for epoch in tbar:
            acc, acc_val, loss, valid_loss, batch_loss = run_epoch(model, train_loader, valid_loader, criteria, optimizer, device)
            tbar.set_postfix_str(f"Acc: {acc}, Acc_val: {acc_val}")

            summ_writer.add_scalar("Train/Loss", loss / len(train_loader), epoch)
            summ_writer.add_scalar("Train/Accuracy", acc, epoch)

            summ_writer.add_scalar("Val/Loss", valid_loss / len(valid_loader), epoch)
            summ_writer.add_scalar("Val/Accuracy", acc_val, epoch)

            summ_writer.add_scalar("Model/Norm", get_norm(model), epoch)
            summ_writer.add_scalar("Constants/Learning rate", config.learning_rate, epoch)
            summ_writer.add_scalar("Constants/Batch size", val_batch_size, epoch)

            for b, l in enumerate(batch_loss):
                summ_writer.add_scalar("Batch Train/Loss", l, b)

            wandb.log({"acc_val": acc_val, "epoch": epoch})

    summ_writer.close()
    wandb.finish()


if __name__ == '__main__':
    freeze_support()
    wandb.agent(sweep_id, main, count=3)