<a href="https://colab.research.google.com/github/Armandpl/wandb-jetracer/blob/master/wandb_jetracer/wandb_jetracer_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# 🏁🏎️💨 = W&B ➕ NVIDIA Jetracer

This is a companion notebook for [title of the rc car video](). We are going to use [PyTorch Lightning](https://www.pytorchlightning.ai/) and [Weights&Biases](https://wandb.ai/site) to train a neural net that can drive [an NVIDIA Jetracer RC car](https://github.com/NVIDIA-AI-IOT/jetracer). 

Weights&Biases is a lightweight developper toolkit for [experiment tracking](https://wandb.ai/site/experiment-tracking), [model management and dataset versioning](https://wandb.ai/site/artifacts). Going through this notebook you'll see how we're leveraging it to effortlessly train and deploy models to the RC car while having good tracability of what model was trained on which version of the dataset with which hyperparameters.

<p align="center">
  <img src="https://raw.githubusercontent.com/Armandpl/wandb-jetracer/master/assets/header.png">
</p>
<p align="center">
<i>Here is the car's pov along with it's label (green circle) indicating the center of the 'road'.</i>
</p>

You can experiment with this notebook even if you don't have access to an RC car. If you run an interesting experiment ping me [@armand_dpl](https://twitter.com/armand_dpl) and I will try it on the car!


# 0. Setup
Here we are installing and importing dependencies.  
We are also cloning https://github.com/Armandpl/wandb_jetracer to get util functions. This repository builds on top of the [NVIDIA Jetracer](https://github.com/NVIDIA-AI-IOT/jetracer) project to instrument it with Weights&Biases. 

In [None]:
!pip install wandb

In [None]:
!git clone https://github.com/Armandpl/wandb_jetracer
!pip install pytorch-lightning torchmetrics

In [None]:
import math
import os

import cv2
import PIL
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import wandb

from wandb_jetracer.utils.xy_dataset import XYDataset
from wandb_jetracer.utils.utils import show_label, torch2cv2

# 1. Training a model

The idea to get the car to drive by itself is to train a model to infer the center of the road. If the center is on the left we steer to the left, if it's on the right we steer to the right.  

To solve this regression task and predict the center of the racetrack we fine-tune a [ResNet](https://arxiv.org/abs/1512.03385) model. We replace the last fully-connected layer with our own fully-connected layer.  

If you want to experiment with [the model architecture](https://pytorch.org/vision/stable/models.html), feel free to modify the `build_model` function. Maybe you could try a MobileNet architecture? or an EfficientNet? or a simple convolutional network?  

If you tried a different architecture and are proud of your results feel free to send me a DM [@armand_dpl](https://twitter.com/armand_dpl) and I will try to run your model on the actual car!

In [None]:
def build_model():
    model = torchvision.models.__dict__[config.architecture](pretrained=config.pretrained)
    model.fc = nn.Linear(model.fc.in_features, 2)
    return model

Most of the code in the cell below is pretty standard. If you'd like to learn more about Weights and Biases and [PyTorch Lightning](https://www.pytorchlightning.ai/) you can check out this video: [⚡ Supercharge your Training with PyTorch Lightning + Weights & Biases](https://www.youtube.com/watch?v=hUXQm46TAKc).

In [None]:
class RoadRegression(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        self.config = config

        # setting up metrics
        metrics = torchmetrics.MetricCollection([
            torchmetrics.MeanSquaredError(),
            torchmetrics.MeanAbsoluteError()
        ])
        self.train_metrics = metrics.clone(prefix='train/')
        self.valid_metrics = metrics.clone(prefix='val/')
        self.test_metrics = metrics.clone(prefix='test/')

        self.model = build_model()

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)
        loss = getattr(F, self.config.loss)(preds, targets)

        metrics = self.train_metrics(preds, targets)
        self.log_dict(metrics, on_step=True, on_epoch=False)

        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)

        metrics = self.valid_metrics(preds, targets)
        self.log_dict(metrics, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)

        metrics = self.test_metrics(preds, targets)
        self.log_dict(metrics, on_step=False, on_epoch=True)

        return (images, preds, targets)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)
        return optimizer

## Visualizing predictions on the test set  
To get an intuition of how our model is doing we log our test set along with predictions, losses and ground truth as a [wandb table](https://docs.wandb.ai/guides/data-vis). We can then go to our dashboard and explore the predictions. For example we can sort by highest losses to gauge what's difficult for the model.  
If you want to see what that looks like you can checkout this [run's page](https://wandb.ai/wandb/racecar/runs/16dlsdf7).

In [None]:
def test_epoch_end(self, test_step_outputs):
    images, predictions, targets = self.concat_test_outputs(test_step_outputs)

    # compute loss for each image of the test set 
    losses = F.mse_loss(predictions, targets, reduction='none')

    test_table = self.create_table(images, predictions, targets, losses)

    wandb.log({"test/predictions": test_table})

def create_table(self, images, predictions, targets, losses):
    # display preds and targets on images
    images_with_preds = []
    for idx, image in enumerate(images):
        img = torch2cv2(image)

        # show ground truth and prediction on the image
        img = show_label(img, targets[idx])
        img = show_label(img, predictions[idx], (0, 0, 255))

        images_with_preds.append(img)

    # create a WandB table
    my_data = [
        [wandb.Image(img), pred, target, loss.sum()] 
        for img, pred, target, loss
        in zip(images_with_preds, predictions, targets, losses)
    ]

    columns= ["image", "prediction", "target", "loss"]
    table = wandb.Table(data=my_data, columns=columns)

    return table

def concat_test_outputs(self, test_step_outputs):
    """
    Concatenate the output of the test step so that we can easily iterate on it and
    compute the loss for each item in one go.
    """
    images, predictions, targets = test_step_outputs[0]
    for i in range(1, len(test_step_outputs)):
        imgs, preds, targs = test_step_outputs[i]

        images = torch.cat((images, imgs), dim=0)
        predictions = torch.cat((predictions, preds), dim=0)
        targets = torch.cat((targets, targs), dim=0)
    
    return images, predictions, targets

RoadRegression.test_epoch_end = test_epoch_end
RoadRegression.create_table = create_table
RoadRegression.concat_test_outputs = concat_test_outputs

## Preparing our data
We download our pre-processed dataset from [WandB Artifacts](https://docs.wandb.ai/guides/artifacts). 
This way we will always know which model was trained on which version of the data, how the training went and we will be able to retrieve the trained weights.

In [None]:
def prepare_data(self):
    artifact = wandb.use_artifact(self.dataset_artifact)
    self.artifact_dir = artifact.download()

<p align="center">
  <img src="https://raw.githubusercontent.com/Armandpl/wandb-jetracer/master/assets/artifacts.png" height="300">
</p>
<p align="center">
<i>WandB Artifacts' graph for this project</i>
</p>

When using artifacts WandB automatically generates a graph allowing us to visualize the whole pipeline.  
Squares represent runs, circles represent artifacts and arrows indicate what artifacts are produced/consumed by runs. For example the `train`ing runs consume a `dataset` artifact and output a `model` artifact.  
Each of these artifact is also version controlled and we can easily access older models/datasets versions along with the runs that created them. 
[Click here](https://wandb.ai/wandb/racecar/artifacts/model/trt-model/d6bca0257e1bcec39983/graph) if you wish to explore this interactive graph!

In the cell below we setup a [PyTorch Lightning Data Module](https://pytorch-lightning.readthedocs.io/en/stable/extensions/datamodules.html). 

In [None]:
from typing import Optional

class RoadDataModule(pl.LightningDataModule):

    def __init__(self, dataset_artifact: str, batch_size):
        super().__init__()
        self.dataset_artifact = dataset_artifact
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        # Assign train/val datasets for use in dataloaders
        train_pth, val_pth, test_pth = [os.path.join(self.artifact_dir, split) for split in ["train", "val", "test"]] 

        if stage == 'fit' or stage is None:
            self.train, self.val = XYDataset(train_pth, train=True), XYDataset(val_pth, train=False)

            self.dims = tuple(self.train[0][0].size())

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.test = XYDataset(test_pth, train=False)

            self.dims = tuple(self.test[0][0].size())

    def train_dataloader(self):
        return self.make_loader(self.train, True)

    def val_dataloader(self):
        return self.make_loader(self.val, False)

    def test_dataloader(self):
        return self.make_loader(self.test, False)

    def make_loader(self, dataset, shuffle):
        return DataLoader(dataset=dataset,
                          batch_size=self.batch_size, 
                          shuffle=shuffle,
                          pin_memory=True, num_workers=2) 

RoadDataModule.prepare_data = prepare_data

### Let's train!
This is also pretty standard PyTorch Lighting/WandB code. More about PyTorch Lightning and WandB [in our docs](https://docs.wandb.ai/guides/integrations/lightning)!

In [None]:
config = dict(
    epochs=10,
    architecture="resnet34",
    pretrained=True,
    batch_size=64,
    learning_rate=1e-4,
    dataset="mix_ready:latest",
    train_augs=False,
    loss="mse_loss"
    )

with wandb.init(project="racecar", config=config, job_type="train", entity="wandb") as run:
    config = run.config

    dm = RoadDataModule(config.dataset, config.batch_size)
    road_regression = RoadRegression(config)

    wandb_logger = WandbLogger()
    trainer = pl.Trainer(
        logger=wandb_logger,
        gpus=1,
        max_epochs=config.epochs,
        log_every_n_steps=1
    )
    trainer.fit(road_regression, dm)

    trainer.test()

    # finally we log the model to wandb.
    torch.save(road_regression.model.state_dict(), "model.pth")
    artifact = wandb.Artifact('model', type='model')
    artifact.add_file('model.pth')
    run.log_artifact(artifact)

# 2. Working out your own dataset split

Here I have mixed images from three different miniature racetracks I built: [Suzuka](https://wandb.ai/wandb/racecar/artifacts/dataset/suzuka), [Monza](https://wandb.ai/wandb/racecar/artifacts/dataset/monza) and [Nurburgring](https://wandb.ai/wandb/racecar/artifacts/dataset/nurburgring).  

You may wish to tailor your own split to experiment.  
Maybe you could bring in [images from real roads](https://github.com/commaai/comma2k19)? Maybe you could train a model on one track and evaluate it on another to see if it generalizes well?  
Feel free to modify the code below to achieve what you want. Make sure the artifact/folder your upload to wandb after this step contains a `train`, `val` and `test` folder so that it works with the training code above!

In [None]:
import random
import shutil

from wandb_jetracer.utils.utils import make_dirs, split_list_by_pct

config = dict(
    datasets=["suzuka:latest", "monza:latest", "nurburgring:latest"],
    output_dataset="mix_ready",
    split_pcts=[0.7, 0.2, 0.1],
)

with wandb.init(project="racecar", config=config, entity="wandb", job_type="pre-process-dataset") as run:
    config = run.config

    out_dirs = make_dirs("./tmp/")

    # make sure the train/val/test pct are coherent
    assert math.fsum(config.split_pcts) == 1

    for dataset in config.datasets:
        # download each dataset
        artifact = run.use_artifact(dataset)
        artifact_dir = artifact.download()

        all_fnames = os.listdir(artifact_dir)
        random.shuffle(all_fnames)

        train, val, test = split_list_by_pct(all_fnames, config.split_pcts)
        sets = {
            "train": train,
            "val": val,
            "test": test
        }

        # and move it's files to the train/val/test dirs for the 'mixed' dataset 
        for out_dir, split in zip(out_dirs, ["train", "val", "test"]):
            for fname in sets[split]:
                source = os.path.join(artifact_dir, fname)
                dest = os.path.join(out_dir, fname)
                os.rename(source, dest)
    
    # then upload the mixed dataset to wandb
    artifact = wandb.Artifact(config.output_dataset, type='dataset')
    artifact.add_dir("./tmp/")
    run.log_artifact(artifact)

    shutil.rmtree('./tmp/', ignore_errors=True)