# Using wandb to track experiments.

Demo task: multi-class image classification using CIFAR10 dataset.

In [2]:
from sklearn.metrics import average_precision_score
from torch.utils.data import DataLoader
from torchvision import datasets, models
from torchvision import transforms as T
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.optim as optim


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


# The next cell includes-
- Collecting the CIFAR10 dataset and defining data loaders.
- Methods to load model, criterion, optimizer and schedulers.
- Definition of AverageMeter

In [4]:
# Downloading CIFAR10 dataset
inp_transforms = T.Compose([T.ToTensor(),
                            T.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])])
tgt_transforms = T.Lambda(lambda y: torch.zeros(10, dtype=torch.long).scatter_(0, torch.tensor(y), value=1))
cifar10 = datasets.CIFAR10(root = "./",
                           transform = inp_transforms,
                           target_transform = tgt_transforms,
                           download = True)

# Defining dataset split (80-20)
train_dataset, val_dataset = torch.utils.data.random_split(cifar10,
                                                           [int(len(cifar10)*0.80), int(len(cifar10)*0.20)])

# Defining the dataloaders
train_dataloader = DataLoader(train_dataset,
                              batch_size=200,
                              shuffle=True)
val_dataloader = DataLoader(val_dataset,
                            batch_size=200,
                            shuffle=False)


# Method to get model based on config param model_type
def get_model(model_type):
    model = None
    if model_type == "pretrained": # Loading pretrained ResNet18 and with updated to final fc layer. 
        model = models.resnet18(pretrained=True)
        model.fc = nn.Linear(512, 10)
        model = model.to(device)
    elif model_type == "scratch": # Loading a blank ResNet18 which generated 10 outputs.
        model = models.resnet18(num_classes=10)
        model = model.to(device)
    else:
        raise NotImplemented
    return model


# Method to get criterion, optimizer and scheduler based on config params.
def get_criterion_optimizer_scheduler(config, model):
    optim_dct = {
        "adam": optim.Adam,
        "SGD": optim.SGD,
        "RMSprop": optim.RMSprop
    }
    optimizer = optim_dct[config["optimizer"]](model.parameters(), lr=config["lr"])
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           factor=0.1,
                                                           patience=config["scheduler_patience"],
                                                           threshold=config["scheduler_thresh"])
    criterion = nn.CrossEntropyLoss()
    return criterion, optimizer, scheduler



# Remainder of this cell includes definition of AverageMeter (can be ignored)
"""
Code taken from Pytorch ImageNet examples
https://github.com/pytorch/examples/blob/main/imagenet/main.py#L375
"""
class Summary():
    NONE = 0
    AVERAGE = 1
    SUM = 2
    COUNT = 3

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
        self.name = name
        self.fmt = fmt
        self.summary_type = summary_type
        self.val_history = list()
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        self.val_history = list()

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count
        self.val_history.append(val)

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)
    
    def summary(self):
        fmtstr = ''
        if self.summary_type is Summary.NONE:
            fmtstr = ''
        elif self.summary_type is Summary.AVERAGE:
            fmtstr = '{name} {avg:.3f}'
        elif self.summary_type is Summary.SUM:
            fmtstr = '{name} {sum:.3f}'
        elif self.summary_type is Summary.COUNT:
            fmtstr = '{name} {count:.3f}'
        else:
            raise ValueError('invalid summary type %r' % self.summary_type)        
        return fmtstr.format(**self.__dict__)


Files already downloaded and verified


# Following cell includes-
- Defining the train and eval loops.
- Method to trigger training loops based on config parameters.

In [13]:
# The train function without wandb logging

def train(model, criterion, optimizer, scheduler, epochs, train_dataloader, val_dataloader, device):
    for epoch in range(epochs):
        model.train()
        loss_meter = AverageMeter("train_loss", ":.5f")
        epoch_outs, epoch_tgt = list(), list()
        for data, tgt_vec in tqdm(train_dataloader):
            data, tgt_vec = data.to(device), tgt_vec.to(device)
            targets = torch.argmax(tgt_vec, axis=1)
            optimizer.zero_grad()
            out = model(data)
            loss = criterion(out, targets)
            loss_meter.update(loss.item(), data.shape[0])
            loss.backward()
            optimizer.step()
            epoch_outs.append(out)
            epoch_tgt.append(tgt_vec)
        predictions = torch.vstack([torch.softmax(out, axis=1) for out in epoch_outs]).detach().cpu().numpy()
        targets = torch.cat([tgt for tgt in epoch_tgt], dim=0).detach().cpu().numpy()
        ap_score = average_precision_score(targets, predictions)
        eval_loss_meter, eval_ap_score = evaluate(model, criterion, val_dataloader, device)
        data_to_log = {
            "epoch": epoch+1,
            "train_loss": loss_meter.avg,
            "eval_loss": eval_loss_meter.avg,
            "train_ap_score": ap_score,
            "eval_ap_score": eval_ap_score,
            "lr": optimizer.state_dict()["param_groups"][0]["lr"],
        }
        wandb.log(data_to_log)
        scheduler.step(eval_loss_meter.avg)
        print(data_to_log)


@torch.no_grad()
def evaluate(model, criterion, val_dataloader, device):
    model.eval()
    loss_meter = AverageMeter("eval_loss", ":.5f")
    epoch_outs, epoch_tgt = list(), list()
    for data, tgt_vec in val_dataloader:
        data, tgt_vec = data.to(device), tgt_vec.to(device)
        targets = torch.argmax(tgt_vec, axis=1)
        out = model(data)
        loss = criterion(out, targets)
        loss_meter.update(loss.item(), data.shape[0])
        epoch_outs.append(out)
        epoch_tgt.append(tgt_vec)
    predictions = torch.vstack([torch.softmax(out, axis=1) for out in epoch_outs]).detach().cpu().numpy()
    targets = torch.cat([tgt for tgt in epoch_tgt], dim=0).detach().cpu().numpy()
    ap_score = average_precision_score(targets, predictions)
    return loss_meter, ap_score


def trigger_training(config):
    model = get_model(config["model_type"])
    criterion, optimizer, scheduler = get_criterion_optimizer_scheduler(config, model)
    epochs = config["num_epochs"]

    train(model, criterion, optimizer, scheduler, epochs, train_dataloader, val_dataloader, device)


# Complete the config file, edit the cells in this notebook to log data to wandb and trigger training loops!

In [41]:
# Fill the Config file below and log the experiment at wandb
config = {
    "lr": 0.001, 
    "model_type": "scratch", # pretrained/scratch
    "optimizer": "SGD", # adam/SGD/RMSprop
    "criterion": "ce",
    "scheduler_patience": 3,
    "scheduler_thresh": 0.0001,
    "num_epochs": 30, # CHANGE
    "gpu_id": 0,
    "wandb_run_name": "2-SGD-final" ### FILL YOUR NAME HERE
}


In [42]:
trigger_training(config)


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 31.09it/s]


{'epoch': 1, 'train_loss': 2.1485008335113527, 'eval_loss': 1.9660384631156922, 'train_ap_score': 0.1953362108020245, 'eval_ap_score': 0.2665750828760356, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.91it/s]


{'epoch': 2, 'train_loss': 1.843226251602173, 'eval_loss': 1.7911156368255616, 'train_ap_score': 0.3158394883065685, 'eval_ap_score': 0.3398399494626542, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 30.62it/s]


{'epoch': 3, 'train_loss': 1.6853881573677063, 'eval_loss': 1.6891206741333007, 'train_ap_score': 0.3845953811329589, 'eval_ap_score': 0.38147522715307286, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.80it/s]


{'epoch': 4, 'train_loss': 1.5837844347953796, 'eval_loss': 1.6250621795654296, 'train_ap_score': 0.43157649659337893, 'eval_ap_score': 0.40938242818533066, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.10it/s]


{'epoch': 5, 'train_loss': 1.5054780840873718, 'eval_loss': 1.5748491668701172, 'train_ap_score': 0.4700009872602573, 'eval_ap_score': 0.43359858699007114, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.56it/s]


{'epoch': 6, 'train_loss': 1.445594495534897, 'eval_loss': 1.5391449165344238, 'train_ap_score': 0.49966607107301575, 'eval_ap_score': 0.448930347765073, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.25it/s]


{'epoch': 7, 'train_loss': 1.3924066418409347, 'eval_loss': 1.5106296491622926, 'train_ap_score': 0.5268410947526834, 'eval_ap_score': 0.46410834793459055, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 30.19it/s]


{'epoch': 8, 'train_loss': 1.3437018567323684, 'eval_loss': 1.4805212569236756, 'train_ap_score': 0.5508178799680628, 'eval_ap_score': 0.48023072919943866, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 30.22it/s]


{'epoch': 9, 'train_loss': 1.3023171693086624, 'eval_loss': 1.4639432120323181, 'train_ap_score': 0.5741242255961604, 'eval_ap_score': 0.4876828438743626, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 30.51it/s]


{'epoch': 10, 'train_loss': 1.2649261993169785, 'eval_loss': 1.4532112169265747, 'train_ap_score': 0.5930134168502191, 'eval_ap_score': 0.49617458588736385, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.58it/s]


{'epoch': 11, 'train_loss': 1.2255076599121093, 'eval_loss': 1.4316606068611144, 'train_ap_score': 0.6138517293729929, 'eval_ap_score': 0.508281303336699, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.46it/s]


{'epoch': 12, 'train_loss': 1.187906939983368, 'eval_loss': 1.4167302083969116, 'train_ap_score': 0.6333751581626108, 'eval_ap_score': 0.5150624600420745, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.02it/s]


{'epoch': 13, 'train_loss': 1.1538493132591248, 'eval_loss': 1.405981056690216, 'train_ap_score': 0.6500014778456162, 'eval_ap_score': 0.5214890184006304, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 28.62it/s]


{'epoch': 14, 'train_loss': 1.1223119324445725, 'eval_loss': 1.4009325623512268, 'train_ap_score': 0.6659003136225478, 'eval_ap_score': 0.5249317591276552, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 27.89it/s]


{'epoch': 15, 'train_loss': 1.087004269361496, 'eval_loss': 1.3878858923912047, 'train_ap_score': 0.6830686100632637, 'eval_ap_score': 0.5333743504836443, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 27.83it/s]


{'epoch': 16, 'train_loss': 1.0518583050370216, 'eval_loss': 1.386347966194153, 'train_ap_score': 0.7001633477503652, 'eval_ap_score': 0.5354295043306126, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.30it/s]


{'epoch': 17, 'train_loss': 1.0224031579494477, 'eval_loss': 1.3785009670257569, 'train_ap_score': 0.7146994177870399, 'eval_ap_score': 0.5395438743125077, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 29.66it/s]


{'epoch': 18, 'train_loss': 0.9904527333378792, 'eval_loss': 1.374214506149292, 'train_ap_score': 0.7294491381672543, 'eval_ap_score': 0.5437150182180871, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.39it/s]


{'epoch': 19, 'train_loss': 0.9622346946597099, 'eval_loss': 1.3691517090797425, 'train_ap_score': 0.7442196919704089, 'eval_ap_score': 0.5472047457811465, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 27.83it/s]


{'epoch': 20, 'train_loss': 0.9242797777056694, 'eval_loss': 1.370529944896698, 'train_ap_score': 0.7615228651029022, 'eval_ap_score': 0.5479670854884006, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.34it/s]


{'epoch': 21, 'train_loss': 0.8933922377228737, 'eval_loss': 1.3722400903701781, 'train_ap_score': 0.7753624624656942, 'eval_ap_score': 0.5499355057945055, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.51it/s]


{'epoch': 22, 'train_loss': 0.8635593789815903, 'eval_loss': 1.3690967226028443, 'train_ap_score': 0.7885721277999842, 'eval_ap_score': 0.5540032266970926, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 28.97it/s]


{'epoch': 23, 'train_loss': 0.8342064049839973, 'eval_loss': 1.3722304010391235, 'train_ap_score': 0.8013926554709243, 'eval_ap_score': 0.5534221352206312, 'lr': 0.001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 28.87it/s]


{'epoch': 24, 'train_loss': 0.7809086227416993, 'eval_loss': 1.3678875374794006, 'train_ap_score': 0.8276861504733939, 'eval_ap_score': 0.5559052743650975, 'lr': 0.0001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.50it/s]


{'epoch': 25, 'train_loss': 0.774853780567646, 'eval_loss': 1.3699351453781128, 'train_ap_score': 0.8314395091661952, 'eval_ap_score': 0.5555670551235423, 'lr': 0.0001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 28.60it/s]


{'epoch': 26, 'train_loss': 0.7696298086643218, 'eval_loss': 1.3691083812713623, 'train_ap_score': 0.8330405180357843, 'eval_ap_score': 0.5556700189291575, 'lr': 0.0001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.35it/s]


{'epoch': 27, 'train_loss': 0.7670722490549088, 'eval_loss': 1.368231213092804, 'train_ap_score': 0.8340646618182472, 'eval_ap_score': 0.5562939067756165, 'lr': 0.0001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.00it/s]


{'epoch': 28, 'train_loss': 0.7594921866059303, 'eval_loss': 1.3704845476150513, 'train_ap_score': 0.8381480538899861, 'eval_ap_score': 0.5557577907360629, 'lr': 0.0001}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:06<00:00, 28.98it/s]


{'epoch': 29, 'train_loss': 0.7544115936756134, 'eval_loss': 1.370063226222992, 'train_ap_score': 0.8403144836713221, 'eval_ap_score': 0.5566483942611228, 'lr': 1e-05}


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.15it/s]


