# W&B Model Management Guide Companion Notebook

**Please add `artifact-portfolios` to your bio in order to enable beta features required to complete this guide.**

This is a comparnion notebook to the [W&B Model Management Guide](https://docs.wandb.ai/guides/models).

**Table of Contents**
* **Cell 1**: Installs `wandb` python library
* **Cell 2** (Form): Allows you to specify some parameters and defines a handful of helper functions. Note: there is not any `wandb` specific library calls in these helper functions - they are purely used to allow the example cells to be more terse and focus on the key aspects of Model Management
* **Cell 3**: (Train, Log, & Link Models) Covers steps [2. Traing & log a Model](https://docs.wandb.ai/guides/models#2.-train-and-log-model-versions) and [3. Link Model Versions to the Collection](https://docs.wandb.ai/guides/models#3.-link-model-versions-to-the-portfolio)
* **Cell 4**: (Use, Evaluate, and Promote a Model) Covers steps [4. Using a Model Version](https://docs.wandb.ai/guides/models#4.-use-a-model-version), [5. Evaluate Model Performance](https://docs.wandb.ai/guides/models#5.-evaluate-model-performance), and [6. Promote a Version to Production](https://docs.wandb.ai/guides/models#6.-promote-a-version-to-production)




# Setup
**Stop! 🛑** Please complete [Step 1 of the tutorial](https://docs.wandb.ai/guides/models#1.-create-a-new-model-collection) before continuing. This will ensure you have a **Model Collection** defined in your project. Enter the Project name where you created the Collection in `project_name` and the name of the Collection in the `model_collection_name` fields respectively.

In [None]:
!pip install wandb -U -qqq

[K     |████████████████████████████████| 1.8 MB 1.9 MB/s 
[K     |████████████████████████████████| 181 kB 17.2 MB/s 
[K     |████████████████████████████████| 145 kB 35.6 MB/s 
[K     |████████████████████████████████| 63 kB 1.4 MB/s 
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


In [None]:
project_name = "model_management_docs_official_v0" #@param {type:"string"}
# dataset_name = "mnist" #@param {type:"string"}
dataset_name = "mnist"
model_collection_name = "MNIST Grayscale 28x28" #@param {type:"string"}
use_beta_apis = False #@param {type:"boolean"}

import torch
from torchvision import datasets, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import StepLR

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

def build_training_data(train_size, val_size, batch_size, data_path):
  mnist_data = datasets.MNIST(
      data_path, 
      download  = True, 
      transform = transforms.Compose([
        transforms.ToTensor(), 
        ])
      )

  extra = 60000 - train_size - val_size
  assert extra >= 0
  splits = torch.utils.data.random_split(mnist_data, [train_size, val_size, extra])
  return torch.utils.data.DataLoader(splits[0], batch_size=batch_size), torch.utils.data.DataLoader(splits[1], batch_size=val_size)

def build_test_data(test_size, data_path):
  mnist_data = datasets.MNIST(
      data_path, 
      train = False,
      download  = True, 
      transform = transforms.Compose([
        transforms.ToTensor(), 
        ])
      )

  extra = 10000 - test_size
  assert extra >= 0
  splits = torch.utils.data.random_split(mnist_data, [test_size, extra])
  return torch.utils.data.DataLoader(splits[0], batch_size=test_size)

def build_model(learning_rate, gamma):
  device = torch.device("cpu")
  model = Net().to(device)
  optimizer = optim.Adadelta(model.parameters(), lr = learning_rate)
  scheduler = StepLR(optimizer, step_size = 1, gamma = gamma)
  return model, optimizer, scheduler

def train_model_batch(model, optimizer, batch_x, batch_y):
  model.train()
  data, target = batch_x.to("cpu"), batch_y.to("cpu")
  optimizer.zero_grad()
  preds = model(data)
  loss = F.nll_loss(preds, target)
  loss.backward()
  optimizer.step()
  return loss.item(), preds

def train_model_epoch(model, optimizer, train_loader, on_batch_end):
  for batch_ndx, batch in enumerate(train_loader):
    train_loss, preds = train_model_batch(model, optimizer, batch[0], batch[1])
    on_batch_end(batch_ndx, batch, preds, train_loss)

def train_model(model, optimizer, scheduler, train_loader, val_loader, num_epochs, on_batch_end, on_epoch_end):
  for epoch_ndx in range(num_epochs):
    def patched_on_batch_end(batch_ndx, batch, preds, train_loss):
      on_batch_end(epoch_ndx, batch_ndx, batch, preds, train_loss)
    train_model_epoch(model, optimizer, train_loader, patched_on_batch_end)
    on_epoch_end(epoch_ndx)
    scheduler.step()

def evaluate_model(model, eval_data):
    device = torch.device("cpu")
    model.eval()
    test_loss = 0
    correct = 0
    preds = []
    with torch.no_grad():
        for data, target in eval_data:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(
                output, target, reduction="sum"
            ).item()  # sum up batch loss
            pred = output.argmax(
                dim=1, keepdim=True
            )  # get the index of the max log-probability
            preds += list(pred.flatten().tolist())
            correct += pred.eq(target.view_as(pred)).sum().item()
    test_loss /= len(eval_data.dataset)
    accuracy = 100.0 * correct / len(eval_data.dataset)
    return test_loss, accuracy, preds


# Train, Log, & Link Models

In [None]:
import wandb
from wandb.beta.workflows import log_model, link_model

import torch
import cloudpickle

# Startup a W&B Run
wandb.init(project=project_name, 
  job_type="model_trainer",
  config={
      "train_size": 100,
      "val_size": 30,
      "batch_size": 64,
      "learning_rate": 1.0,
      "gamma": 0.75,
      "epochs": 5,
  }
)

# Load in the training data
train_size = wandb.config["train_size"]
val_size = wandb.config["val_size"]
batch_size=wandb.config["batch_size"]
train_data_path = "./train_data"
train, val = build_training_data(train_size, val_size, batch_size, train_data_path)

# (Optional) Declare dataset dependency
art = wandb.Artifact(f"{dataset_name}-train", "dataset")
art.add_dir(train_data_path)
wandb.use_artifact(art)

# Define a model
learning_rate = wandb.config["learning_rate"]
gamma = wandb.config["gamma"]
model, optimizer, scheduler = build_model(learning_rate, gamma)

# Setup callbacks:
def on_batch_end(epoch_ndx, batch_ndx, batch, preds, train_loss):
  wandb.log({
      "epoch_ndx": epoch_ndx,
      "batch_ndx": batch_ndx,
      "train_loss": train_loss,
      "learning_rate": optimizer.param_groups[0]["lr"]
  })

best_loss = float("inf")
best_model = None
def on_epoch_end(epoch_ndx):
  global best_loss
  global best_model
  val_loss, val_acc, preds = evaluate_model(model, val)
  is_best = val_loss < best_loss
  if is_best:
    best_loss = val_loss

  ##### W&B MODEL MANAGEMENT SPEICIFIC CALLS ######
  if use_beta_apis:
    model_version = log_model(model, "mnist", ["best"] if is_best else None)
    if is_best:
      best_model = model_version
  else:
    art = wandb.Artifact(f"mnist-{wandb.run.id}", "model")
    torch.save(model, "model.pt", pickle_module=cloudpickle)
    art.add_file("model.pt")
    wandb.log_artifact(art, aliases=["best", "latest"] if is_best else None)
    if is_best:
      best_model = art
  wandb.log({
      "epoch_ndx": epoch_ndx,
      "val_loss": val_loss,
      "val_acc": val_acc,
      "learning_rate": optimizer.param_groups[0]["lr"],
      "best_loss": best_loss
  })
  print(f"Epoch {epoch_ndx}: val_loss: {val_loss}, val_acc: {val_acc}")
  

# Train the Model
epochs = wandb.config["epochs"]
train_model(model, optimizer, scheduler, train, val, epochs, on_batch_end, on_epoch_end)

##### W&B MODEL MANAGEMENT SPEICIFIC CALLS ######
if use_beta_apis:
  link_model(best_model, model_collection_name)
else:
  wandb.run.link_artifact(best_model, model_collection_name, ["latest"])

# Finish the Run
wandb.finish()

[34m[1mwandb[0m: Adding directory to artifact (./train_data)... Done. 0.2s


Epoch 0: val_loss: 2.2376441955566406, val_acc: 13.333333333333334
Epoch 1: val_loss: 2.1073472340901693, val_acc: 20.0
Epoch 2: val_loss: 1.9300413767496745, val_acc: 33.333333333333336
Epoch 3: val_loss: 1.7272534688313803, val_acc: 36.666666666666664
Epoch 4: val_loss: 1.5247809092203777, val_acc: 46.666666666666664


VBox(children=(Label(value='22.934 MB of 22.934 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, m…

0,1
batch_ndx,▁█▁█▁█▁█▁█
best_loss,█▇▅▃▁
epoch_ndx,▁▁▁▃▃▃▅▅▅▆▆▆███
learning_rate,███▅▅▅▄▄▄▂▂▂▁▁▁
train_loss,███▇▆▅▅▄▃▁
val_acc,▁▂▅▆█
val_loss,█▇▅▃▁

0,1
batch_ndx,1.0
best_loss,1.52478
epoch_ndx,4.0
learning_rate,0.31641
train_loss,1.35139
val_acc,46.66667
val_loss,1.52478


# Use, Evaluate, and Promote a Model

In [None]:
import wandb
from wandb.beta.workflows import use_model

import torch
# Startup a W&B Run
wandb.init(project=project_name, 
  job_type="model_evaluator",
  config={
      "test_size": 100,
  }
)

# Load in the test data
test_size = wandb.config["test_size"]
test_data_path = "./test_data"
test = build_test_data(test_size, test_data_path)

# (Optional) Declare dataset dependency
art = wandb.Artifact(f"{dataset_name}-test", "dataset")
art.add_dir(test_data_path)
wandb.use_artifact(art)

##### W&B MODEL MANAGEMENT SPEICIFIC CALLS ######
if use_beta_apis:
  model_art = use_model(f"{model_collection_name}:latest")
  model_obj = model_art.model_obj()
else:
  model_art = wandb.use_artifact(f"{model_collection_name}:latest")
  model_path = model_art.get_path("model.pt").download()
  model_obj = torch.load(model_path)

model_obj.eval()
val_loss, val_acc, preds = evaluate_model(model_obj, test)

table = wandb.Table(data=[], columns=[])
table.add_column("image", [wandb.Image(i.numpy()) for i in list(test)[0][0]])
table.add_column("label", list(test)[0][1].tolist())
table.add_column("pred", preds)

wandb.log({
    "test_loss": val_loss,
    "test_acc": val_acc,
    "predictions": table
})

##### W&B MODEL MANAGEMENT SPEICIFIC CALLS ######
if use_beta_apis:
  link_model(model_art, model_collection_name, aliases=["production"])
else:
  wandb.run.link_artifact(model_art, model_collection_name, ["latest", "production"])


wandb.finish()

[34m[1mwandb[0m: Adding directory to artifact (./test_data)... Done. 0.2s


VBox(children=(Label(value='0.091 MB of 0.091 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
test_acc,▁
test_loss,▁

0,1
test_acc,48.0
test_loss,1.6166
