In [1]:
import wandb
import math
import random
import torch, torchvision
import torch.nn as nn
import torchvision.transforms as T
from tqdm.notebook import tqdm

device = "cuda:0" if torch.cuda.is_available() else "cpu"

def get_dataloader(is_train, batch_size, slice=5):
    "Get a training dataloader"
    full_dataset = torchvision.datasets.MNIST(root=".", train=is_train, transform=T.ToTensor(), download=True)
    sub_dataset = torch.utils.data.Subset(full_dataset, indices=range(0, len(full_dataset), slice))
    loader = torch.utils.data.DataLoader(dataset=sub_dataset, 
                                         batch_size=batch_size, 
                                         shuffle=True if is_train else False, 
                                         pin_memory=True, num_workers=2)
    return loader

def get_model(dropout):
    "A simple model"
    model = nn.Sequential(nn.Flatten(),
                         nn.Linear(28*28, 256),
                         nn.BatchNorm1d(256),
                         nn.ReLU(),
                         nn.Dropout(dropout),
                         nn.Linear(256,10)).to(device)
    return model

def validate_model(model, valid_dl, loss_func, log_images=False, batch_idx=0):
    "Compute performance of the model on the validation dataset and log a wandb.Table"
    model.eval()
    val_loss = 0.
    with torch.inference_mode():
        correct = 0
        for i, (images, labels) in tqdm(enumerate(valid_dl), leave=False):
            images, labels = images.to(device), labels.to(device)

            # Forward pass ➡
            outputs = model(images)
            val_loss += loss_func(outputs, labels)*labels.size(0)

            # Compute accuracy and accumulate
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()

            # Log one batch of images to the dashboard, always same batch_idx.
            if i==batch_idx and log_images:
                log_image_table(images, predicted, labels, outputs.softmax(dim=1))
    return val_loss / len(valid_dl.dataset), correct / len(valid_dl.dataset)

def log_image_table(images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # 🐝 Create a wandb Table to log images, labels and predictions to
    table = wandb.Table(columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)])
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wandb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wandb.log({"predictions_table":table}, commit=False)
    
def is_py_or_txt_path(path):
    return path.endswith(".ipynb") or path.endswith(".txt")

In [2]:
import wandb
# 🐝 initialise a wandb run
wandb.init(
    entity="wandb",
    project="launch-welcome",
    config={
        "epochs": 10,
        "batch_size": 128,
        "lr": 1e-3,
        "dropout": 0.5,
        }
    )

# Copy your config 
config = wandb.config

# Get the data
train_dl = get_dataloader(is_train=True, batch_size=config.batch_size)
valid_dl = get_dataloader(is_train=False, batch_size=2*config.batch_size)
n_steps_per_epoch = math.ceil(len(train_dl.dataset) / config.batch_size)

# A simple MLP model
model = get_model(config.dropout)

# Make the loss and optimizer
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.lr)

# Training
example_ct = 0
step_ct = 0
for epoch in tqdm(range(config.epochs)):
    model.train()
    for step, (images, labels) in enumerate(tqdm(train_dl, leave=False)):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        train_loss = loss_func(outputs, labels)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
        
        example_ct += len(images)
        metrics = {"train/train_loss": train_loss, 
                    "train/epoch": (step + 1 + (n_steps_per_epoch * epoch)) / n_steps_per_epoch, 
                    "train/example_ct": example_ct}
        
        if step + 1 < n_steps_per_epoch:
            # 🐝 Log train metrics to wandb 
            wandb.log(metrics)
            
        step_ct += 1

    val_loss, accuracy = validate_model(model, valid_dl, loss_func, log_images=(epoch==(config.epochs-1)))

    # 🐝 Log train and validation metrics to wandb
    val_metrics = {"val/val_loss": val_loss, 
                    "val/val_accuracy": accuracy}
    wandb.log({**metrics, **val_metrics})
    
    print(f"Train Loss: {train_loss:.3f}, Valid Loss: {val_loss:3f}, Accuracy: {accuracy:.2f}")

    # If you had a test set, this is how you could log it as a Summary metric
    wandb.summary['test_accuracy'] = 0.8

wandb.run.log_code(include_fn=is_py_or_txt_path)
# 🐝 Close your wandb run 
wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33mkylegoyette[0m (use `wandb login --relogin` to force relogin)


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.361, Valid Loss: 0.304362, Accuracy: 0.91


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.265, Valid Loss: 0.243378, Accuracy: 0.92


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.216, Valid Loss: 0.219385, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.288, Valid Loss: 0.199130, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.155, Valid Loss: 0.197639, Accuracy: 0.93


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.197, Valid Loss: 0.179202, Accuracy: 0.94


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.135, Valid Loss: 0.170792, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.152, Valid Loss: 0.175556, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.130, Valid Loss: 0.171380, Accuracy: 0.95


  0%|          | 0/94 [00:00<?, ?it/s]

0it [00:00, ?it/s]

Train Loss: 0.081, Valid Loss: 0.156511, Accuracy: 0.95



VBox(children=(Label(value='0.278 MB of 0.278 MB uploaded (0.019 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/example_ct,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/train_loss,█▅▄▄▃▂▃▂▂▂▃▂▁▂▂▁▂▂▂▃▂▂▂▂▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁
val/val_accuracy,▁▃▅▆▅▆▇▇▇█
val/val_loss,█▅▄▃▃▂▂▂▂▁

0,1
test_accuracy,0.8
train/epoch,10.0
train/example_ct,120000.0
train/train_loss,0.08105
val/val_accuracy,0.953
val/val_loss,0.15651