{'epoch': 30, 'train_loss': 0.756203075349331, 'eval_loss': 1.3687308955192565, 'train_ap_score': 0.8386781274457087, 'eval_ap_score': 0.5564694287925537, 'lr': 1e-05}


# WandB Steps

In [10]:
### Step 1: Import WandB in your code

import wandb

### Step 1 ends

In [38]:
### Step 2:
# Initiate wandb in your script. The moment we trigger wandb.init(), an active
# socket connection is established between your machine and wandb server.
# We specify the entity (wandb username) and project (which wandb project to use for logging)

wandb.init(entity = "dhruv_sri",   # wandb username. (NOT REQUIRED ARG. ANYMORE, it fetches from initial login)
           project = "wandb_demo", # wandb project name. New project will be created if given project is missing.
           config = config         # Config dict
          )
wandb.run.name = config["wandb_run_name"]

### Step 2 ends.


0,1
epoch,▁▂▃▄▅▅▆▇█
eval_ap_score,▁▃▅▅▆▇▇██
eval_loss,█▆▄▃▃▂▂▁▁
lr,▁▁▁▁▁▁▁▁▁
train_ap_score,▁▃▄▅▆▇▇██
train_loss,█▆▄▃▃▂▂▁▁

0,1
epoch,9.0
eval_ap_score,0.46597
eval_loss,1.51406
lr,0.001
train_ap_score,0.5605
train_loss,1.3371


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01667039038332708, max=1.0)…

In [12]:
### Step 3: Trigger wandb log
# This step is responsible for sending the logs to wandb

# evaluate()
wandb.log(data_to_log)

### Step 3 ends.


NameError: name 'data_to_log' is not defined

In [None]:
### Step 4 (Optional)
# This closes the active socket connection to wandb server. Optional since wandb destructor does the same.

wandb.finish()

### Step 4 ends.


# WandB sweeps related steps

In [18]:
### Step 1:
# Create a WandB sweep config file.
# This config file will be used at the WandB website to initialize a sweep server
program: "demo.py"
method: "grid"
metric:
  name: "eval_ap_score"
  goal: "maximize"
parameters:
    criterion:
      value: "ce"
    gpu_id:
      value: 0
    lr:
      values: [0.1, 0.001, 0.0001]
    model_type:
      values: ["scratch", "pretrained"]
    num_epochs:
      value: 25
    optimizer:
      values: ["adam", "SGD", "RMSprop"]
    scheduler_patience:
      value: 3
    scheduler_thresh:
      value: 0.01

        
### A sample sweep config file if bayes method is used-
# program: wandb_demo.py
# method: bayes
# metric:
#   name: "eval_ap_score"
#   goal: maximize
# parameters:
#   lr:
#     distribution: uniform
#     min: 0.00001
#     max: 0.1
#   criterion:
#     distribution: categorical
#     value:
#       - ce
#   optimizer:
#     distribution: categorical
#     values:
#       - adam
#       - SGD
#       - RMSprop
#   model_type:
#     distribution: categorical
#     values:
#       - pretrained
#       - scratch
#   num_epochs:
#     value:
#       - 30
#   scheduler_thresh:
#     distribution: uniform
#     min: 0.001
#     max: 0.01
#   scheduler_patience:
#     distribution: int_uniform
#     min: 2
#     max: 10


SyntaxError: invalid syntax (3096734612.py, line 6)

In [None]:
### Step 2
# After using the above config on wandb website, you will get a sweep id in return.
# E.g. sweep id- dhruv_sri/wandb_demo/hbyp0tl8
#
# Add the following agent line in your code-
# Use the generated sweep id in the below code

wandb.agent(sweep_id="### FILL SWEEP ID HERE ###", function=sweep_agent_manager, count=100)


In [None]:
### Step 3
# Notice in above command we mentioned an argument named "function"
# Wandb agents must trigger a function where they can initiate a socket to wandb and get a config.
# So, we will use the following sweep_agent_manager function here-

def sweep_agent_manager():
    wandb.init()
    config = dict(wandb.config)
    run_name = f"{config['model_type']}_{config['optimizer']}_{config['lr']}"
    wandb.run.name = run_name
    trigger_training(config)


In [None]:
### Done.
# Now execute your training script on multiple machines.
# Each run will request the config file from wandb and related experiments will be logged.
# 
# NOTE!! wandb.log(data_to_log) must be present inside the code!! Else there is no meaning to sweep.


# ------------------------------ Ends ------------------------------