This is an example taken from https://colab.research.google.com/github/wandb/examples/
It helps us to understand how Weights & Biases can be used for machine learning experiment tracking, model checkpointing, and collaboration with your team. Here we mainly focus on how to train and track a training experiment using Pytorch Neural Network, after setting up an account on W&B.

<img src="http://wandb.me/logo-im-png" width="400" alt="Weights & Biases" />
<!--- @wandbcode{intro-colab} -->

## 🪄 Install `wandb` library and login


Start by installing the library and logging in to your free account.



In [2]:
!pip install wandb -qU

[K     |████████████████████████████████| 1.9 MB 4.4 MB/s 
[K     |████████████████████████████████| 182 kB 29.1 MB/s 
[K     |████████████████████████████████| 166 kB 27.2 MB/s 
[K     |████████████████████████████████| 63 kB 1.5 MB/s 
[K     |████████████████████████████████| 166 kB 34.8 MB/s 
[K     |████████████████████████████████| 162 kB 25.2 MB/s 
[K     |████████████████████████████████| 162 kB 23.7 MB/s 
[K     |████████████████████████████████| 158 kB 24.1 MB/s 
[K     |████████████████████████████████| 157 kB 28.0 MB/s 
[K     |████████████████████████████████| 157 kB 26.6 MB/s 
[K     |████████████████████████████████| 157 kB 25.0 MB/s 
[K     |████████████████████████████████| 157 kB 27.0 MB/s 
[K     |████████████████████████████████| 157 kB 25.4 MB/s 
[K     |████████████████████████████████| 157 kB 33.0 MB/s 
[K     |████████████████████████████████| 157 kB 30.6 MB/s 
[K     |████████████████████████████████| 156 kB 30.5 MB/s 
[?25h  Building wheel for 

In [7]:
# Log in to your W&B account
import wandb
wandb.login(key='') #please specify our own login key

[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

## 👟 Run an experiment
1️⃣. **Start a new run** and pass in hyperparameters to track

2️⃣. **Log metrics** from training or evaluation

3️⃣. **Visualize results** in the dashboard

In [8]:
import random

# Launch 5 simulated experiments
total_runs = 10
for run in range(total_runs):
  # 🐝 1️⃣ Start a new run to track this script
  wandb.init(
      # Set the project where this run will be logged
      project="basic-intro", 
      # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
      name=f"experiment_{run}", 
      # Track hyperparameters and run metadata
      config={
      "learning_rate": 0.02,
      "architecture": "CNN",
      "dataset": "CIFAR-10",
      "epochs": 20,
      })
  
  # This simple block simulates a training loop logging metrics
  epochs = 20
  offset = random.random() / 5
  for epoch in range(2, epochs):
      acc = 1 - 2 ** -epoch - random.random() / epoch - offset
      loss = 2 ** -epoch + random.random() / epoch + offset
      
      # 🐝 2️⃣ Log metrics from your script to W&B
      wandb.log({"acc": acc, "loss": loss})
      
  # Mark the run as finished
  wandb.finish()

0,1
acc,▁▆▄▅▆█▇▇▇███▇█▇▇██
loss,█▃▃▂▁▁▂▂▂▁▂▂▂▂▁▁▁▁

0,1
acc,0.95993
loss,0.03696


0,1
acc,▁▆▄▆▇▇▇▇▇▇██▇███▇█
loss,██▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁

0,1
acc,0.86254
loss,0.09406


0,1
acc,▁▆▅▆▇▇▇▇█▇▇█▇█████
loss,█▃▃▄▁▁▂▂▁▁▂▂▁▁▁▂▂▁

0,1
acc,0.78601
loss,0.20111


0,1
acc,▁▆▅████▇█▇█████▇██
loss,█▇▃▃▁▁▂▂▂▂▁▂▁▁▂▁▂▁

0,1
acc,0.84449
loss,0.13137


0,1
acc,▁▃▆▆▆▇▇▇██▇▇█▇██▇█
loss,█▆▃▂▂▃▂▁▁▁▁▁▂▁▁▁▁▁

0,1
acc,0.8309
loss,0.18071


VBox(children=(Label(value='0.001 MB of 0.008 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.137016…

0,1
acc,▁▆▇▆▇▇▇▇▇████▇████
loss,█▇▃▂▃▁▁▁▃▂▁▂▂▁▁▁▁▂

0,1
acc,0.97506
loss,0.05058


0,1
acc,▁▄▅█▇▇▇█▇█▇███████
loss,▇█▃▃▃▂▁▁▁▂▂▁▂▁▂▂▁▁

0,1
acc,0.97775
loss,0.00957


0,1
acc,▂▁▇▇▇▆█▇████▇▇▇▇▇█
loss,▆█▅▃▃▁▂▁▁▁▁▁▂▂▂▁▁▁

0,1
acc,0.96468
loss,0.03202


0,1
acc,▁▃▃▄▇▄▆▅▇█▆▇█▇▇▇██
loss,█▃▃▃▃▂▂▂▁▂▂▂▂▁▂▁▁▁

0,1
acc,0.79705
loss,0.19784


0,1
acc,▁▃▇▇█▇▇▇██▇▇█▇█▇█▇
loss,█▇▅▂▁▁▂▂▁▁▂▁▁▁▁▁▁▂

0,1
acc,0.77502
loss,0.23179


3️⃣ You can find your interactive dashboard by clicking any of the  👆 wandb links above.

# 🔥 Simple Pytorch Neural Network

💪 Run this model to train a simple MNIST classifier, and click on the project page link to see your results stream in live to a W&B project.


Any run in `wandb` automatically logs [metrics](https://docs.wandb.ai/ref/app/pages/run-page#charts-tab),
[system information](https://docs.wandb.ai/ref/app/pages/run-page#system-tab),
[hyperparameters](https://docs.wandb.ai/ref/app/pages/run-page#overview-tab),
[terminal output](https://docs.wandb.ai/ref/app/pages/run-page#logs-tab) and
you'll see an [interactive table](https://docs.wandb.ai/guides/data-vis)
with model inputs and outputs.

## Set up Dataloader

In [10]:
#@title
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(dropout):
    "A simple model"
    model = nn.Sequential(nn.Flatten(),
                         nn.Linear(32*32, 256),
                         nn.BatchNorm1d(256),
                         nn.ReLU(),
                         nn.Dropout(dropout),
                         nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)

## Train Your Model

In [9]:
# Launch 5 experiments, trying different dropout rates
for _ in range(5):
    # 🐝 initialise a wandb run
    wandb.init(
        project="pytorch-intro",
        config={
            "epochs": 20,
            "batch_size": 64,
            "lr": 1e-3,
            "dropout": random.uniform(0.02, 0.90),
            })
    
    # Copy your config 
    config = wandb.config

    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
    
    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

   # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            example_ct += len(images)
            metrics = {"train/train_loss": train_loss, 
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                       "train/example_ct": example_ct}
            
            if step + 1 < n_steps_per_epoch:
                # 🐝 Log train metrics to wandb 
                wandb.log(metrics)
                
            step_ct += 1

        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # 🐝 Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, 
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})
        
        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

    # 🐝 Close your wandb run 
    wandb.finish()

Train Loss: 0.538, Valid Loss: 0.345946, Accuracy: 0.90
Train Loss: 0.683, Valid Loss: 0.286832, Accuracy: 0.91
Train Loss: 0.471, Valid Loss: 0.259266, Accuracy: 0.92
Train Loss: 0.412, Valid Loss: 0.243488, Accuracy: 0.93
Train Loss: 0.289, Valid Loss: 0.237428, Accuracy: 0.93
Train Loss: 0.345, Valid Loss: 0.227471, Accuracy: 0.93
Train Loss: 0.318, Valid Loss: 0.220237, Accuracy: 0.93
Train Loss: 0.284, Valid Loss: 0.210837, Accuracy: 0.94
Train Loss: 0.657, Valid Loss: 0.206419, Accuracy: 0.94
Train Loss: 0.130, Valid Loss: 0.201418, Accuracy: 0.94
Train Loss: 0.099, Valid Loss: 0.200119, Accuracy: 0.94
Train Loss: 0.256, Valid Loss: 0.192761, Accuracy: 0.94
Train Loss: 0.347, Valid Loss: 0.191222, Accuracy: 0.94
Train Loss: 0.116, Valid Loss: 0.196548, Accuracy: 0.94
Train Loss: 0.157, Valid Loss: 0.183463, Accuracy: 0.95
Train Loss: 0.103, Valid Loss: 0.186012, Accuracy: 0.94
Train Loss: 0.185, Valid Loss: 0.181211, Accuracy: 0.94
Train Loss: 0.209, Valid Loss: 0.183602, Accurac

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▅▄▃▄▆▄▃▄▅▂▃▄▄▅▃▂▃▂▃▂▂▁▄▁▁▃▃▂▁▁▂▂▂▂▃▂▁▁
val/val_accuracy,▁▂▄▅▄▅▆▆▆▇▆▇▆▇█▇▇█▇▇
val/val_loss,█▆▄▄▄▃▃▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
test_accuracy,0.8
train/epoch,20.0
train/example_ct,240000.0
train/train_loss,0.28726
val/val_accuracy,0.9405
val/val_loss,0.17682


Train Loss: 0.330, Valid Loss: 0.262032, Accuracy: 0.93
Train Loss: 0.307, Valid Loss: 0.219681, Accuracy: 0.93
Train Loss: 0.153, Valid Loss: 0.199535, Accuracy: 0.94
Train Loss: 0.109, Valid Loss: 0.171429, Accuracy: 0.95
Train Loss: 0.135, Valid Loss: 0.175242, Accuracy: 0.95
Train Loss: 0.076, Valid Loss: 0.160210, Accuracy: 0.95
Train Loss: 0.025, Valid Loss: 0.156895, Accuracy: 0.95
Train Loss: 0.012, Valid Loss: 0.146056, Accuracy: 0.96
Train Loss: 0.051, Valid Loss: 0.148607, Accuracy: 0.95
Train Loss: 0.141, Valid Loss: 0.149910, Accuracy: 0.96
Train Loss: 0.040, Valid Loss: 0.150638, Accuracy: 0.96
Train Loss: 0.024, Valid Loss: 0.155221, Accuracy: 0.95
Train Loss: 0.028, Valid Loss: 0.148495, Accuracy: 0.95
Train Loss: 0.057, Valid Loss: 0.161325, Accuracy: 0.96
Train Loss: 0.017, Valid Loss: 0.167858, Accuracy: 0.95
Train Loss: 0.042, Valid Loss: 0.157609, Accuracy: 0.95
Train Loss: 0.021, Valid Loss: 0.153134, Accuracy: 0.96
Train Loss: 0.022, Valid Loss: 0.160536, Accurac

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▅▄▃▃▂▂▂▂▂▂▁▂▂▁▁▂▂▂▁▂▁▂▂▁▁▁▂▁▁▁▁▂▂▁▁▂▂▂
val/val_accuracy,▁▃▄▅▆▆▇█▇▇▇▇▇█▇▇███▇
val/val_loss,█▅▄▃▃▂▂▁▁▁▁▂▁▂▂▂▁▂▂▂

0,1
test_accuracy,0.8
train/epoch,20.0
train/example_ct,240000.0
train/train_loss,0.16783
val/val_accuracy,0.955
val/val_loss,0.16238


Train Loss: 0.261, Valid Loss: 0.269500, Accuracy: 0.92
Train Loss: 0.309, Valid Loss: 0.216647, Accuracy: 0.94
Train Loss: 0.277, Valid Loss: 0.190791, Accuracy: 0.94
Train Loss: 0.285, Valid Loss: 0.179711, Accuracy: 0.94
Train Loss: 0.093, Valid Loss: 0.168203, Accuracy: 0.95
Train Loss: 0.077, Valid Loss: 0.155528, Accuracy: 0.95
Train Loss: 0.044, Valid Loss: 0.153432, Accuracy: 0.95
Train Loss: 0.225, Valid Loss: 0.146848, Accuracy: 0.95
Train Loss: 0.064, Valid Loss: 0.158515, Accuracy: 0.95
Train Loss: 0.313, Valid Loss: 0.151116, Accuracy: 0.95
Train Loss: 0.046, Valid Loss: 0.136666, Accuracy: 0.96
Train Loss: 0.008, Valid Loss: 0.144247, Accuracy: 0.96
Train Loss: 0.240, Valid Loss: 0.133902, Accuracy: 0.96
Train Loss: 0.105, Valid Loss: 0.148344, Accuracy: 0.96
Train Loss: 0.018, Valid Loss: 0.150337, Accuracy: 0.95
Train Loss: 0.085, Valid Loss: 0.152747, Accuracy: 0.96
Train Loss: 0.145, Valid Loss: 0.147039, Accuracy: 0.96
Train Loss: 0.058, Valid Loss: 0.145244, Accurac

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▃▅▂▂▂▃▃▂▂▂▂▂▂▁▂▂▁▂▂▂▁▁▁▃▂▁▁▁▁▁▁▂▁▂▁▁▁▁
val/val_accuracy,▁▃▄▅▆▇▆▇▇▇████▇▇███▇
val/val_loss,█▅▄▃▃▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂

0,1
test_accuracy,0.8
train/epoch,20.0
train/example_ct,240000.0
train/train_loss,0.05924
val/val_accuracy,0.952
val/val_loss,0.15092


Train Loss: 0.119, Valid Loss: 0.245016, Accuracy: 0.93
Train Loss: 0.409, Valid Loss: 0.212966, Accuracy: 0.93
Train Loss: 0.108, Valid Loss: 0.190256, Accuracy: 0.94
Train Loss: 0.317, Valid Loss: 0.174174, Accuracy: 0.94
Train Loss: 0.180, Valid Loss: 0.165155, Accuracy: 0.95
Train Loss: 0.207, Valid Loss: 0.167066, Accuracy: 0.95
Train Loss: 0.036, Valid Loss: 0.154895, Accuracy: 0.95
Train Loss: 0.060, Valid Loss: 0.148212, Accuracy: 0.95
Train Loss: 0.059, Valid Loss: 0.149658, Accuracy: 0.95
Train Loss: 0.073, Valid Loss: 0.149387, Accuracy: 0.95
Train Loss: 0.020, Valid Loss: 0.137227, Accuracy: 0.96
Train Loss: 0.011, Valid Loss: 0.150689, Accuracy: 0.96
Train Loss: 0.080, Valid Loss: 0.144276, Accuracy: 0.95
Train Loss: 0.071, Valid Loss: 0.158557, Accuracy: 0.95
Train Loss: 0.047, Valid Loss: 0.149366, Accuracy: 0.96
Train Loss: 0.037, Valid Loss: 0.150438, Accuracy: 0.95
Train Loss: 0.039, Valid Loss: 0.157798, Accuracy: 0.96
Train Loss: 0.077, Valid Loss: 0.161429, Accurac

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▂▃▂▄▃▃▂▂▁▂▁▂▂▁▁▂▁▂▁▁▁▂▁▁▁▂▁▁▁▁▁▂▁▁▁▁▁▂
val/val_accuracy,▁▂▃▄▆▅▇▆▆▇█▇▇▆█▇▇▇▇▇
val/val_loss,█▆▄▃▃▃▂▂▂▂▁▂▁▂▂▂▂▃▂▃

0,1
test_accuracy,0.8
train/epoch,20.0
train/example_ct,240000.0
train/train_loss,0.00665
val/val_accuracy,0.953
val/val_loss,0.16148


Train Loss: 0.321, Valid Loss: 0.264122, Accuracy: 0.92
Train Loss: 0.179, Valid Loss: 0.212144, Accuracy: 0.94
Train Loss: 0.154, Valid Loss: 0.186276, Accuracy: 0.94
Train Loss: 0.098, Valid Loss: 0.169114, Accuracy: 0.95
Train Loss: 0.058, Valid Loss: 0.171764, Accuracy: 0.95
Train Loss: 0.102, Valid Loss: 0.158962, Accuracy: 0.95
Train Loss: 0.047, Valid Loss: 0.158981, Accuracy: 0.95
Train Loss: 0.078, Valid Loss: 0.158298, Accuracy: 0.95
Train Loss: 0.124, Valid Loss: 0.135745, Accuracy: 0.96
Train Loss: 0.113, Valid Loss: 0.140872, Accuracy: 0.95
Train Loss: 0.125, Valid Loss: 0.147893, Accuracy: 0.95
Train Loss: 0.026, Valid Loss: 0.145661, Accuracy: 0.95
Train Loss: 0.048, Valid Loss: 0.138221, Accuracy: 0.96
Train Loss: 0.151, Valid Loss: 0.144991, Accuracy: 0.96
Train Loss: 0.147, Valid Loss: 0.141671, Accuracy: 0.95
Train Loss: 0.080, Valid Loss: 0.146887, Accuracy: 0.96
Train Loss: 0.051, Valid Loss: 0.138932, Accuracy: 0.96
Train Loss: 0.022, Valid Loss: 0.144466, Accurac

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▅▃▃▄▂▃▁▃▂▂▁▂▂▁▂▂▂▁▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁
val/val_accuracy,▁▄▅▆▆▆▆▇▇▇▇▇▇▇▇████▇
val/val_loss,█▅▄▃▃▂▂▂▁▁▂▂▁▂▁▂▁▁▁▂

0,1
test_accuracy,0.8
train/epoch,20.0
train/example_ct,240000.0
train/train_loss,0.12063
val/val_accuracy,0.9555
val/val_loss,0.1489


You have now trained your first model using wandb! 👆 Click on the wandb link above to see your metrics