In [1]:
import wandb
import os
import math
import pandas as pd
from wandb.workflows import model_registry as MR
from wandb.workflows import model_registry_example as UserCode

project_name = "model_reg_api_demo_5"
os.environ["WANDB_PROJECT"] = project_name

# 💎 Step 1: Train a Model

Here, we train a model against the famous MNIST dataset. For brevity of this demo, typical user code has been abstracted out into a `UserCode` package.

* We initialize a run with a configuration - as is standard for wandb scripts
* We then call some user code to fetch training data, build a model, and train
* Importantly, we define an `onEpochEnd` callback hook which simulates how the user might instrument saving a model periodically. We log loss and accuracy in the familiar wandb.log method.
* **The key method to notice is `log_model`** which is a 1 liner to log a model.

The script will result in `epochs` # of versions in an Artifact Sequence

In [2]:
run1 = wandb.init(config={
    "batch_size"    : 64,
    "gamma"         : 0.7,
    "lr"            : 1.0,
    "epochs"        : 5,
    "seed"          : 1,
    "train_count"   : 1000,
    "val_count"     : 200,
})
cfg                     = wandb.config
_                       = UserCode.seed(cfg.seed)

train_data, val_data    = UserCode.load_training_data_split(train_count=cfg.train_count, val_count=cfg.val_count)
model, opt              = UserCode.build_model(lr=cfg.lr)

lowest_loss             = math.inf
best_model              = None

def onEpochEnd(epoch, model):
    global lowest_loss
    global best_model

    val_loss, val_acc, _ = UserCode.evaluate_model(model, val_data)
    
    wandb.log({
        "epoch"    : epoch, 
        "val_loss" : val_loss, 
        "val_acc"  : val_acc
    })
    
    if val_loss < lowest_loss:
        lowest_loss     = val_loss
        best_model      = MR.log_model(model, "mnist_nn", aliases=["best"])
    else:
        _               = MR.log_model(model, "mnist_nn")
    

_ = UserCode.train_model(
    model        = model, 
    optimizer    = opt, 
    train_data   = train_data, 
    batch_size   = cfg.batch_size, 
    gamma        = cfg.gamma, 
    epochs       = cfg.epochs, 
    onEpochEnd   = onEpochEnd
)

wandb.finish()


[34m[1mwandb[0m: Currently logged in as: [33mtimssweeney[0m (use `wandb login --relogin` to force relogin)



Test set: Average loss: 1.0309, Accuracy: 132/200 (66%)


Test set: Average loss: 0.4843, Accuracy: 168/200 (84%)


Test set: Average loss: 0.3149, Accuracy: 177/200 (88%)


Test set: Average loss: 0.2868, Accuracy: 183/200 (92%)


Test set: Average loss: 0.2658, Accuracy: 186/200 (93%)



0,1
epoch,▁▃▅▆█
val_acc,▁▆▇██
val_loss,█▃▁▁▁

0,1
epoch,5.0
val_acc,93.0
val_loss,0.26582


# ✈️ Step 2: Publish to a Registered Model

TODO: Either get `MR.link(best_model, "mnist")` to work, or show this in the UI

# ⚙️ Step 3: Evaluate a Model

In this next step, we use similar abstractions. We show the following:

* User code to load test data
* **New `use_model` function that fetches a SavedModel**
* User code to evaluate the model & create a table of prections
* **New `log_evaluation_table` function to log an eval table for the model**

In [3]:
# TODO: Change this to the collection you want to target
# Ideally this is a portfolio, not the sequence from the run, but waiting on linking
registered_model_name = f'mnist_nn-{run1.id}'
wandb.init(
    job_type="evaluation",
    config={
        "test_size": 100,
    }
)
test_data        = UserCode.load_test_data(wandb.config.test_size)
model            = MR.use_model(f'{project_name}/{registered_model_name}')
loss, acc, preds = UserCode.evaluate_model(model.raw_model(), test_data)

# Uggg, I don't really like this
t = wandb.Table(columns=["x", "y"], data=[[wandb.Image(row[0]), row[1]] for row in test_data])
t.add_column("pred", preds)

MR.log_evaluation_table(
    table               = t,
    # model_or_id         = model,
    additional_metrics  = {
        "test_loss": loss,
        "test_accuracy": acc,
    }
)
wandb.finish()



Test set: Average loss: 0.2430, Accuracy: 96/100 (96%)



In [4]:
# Version Fetching:
versions = MR.model_versions(f'{project_name}/{registered_model_name}')
print(f'Found {len(versions)} versions')

Found 5 versions
