# WANDB
This notebook shows a very basic usage of wandb to log ML runs. There are three heirarchy levels for logging -
```
entity
    project
        run (aka experiment)
```

The `entity` is usually the organization or team. In my case it is my username. Each `entity` can have many `project`s and each `project` can have many `run`s. A `run` is also referred to as an experiment in some of the documentation. I can log a bunch of things with wandb including matplotlib plots, histograms, images, etc. See [documentation for the `Run` object](https://docs.wandb.ai/ref/python/run).

#### How to interpret parameter/gradient histogram
Even though I have not explicitly named the layers in my model, PyTorch will give them default names. In this example the model has the following named parameters -
```
1.weight
1.bias
2.weight
2.bias
5.weight
5.bias
```

When logging the histogram of these parameters (or their corresponding gradients), wandb uses the following naming scheme -
`graph_{idx}{param_name}`

**Refs**
  * [wandb_watch.py:86](https://github.com/wandb/wandb/blob/722f9737ce1a77b8970fef275047e8a0f4a1a68e/wandb/sdk/wandb_watch.py#L86C45-L86C45)
  * [wandb_torch.py:105](https://github.com/wandb/wandb/blob/07051ce76e01d3e30cf3b25b42c22cd94cf62f5f/wandb/wandb_torch.py#L105)

The `idx` is a global variable in wandb and can be anything. This results in funny chart names like `graph51.bias`, where the `idx = 5` and `1.bias` is the param name.

![parameters](./parameters.png)

In [1]:
from dataclasses import asdict, dataclass
from pathlib import Path

import torch as t
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from tqdm import tqdm

import wandb as wb

In [2]:
DEVICE = t.device("mps")
DATAROOT = Path.home()/"mldata"/"mnist"

In [3]:
@dataclass
class Hyperparams:
    n_epochs: int
    batch_size: int
    dropout: float
    lr: float

In [4]:
def build_model(dropout: float) -> t.nn.Module:
    return t.nn.Sequential(
        t.nn.Flatten(),
        t.nn.Linear(28*28, 256),
        t.nn.BatchNorm1d(256),
        t.nn.ReLU(),
        t.nn.Dropout(dropout),
        t.nn.Linear(256, 10)).to(DEVICE)

In [5]:
def build_dataloader(is_train: bool, batch_size: int, slice = 5):
    full_dataset = MNIST(
        root=DATAROOT, 
        train=is_train, 
        transform=ToTensor(), 
        download=True
    )
    sub_dataset = Subset(
        full_dataset, 
        indices=range(0, len(full_dataset), slice)
    )
    return DataLoader(
        dataset=sub_dataset, 
        batch_size=batch_size, 
        shuffle=True if is_train else False, 
        pin_memory=True, num_workers=2
    )

In [6]:
def log_image_table(tablename, images, predicted, labels, probs):
    "Log a wandb.Table with (img, pred, target, scores)"
    # üêù Create a wandb Table to log images, labels and predictions to
    table = wb.Table(
        columns=["image", "pred", "target"]+[f"score_{i}" for i in range(10)]
    )
    for img, pred, targ, prob in zip(images.to("cpu"), predicted.to("cpu"), labels.to("cpu"), probs.to("cpu")):
        table.add_data(wb.Image(img[0].numpy()*255), pred, targ, *prob.numpy())
    wb.log({f"{tablename}":table}, commit=False)

In [7]:
def eval(model, valdl, loss_fn, epoch=0, batch_idx=0, log_images=False):
    losses = []
    corrects = []
    totals = []

    model.eval()
    with t.inference_mode():
        for i, (inputs, targets) in enumerate(valdl):
            inputs = inputs.to(DEVICE)
            targets = targets.to(DEVICE)
            
            outputs = model(inputs)

            loss = loss_fn(outputs, targets).item()
            losses.append(loss)
            
            preds = t.argmax(outputs, dim=1)
            correct = (preds == targets).sum().item()            
            corrects.append(correct)
            totals.append(len(targets))

            # If I don't give the table name then the tables keep getting
            # overwritten with every new eval epoch. I can give each epoch's
            # table a different name like shown below, but the rendering on
            # wandb is not very good. see 
            # https://wandb.ai/avilay/learn-wandb-exp/runs/ssdqommr?workspace=user-avilay
            # BEST PRACTICE: Just log the images for the last eval run.
            if i == batch_idx and log_images:
                log_image_table(
                    f"epoch-{epoch}",
                    inputs[:5],
                    preds[:5],
                    targets[:5],
                    outputs[:5].softmax(dim=1)
                )
    avg_loss = t.mean(t.tensor(losses)).item()
    avg_acc = (t.sum(t.tensor(corrects)) / t.sum(t.tensor(totals))).item()
    wb.log({
        "val/loss": avg_loss,
        "val/Accuracy": avg_acc
    })


In [None]:
model = build_model(dropout=0.25)
model.to(DEVICE)
valdl = build_dataloader(is_train=False, batch_size=100)
loss_fn = t.nn.CrossEntropyLoss()

In [None]:
eval(model, valdl, loss_fn)

In [None]:
del model
del valdl
del loss_fn

In [10]:
def train(model, traindl, loss_fn, optim, start_step=0, log_freq=1):
    losses = []
    model.train()
    with t.enable_grad():
        for step, batch in enumerate(tqdm(traindl)):
            images = batch[0].to(DEVICE)
            targets = batch[1].to(DEVICE)

            optim.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optim.step()

            losses.append(loss.detach().item())
            if (start_step + step) % log_freq == 0:
                avg_loss = t.mean(t.tensor(losses)).item()
                wb.log({
                    "train/loss": avg_loss
                }, step=(start_step + step))
                losses = []

In [12]:
hparams = Hyperparams(
    n_epochs=3,
    batch_size=32,
    dropout=0.2,
    lr=0.001
)

run = wb.init(
    project="learn-wandb-basic",
    config=asdict(hparams)
)

model = build_model(dropout=hparams.dropout)
model.to(DEVICE)

traindl = build_dataloader(is_train=True, batch_size=hparams.batch_size)
valdl = build_dataloader(is_train=False, batch_size=100)

loss_fn = t.nn.CrossEntropyLoss()
optim = t.optim.AdamW(model.parameters(), lr=hparams.lr)
steps_per_epoch = len(traindl)

wb.watch(model, loss_fn, log="all", log_freq=100)

for epoch in range(hparams.n_epochs):
    train(model, traindl, loss_fn, optim, start_step=steps_per_epoch * epoch, log_freq=100)
    eval(
        model, 
        valdl, 
        loss_fn, 
        epoch=epoch, 
        log_images=True if epoch == hparams.n_epochs - 1 else False
    )

wb.finish()


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 375/375 [00:06<00:00, 55.12it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 375/375 [00:07<00:00, 52.60it/s] 
100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 375/375 [00:07<00:00, 49.26it/s] 


VBox(children=(Label(value='0.006 MB of 0.008 MB uploaded\r'), FloatProgress(value=0.8245614035087719, max=1.0‚Ä¶



0,1
train/loss,‚ñà‚ñÉ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val/Accuracy,‚ñÅ‚ñá‚ñà
val/loss,‚ñà‚ñÇ‚ñÅ

0,1
train/loss,0.1464
val/Accuracy,0.9445
val/loss,0.16845


In [13]:
model

Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=256, bias=True)
  (2): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU()
  (4): Dropout(p=0.2, inplace=False)
  (5): Linear(in_features=256, out_features=10, bias=True)
)

In [14]:
for name, parameter in model.named_parameters():
    print(name, parameter.shape)

1.weight torch.Size([256, 784])
1.bias torch.Size([256])
2.weight torch.Size([256])
2.bias torch.Size([256])
5.weight torch.Size([10, 256])
5.bias torch.Size([10])
