In [1]:
!pip install wandb -qU

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.3/184.3 kB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m206.5/206.5 kB[0m [31m23.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m6.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [3]:
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [4]:
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(dropout):
    "A simple model"
    model = nn.Sequential(nn.Flatten(),
                         nn.Linear(28*28, 256),
                         nn.BatchNorm1d(256),
                         nn.ReLU(),
                         nn.Dropout(dropout),
                         nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in enumerate(valid_dl):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)

In [5]:
# Launch 5 experiments, trying different dropout rates
for _ in range(5):
    # 🐝 initialise a wandb run
    wandb.init(
        project="pytorch-intro",
        config={
            "epochs": 10,
            "batch_size": 128,
            "lr": 1e-3,
            "dropout": random.uniform(0.01, 0.80),
            })
    
    # Copy your config 
    config = wandb.config

    # Get the data
    train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
    valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
    n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)
    
    # A simple MLP model
    model = get_model(config.dropout)

    # Make the loss and optimizer
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

   # Training
    example_ct = 0
    step_ct = 0
    for epoch in range(config.epochs):
        model.train()
        for step, (images, labels) in enumerate(train_dl):
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            train_loss = loss_func(outputs, labels)
            optimizer.zero_grad()
            train_loss.backward()
            optimizer.step()
            
            example_ct += len(images)
            metrics = {"train/train_loss": train_loss, 
                       "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                       "train/example_ct": example_ct}
            
            if step + 1 < n_steps_per_epoch:
                # 🐝 Log train metrics to wandb 
                wandb.log(metrics)
                
            step_ct += 1

        val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

        # 🐝 Log train and validation metrics to wandb
        val_metrics = {"val/val_loss": val_loss, 
                       "val/val_accuracy": accuracy}
        wandb.log({**metrics, **val_metrics})
        
        print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

    # 🐝 Close your wandb run 
    wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33marthur-v-qin[0m. Use [1m`wandb login --relogin`[0m to force relogin


Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:00<00:00, 160891111.55it/s]


Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 18508127.40it/s]


Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 43223780.93it/s]

Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 3066236.72it/s]


Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw

Train Loss: 0.521, Valid Loss: 0.347669, Accuracy: 0.91
Train Loss: 0.399, Valid Loss: 0.282886, Accuracy: 0.92
Train Loss: 0.335, Valid Loss: 0.259039, Accuracy: 0.92
Train Loss: 0.201, Valid Loss: 0.241721, Accuracy: 0.92
Train Loss: 0.212, Valid Loss: 0.232045, Accuracy: 0.93
Train Loss: 0.335, Valid Loss: 0.221272, Accuracy: 0.93
Train Loss: 0.349, Valid Loss: 0.208346, Accuracy: 0.93
Train Loss: 0.236, Valid Loss: 0.202134, Accuracy: 0.94
Train Loss: 0.213, Valid Loss: 0.199120, Accuracy: 0.94
Train Loss: 0.239, Valid Loss: 0.190516, Accuracy: 0.94


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▃▃▃▂▂▂▂▂▂▁▂▁▂▂▂▂▂▁▂▁▃▂▂▂▂▁▁▂▁▂▁▁▂▁▁▁▁
val/val_accuracy,▁▃▄▄▅▆▇▇██
val/val_loss,█▅▄▃▃▂▂▂▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.2389
val/val_accuracy,0.9395
val/val_loss,0.19052


Train Loss: 0.532, Valid Loss: 0.329610, Accuracy: 0.92
Train Loss: 0.279, Valid Loss: 0.264059, Accuracy: 0.92
Train Loss: 0.211, Valid Loss: 0.244534, Accuracy: 0.92
Train Loss: 0.285, Valid Loss: 0.224124, Accuracy: 0.93
Train Loss: 0.217, Valid Loss: 0.212275, Accuracy: 0.94
Train Loss: 0.311, Valid Loss: 0.201968, Accuracy: 0.94
Train Loss: 0.232, Valid Loss: 0.190789, Accuracy: 0.94
Train Loss: 0.230, Valid Loss: 0.191437, Accuracy: 0.94
Train Loss: 0.171, Valid Loss: 0.186410, Accuracy: 0.94
Train Loss: 0.121, Valid Loss: 0.177122, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▄▃▂▃▃▃▃▃▃▁▂▃▂▂▂▃▂▂▂▂▁▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▂
val/val_accuracy,▁▂▃▄▅▆▅▆▇█
val/val_loss,█▅▄▃▃▂▂▂▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.12059
val/val_accuracy,0.9495
val/val_loss,0.17712


Train Loss: 0.291, Valid Loss: 0.300051, Accuracy: 0.91
Train Loss: 0.309, Valid Loss: 0.240876, Accuracy: 0.93
Train Loss: 0.263, Valid Loss: 0.212816, Accuracy: 0.93
Train Loss: 0.255, Valid Loss: 0.199201, Accuracy: 0.94
Train Loss: 0.179, Valid Loss: 0.182370, Accuracy: 0.94
Train Loss: 0.103, Valid Loss: 0.171615, Accuracy: 0.95
Train Loss: 0.055, Valid Loss: 0.166698, Accuracy: 0.95
Train Loss: 0.054, Valid Loss: 0.159757, Accuracy: 0.95
Train Loss: 0.054, Valid Loss: 0.153360, Accuracy: 0.95
Train Loss: 0.058, Valid Loss: 0.155677, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▄▃▃▃▃▃▃▁▂▂▂▂▂▂▃▂▁▃▁▁▁▂▁▁▁▁▁▁▂▁▁▁▂▁▂▁▁▁▁
val/val_accuracy,▁▄▅▆▆▇████
val/val_loss,█▅▄▃▂▂▂▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.05789
val/val_accuracy,0.949
val/val_loss,0.15568


Train Loss: 0.290, Valid Loss: 0.303804, Accuracy: 0.91
Train Loss: 0.329, Valid Loss: 0.253271, Accuracy: 0.92
Train Loss: 0.259, Valid Loss: 0.214477, Accuracy: 0.94
Train Loss: 0.153, Valid Loss: 0.198665, Accuracy: 0.94
Train Loss: 0.176, Valid Loss: 0.186424, Accuracy: 0.94
Train Loss: 0.124, Valid Loss: 0.175099, Accuracy: 0.94
Train Loss: 0.087, Valid Loss: 0.170804, Accuracy: 0.94
Train Loss: 0.163, Valid Loss: 0.170874, Accuracy: 0.94
Train Loss: 0.092, Valid Loss: 0.161157, Accuracy: 0.95
Train Loss: 0.054, Valid Loss: 0.164605, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▃▂▃▃▃▃▂▁▂▂▂▂▁▂▂▁▂▁▂▁▂▁▁▁▂▁▁▁▂▁▁▂▁▁▂▁▁
val/val_accuracy,▁▂▅▆▆▇▇▇██
val/val_loss,█▆▄▃▂▂▁▁▁▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.05437
val/val_accuracy,0.95
val/val_loss,0.1646


Train Loss: 0.243, Valid Loss: 0.280754, Accuracy: 0.91
Train Loss: 0.252, Valid Loss: 0.226741, Accuracy: 0.93
Train Loss: 0.203, Valid Loss: 0.202921, Accuracy: 0.94
Train Loss: 0.119, Valid Loss: 0.193697, Accuracy: 0.94
Train Loss: 0.092, Valid Loss: 0.174898, Accuracy: 0.94
Train Loss: 0.100, Valid Loss: 0.162667, Accuracy: 0.95
Train Loss: 0.083, Valid Loss: 0.162034, Accuracy: 0.95
Train Loss: 0.101, Valid Loss: 0.148741, Accuracy: 0.95
Train Loss: 0.040, Valid Loss: 0.163625, Accuracy: 0.94
Train Loss: 0.027, Valid Loss: 0.147997, Accuracy: 0.95


0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▃▂▂▃▃▂▂▂▁▂▂▂▁▂▂▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁
val/val_accuracy,▁▄▅▅▆▇██▆█
val/val_loss,█▅▄▃▂▂▂▁▂▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.02702
val/val_accuracy,0.952
val/val_loss,0.148
