# wandb (weights and biases)

In [1]:
!pip install wandb -qqq
import wandb
# Log in to your W&B account
wandb.login()

[K     |████████████████████████████████| 1.8 MB 19.1 MB/s 
[K     |████████████████████████████████| 145 kB 70.6 MB/s 
[K     |████████████████████████████████| 181 kB 63.0 MB/s 
[K     |████████████████████████████████| 63 kB 2.1 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## Run experiment

In [2]:
import random

# Launch 5 simulated experiments
total_runs = 5
for run in range(total_runs):
  # Start a new run to track this script
  wandb.init(
      # Set the project where this run will be logged
      project="basic-intro", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_{run}", 
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.02,
      "architecture": "CNN",
      "dataset": "CIFAR-100",
      "epochs": 10,
      })
  
  # Simulation :
  # This simple block simulates a training loop logging metrics
  epochs = 10
  offset = random.random() / 5
  for epoch in range(2, epochs):
      acc = 1 - 2 ** -epoch - random.random() / epoch - offset
      loss = 2 ** -epoch + random.random() / epoch + offset
      
      # Log metrics from your script to W&B
      wandb.log({"acc": acc, "loss": loss})
      
  # Mark the run as finished
  wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mhesfy[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▁▃▆▆▇▅█▇
loss,█▅▃▁▂▂▂▁

0,1
acc,0.91441
loss,0.04743


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▁▃▇▆▆█▇▇
loss,█▇▄▂▃▁▃▂

0,1
acc,0.74614
loss,0.20223


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▂▁▇▇▇▆██
loss,▇█▆▅▂▁▁▃

0,1
acc,0.78237
loss,0.26255


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▁▅▇▅▅▇▇█
loss,█▄▃▂▃▂▂▁

0,1
acc,0.90211
loss,0.09086


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
acc,▄▁▃▇▆▇██
loss,█▅▂▄▂▁▂▂

0,1
acc,0.91599
loss,0.09746


## Setup DataLoader

In [3]:
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
from tqdm.notebook import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(dropout):
    "A simple model"
    model = nn.Sequential(nn.Flatten(),
                         nn.Linear(28*28, 256),
                         nn.BatchNorm1d(256),
                         nn.ReLU(),
                         nn.Dropout(dropout),
                         nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in tqdm(enumerate(valid_dl), leave=False):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)

## Explore the data

In [36]:
## Run the following code to show some samples
import matplotlib.pyplot as plt
import random

## Get the training data
data_loader = get_dataloader(is_train=True, batch_size=128, slice=5)
images, classes = next(iter(data_loader))   

fig, axs = plt.subplots(3, 3)
for i in range(3):
  for j in range(3):  
    image = images[random.randint(0, 128)].view(28, 28)
    axs[i, j].imshow(image, cmap=plt.get_cmap('gray'))
plt.show()

## Train the model

In [37]:
# Launch 5 experiments, trying different dropout rates
for _ in range(5):
    # initialise a wandb run
    wandb.init(
        project="pytorch-intro",
        config={
            "epochs": 10,
            "batch_size": 128,
            "lr": 1e-3,
            "dropout": random.uniform(0.01, 0.80),
            })
    
    # Copy your config 
    config = wandb.config

    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
    
    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

   # Training
    example_ct = 0
    step_ct = 0
    for epoch in tqdm(range(config.epochs)):
        model.train()
        for step, (images, labels) in enumerate(tqdm(train_dl, leave=False)):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            example_ct += len(images)
            metrics = {"train/train_loss": train_loss, 
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                       "train/example_ct": example_ct}
            
            if step + 1 < n_steps_per_epoch:
                # 🐝 Log train metrics to wandb 
                wandb.log(metrics)
                
            step_ct += 1

        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # 🐝 Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, 
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})
        
        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

    # Close your wandb run 
    wandb.finish()

  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.246, Valid Loss: 0.288986, Accuracy: 0.92


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.223, Valid Loss: 0.237089, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.156, Valid Loss: 0.199850, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.218, Valid Loss: 0.179299, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.133, Valid Loss: 0.172002, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.104, Valid Loss: 0.157825, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.060, Valid Loss: 0.158108, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.059, Valid Loss: 0.160000, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.091, Valid Loss: 0.151661, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.054, Valid Loss: 0.150632, Accuracy: 0.95


VBox(children=(Label(value='0.173 MB of 0.173 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▄▃▃▂▂▂▃▂▂▂▁▂▂▁▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▃▅▆▆▇██▇▇
val/val_loss,█▅▃▂▂▁▁▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.05432
val/val_accuracy,0.95
val/val_loss,0.15063


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.436, Valid Loss: 0.369925, Accuracy: 0.90


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.442, Valid Loss: 0.299617, Accuracy: 0.91


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.393, Valid Loss: 0.271438, Accuracy: 0.92


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.364, Valid Loss: 0.253655, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.297, Valid Loss: 0.242017, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.348, Valid Loss: 0.227699, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.306, Valid Loss: 0.224323, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.479, Valid Loss: 0.219001, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.413, Valid Loss: 0.210484, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.315, Valid Loss: 0.206561, Accuracy: 0.94


VBox(children=(Label(value='0.173 MB of 0.232 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.743535…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▃▂▃▂▃▂▁▃▁▁▂▂▂▁▁▁▂▂▂▂▁▁▁▁▁▂▁▁▂▁▁▂▁▂▁▁▂
val/val_accuracy,▁▃▅▆▆▇▆▇██
val/val_loss,█▅▄▃▃▂▂▂▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.31477
val/val_accuracy,0.937
val/val_loss,0.20656


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.286, Valid Loss: 0.299612, Accuracy: 0.91


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.342, Valid Loss: 0.240825, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.241, Valid Loss: 0.214199, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.185, Valid Loss: 0.200180, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.216, Valid Loss: 0.184805, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.156, Valid Loss: 0.178698, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.069, Valid Loss: 0.168824, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.118, Valid Loss: 0.155653, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.079, Valid Loss: 0.154168, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.108, Valid Loss: 0.158018, Accuracy: 0.95


VBox(children=(Label(value='0.173 MB of 0.233 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.742859…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▃▂▃▂▂▂▃▂▂▂▂▁▂▂▁▂▂▁▁▂▂▁▂▁▁▁▁▁▂▁▂▂▁▁▁▁▁
val/val_accuracy,▁▄▅▆▆▆████
val/val_loss,█▅▄▃▂▂▂▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.10794
val/val_accuracy,0.9505
val/val_loss,0.15802


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.397, Valid Loss: 0.275592, Accuracy: 0.92


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.178, Valid Loss: 0.226667, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.188, Valid Loss: 0.199590, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.095, Valid Loss: 0.187467, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.137, Valid Loss: 0.161976, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.024, Valid Loss: 0.156230, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.038, Valid Loss: 0.150993, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.066, Valid Loss: 0.157093, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.049, Valid Loss: 0.142579, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.030, Valid Loss: 0.152913, Accuracy: 0.95


VBox(children=(Label(value='0.173 MB of 0.233 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.742611…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▄▃▂▂▃▂▂▂▁▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▂▄▄▇▆▇███
val/val_loss,█▅▄▃▂▂▁▂▁▂

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.03015
val/val_accuracy,0.954
val/val_loss,0.15291


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.324, Valid Loss: 0.307040, Accuracy: 0.91


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.261, Valid Loss: 0.251734, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.175, Valid Loss: 0.224374, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.139, Valid Loss: 0.198968, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.187, Valid Loss: 0.190828, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.144, Valid Loss: 0.185762, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.210, Valid Loss: 0.179441, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.224, Valid Loss: 0.168305, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.080, Valid Loss: 0.166681, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.146, Valid Loss: 0.162068, Accuracy: 0.94


VBox(children=(Label(value='0.173 MB of 0.233 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.742779…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▃▂▂▂▃▂▂▂▂▂▂▁▂▁▂▂▂▂▁▁▁▁▂▁▁▂▂▁▁▁▁▂▁▁▁▁▁
val/val_accuracy,▁▄▆▆▆▆▇██▇
val/val_loss,█▅▄▃▂▂▂▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.14596
val/val_accuracy,0.9445
val/val_loss,0.16207
