<a href="https://colab.research.google.com/github/Armandpl/wandb_jetracer/blob/master/wandb_jetracer_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# 🔥 = W&B ➕ PyTorch ➕ Nvidia jetracer

In [1]:
%%capture
!pip install wandb --upgrade
!wget -O xy_dataset.py https://raw.githubusercontent.com/Armandpl/wandb_jetracer/master/utils/xy_dataset.py
!pip install torch==1.7.1
!pip install torchmetrics

In [2]:
import wandb
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from tqdm.notebook import tqdm
from torchmetrics.functional import mean_absolute_error, mean_squared_log_error, mean_squared_error

from xy_dataset import XYDataset
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Manually set pytorch seed to get the same dataset split everytime
torch.manual_seed(42)

<torch._C.Generator at 0x7fdd2155ffb0>

In [23]:
config = dict(
    epochs=45,
    architecture="resnet18",
    pretrained=True,
    batch_size=64,
    learning_rate=1e-4,
    dataset="suzuka:latest",
    train_pct=0.8,
    train_augs=False,
    loss="MSELoss"
    )

In [4]:
def make(run):
    # Pull the dataset
    artifact = run.use_artifact(run.config.dataset)
    artifact_dir = artifact.download()

    dataset = XYDataset(artifact_dir)

    train_len = int(len(dataset)*run.config.train_pct)
    test_len = len(dataset)-train_len

    train, test = torch.utils.data.random_split(dataset, (train_len, test_len))

    train_loader = make_loader(train, batch_size=run.config.batch_size)
    test_loader = make_loader(test, batch_size=run.config.batch_size)

    # Make the model
    model = torchvision.models.__dict__[run.config.architecture](pretrained=run.config.pretrained)
    model.fc = nn.Linear(model.fc.in_features, 2)
    model = model.to(device)
    model.train()

    # Make the loss and optimizer
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(
        model.parameters(), lr=run.config.learning_rate)
    
    return model, train_loader, test_loader, criterion, optimizer

In [5]:
def make_loader(dataset, batch_size):
    loader = torch.utils.data.DataLoader(dataset=dataset,
                                         batch_size=batch_size, 
                                         shuffle=True,
                                         pin_memory=True, num_workers=2)
    return loader

# 👟 Define Training Logic

In [12]:
def train(model, train_loader, test_loader, criterion, optimizer, run):
    # tell wandb to watch what the model gets up to: gradients, weights, and more!
    run.watch(model, criterion, log="all", log_freq=10)

    # Run training and track with wandb
    batch_ct = 0
    for epoch in tqdm(range(run.config.epochs)):
        for _, (images, labels) in enumerate(train_loader):

            images, targets = images.to(device), labels.to(device)

            optimizer.zero_grad()

            scores = model(images)
            loss = criterion(scores, targets)

            loss.backward()
            optimizer.step()

            # Report metrics every batch
            batch_ct += 1
            metrics = compute_metrics(scores, targets)
            loss = float(loss)
            log = {"epoch": epoch, "train/loss": loss}
            wandb.log({**log, **metrics})
        
        # evaluate every epoch
        test(model, test_loader, criterion, batch_ct)

In [22]:
def compute_metrics(preds, targets, step="train"):
    x_preds = preds[:,0]
    x_targets = targets[:,0]
    y_preds = preds[:,1]
    y_targets = targets[:,1]

    metrics = {
        f"{step}/x/MSE": mean_squared_error(x_preds, x_targets),
        f"{step}/x/MAE": mean_absolute_error(x_preds, x_targets),
        f"{step}/x/MSLE": mean_squared_log_error(x_preds, x_targets),

        f"{step}/y/MSE": mean_squared_error(y_preds, y_targets),
        f"{step}/y/MAE": mean_absolute_error(y_preds, y_targets),
        f"{step}/y/MSLE": mean_squared_log_error(y_preds, y_targets),

        f"{step}/MSE": mean_squared_error(preds, targets),
        f"{step}/MAE": mean_absolute_error(preds, targets),
        f"{step}/MSLE": mean_squared_log_error(preds, targets),
    }

    return metrics

# 🧪 Define Testing Logic

In [8]:
def test(model, test_loader, criterion, batch_ct):
    model.eval()

    # Run the model on some test examples
    with torch.no_grad():

        avg_metrics = None
        for images, targets in test_loader:
            images, targets = images.to(device), targets.to(device)
            scores = model(images)
            metrics = compute_metrics(scores, targets, step="valid")

            if avg_metrics is None:
              avg_metrics = metrics
            else:
              avg_metrics = average_dict([avg_metrics, metrics])
        
        wandb.log(avg_metrics, step=batch_ct)

    model.train()

In [9]:
from collections import Counter

def average_dict(l):
  """average list of dict l having the same keys"""
  sums = Counter()
  counters = Counter()
  for itemset in l:
      sums.update(itemset)
      counters.update(itemset.keys())

  return {x: float(sums[x])/counters[x] for x in sums.keys()}

In [24]:
with wandb.init(project="racecar", config=config, job_type="train", entity="wandb") as run:

  # make the model, data, and optimization problem
  model, train_loader, test_loader, criterion, optimizer = make(run)
  print(model)

  # and use them to train the model
  train(model, train_loader, test_loader, criterion, optimizer, run)

  # finally we log both models to wandb
  torch.save(model.state_dict(), 'model.pth')
  artifact = wandb.Artifact('model', type='model')
  artifact.add_file('model.pth')
  run.log_artifact(artifact)

Downloading: "https://download.pytorch.org/models/resnet18-5c106cde.pth" to /root/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth


HBox(children=(FloatProgress(value=0.0, max=46827520.0), HTML(value='')))


ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
 

HBox(children=(FloatProgress(value=0.0, max=45.0), HTML(value='')))




VBox(children=(Label(value=' 42.72MB of 42.72MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.…

0,1
epoch,44.0
train/loss,0.00502
train/x/MSE,0.00565
train/x/MAE,0.05719
train/x/MSLE,0.37365
train/y/MSE,0.00439
train/y/MAE,0.05692
train/y/MSLE,0.00521
train/MSE,0.00502
train/MAE,0.05705


0,1
epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/loss,█▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/x/MSE,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
train/x/MAE,█▅▃▄▃▂▂▂▂▂▁▂▂▁▂▂▁▂▂▁▁▁▁▂▂▁▁▁▂▁▃▁▂▁▁▁▁▁▁▁
train/x/MSLE,█ ▆ ▃▃▅ ▃▂ ▅ ▄ ▁▂▃ ▆
train/y/MSE,█▅▃▃▂▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/y/MAE,█▆▅▄▄▄▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▂▂▂▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▂
train/y/MSLE,█▇▇▆▃▃▃▂▂▂▂▂▁▂▂▂▁▂▁▂▂▂▁▁▂▂▂▂▁▁▁▁▁▁▁▁▁▂
train/MSE,█▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/MAE,█▅▄▄▃▃▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▁▁▂▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁
