<a href="https://colab.research.google.com/github/Armandpl/wandb_jetracer/blob/master/wandb_jetracer_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<img src="https://i.imgur.com/gb6B4ig.png" width="400" alt="Weights & Biases" />

# 🏁🏎️💨 = W&B ➕ Nvidia jetracer

In this notebook we are going to train a model to infer the center of the road an drive a remote controlled car.  

[link to video]  
image from the car's pov

# 0. Setup
Installing and importing dependencies. Cloning https://github.com/Armandpl/wandb_jetracer to get util functions. 

In [None]:
!pip install wandb

Collecting wandb
[?25l  Downloading https://files.pythonhosted.org/packages/6c/48/b199e2b3b341ac842108c5db4956091dd75d961cfa77aceb033e99cac20f/wandb-0.10.31-py2.py3-none-any.whl (1.8MB)
[K     |████████████████████████████████| 1.8MB 4.2MB/s 
Collecting docker-pycreds>=0.4.0
  Downloading https://files.pythonhosted.org/packages/f5/e8/f6bd1eee09314e7e6dee49cbe2c5e22314ccdb38db16c9fc72d2fa80d054/docker_pycreds-0.4.0-py2.py3-none-any.whl
Collecting shortuuid>=0.5.0
  Downloading https://files.pythonhosted.org/packages/25/a6/2ecc1daa6a304e7f1b216f0896b26156b78e7c38e1211e9b798b4716c53d/shortuuid-1.0.1-py3-none-any.whl
Collecting sentry-sdk>=0.4.0
[?25l  Downloading https://files.pythonhosted.org/packages/1c/4a/a54b254f67d8f4052338d54ebe90126f200693440a93ef76d254d581e3ec/sentry_sdk-1.1.0-py2.py3-none-any.whl (131kB)
[K     |████████████████████████████████| 133kB 18.2MB/s 
Collecting configparser>=3.8.1
  Downloading https://files.pythonhosted.org/packages/fd/01/ff260a18caaf4457eb028c96e

In [None]:
!git clone https://github.com/Armandpl/wandb_jetracer
!pip install pytorch-lightning torchmetrics

Cloning into 'wandb_jetracer'...
remote: Enumerating objects: 374, done.[K
remote: Counting objects: 100% (374/374), done.[K
remote: Compressing objects: 100% (218/218), done.[K
remote: Total 374 (delta 181), reused 293 (delta 107), pack-reused 0[K
Receiving objects: 100% (374/374), 99.80 KiB | 1.37 MiB/s, done.
Resolving deltas: 100% (181/181), done.
Collecting pytorch-lightning
[?25l  Downloading https://files.pythonhosted.org/packages/b6/6a/20d0bf3b967ab62333efea36fe922aaa252d1762555b4a7afb2be5bbdcbf/pytorch_lightning-1.3.5-py3-none-any.whl (808kB)
[K     |████████████████████████████████| 808kB 4.1MB/s 
[?25hCollecting torchmetrics
[?25l  Downloading https://files.pythonhosted.org/packages/3b/e8/513cd9d0b1c83dc14cd8f788d05cd6a34758d4fd7e4f9e5ecd5d7d599c95/torchmetrics-0.3.2-py3-none-any.whl (274kB)
[K     |████████████████████████████████| 276kB 37.4MB/s 
[?25hCollecting tensorboard!=2.5.0,>=2.2.0
[?25l  Downloading https://files.pythonhosted.org/packages/64/21/eebd23060

In [None]:
import math
import os

import cv2
import PIL
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torchmetrics
import wandb

from wandb_jetracer.utils.xy_dataset import XYDataset

# 1. Training a model

In [None]:
class RoadRegression(pl.LightningModule):

    def __init__(self, config):
        super().__init__()
        self.config = config

        # setting up metrics
        metrics = torchmetrics.MetricCollection([
            torchmetrics.MeanSquaredError(),
            torchmetrics.MeanAbsoluteError()
        ])
        self.train_metrics = metrics.clone(prefix='train/')
        self.valid_metrics = metrics.clone(prefix='val/')
        self.test_metrics = metrics.clone(prefix='test/')

        # setting up the model, here we are fine-tuning a ResNet
        self.model = torchvision.models.__dict__[config.architecture](pretrained=config.pretrained)
        self.model.fc = nn.Linear(self.model.fc.in_features, 2)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)
        loss = getattr(F, self.config.loss)(preds, targets)

        metrics = self.train_metrics(preds, targets)
        self.log_dict(metrics, on_step=True, on_epoch=False)

        return loss

    def validation_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)

        metrics = self.valid_metrics(preds, targets)
        self.log_dict(metrics, on_step=False, on_epoch=True)

    def test_step(self, batch, batch_idx):
        images, targets = batch
        preds = self.forward(images)

        metrics = self.test_metrics(preds, targets)
        self.log_dict(metrics, on_step=False, on_epoch=True)

        return (images, preds, targets)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.config.learning_rate)
        return optimizer



## Visualizing predictions on the test set  
To get an intuition of how our model is doing we log our test set along with predictions, losses and ground truth as a wandb table. We can then go to our dashboard and sort by highest losses for example. **[add link to dashboard or example run]**

In [None]:
def test_epoch_end(self, test_step_outputs):
    images, predictions, targets = self.concat_test_outputs(test_step_outputs)

    # compute loss for each image of the test set 
    losses = F.mse_loss(predictions, targets, reduction='none')

    test_table = self.create_table(images, predictions, targets, losses)

    wandb.log({"test/predictions": test_table})

def create_table(self, images, predictions, targets, losses):
    # display preds and targets on images
    images_with_preds = []
    for idx, image in enumerate(images):
        img = torch2cv2(image)

        # show ground truth and prediction on the image
        img = show_label(img, targets[idx])
        img = show_label(img, predictions[idx], (0, 0, 255))

        images_with_preds.append(img)

    # create a WandB table
    my_data = [
        [wandb.Image(img), pred, target, loss.sum()] 
        for img, pred, target, loss
        in zip(images_with_preds, predictions, targets, losses)
    ]

    columns= ["image", "prediction", "target", "loss"]
    table = wandb.Table(data=my_data, columns=columns)

    return table

def concat_test_outputs(self, test_step_outputs):
    """
    Concatenate the output of the test step so that we can iterate easily and
    compute the loss for each item easily.
    """
    images, predictions, targets = test_step_outputs[0]
    for i in range(1, len(test_step_outputs)):
        imgs, preds, targs = test_step_outputs[i]

        images = torch.cat((images, imgs), dim=0)
        predictions = torch.cat((predictions, preds), dim=0)
        targets = torch.cat((targets, targs), dim=0)
    
    return images, predictions, targets

RoadRegression.test_epoch_end = test_epoch_end
RoadRegression.create_table = create_table
RoadRegression.concat_test_outputs = concat_test_outputs

### Utils functions. Extract to github repo

In [None]:
def show_label(image, coordinates, color=(0, 255, 0), img_size=224):
    """
    Show a circle at x, y coordinates on image
    x, y belong to [-1, 1]
    """
    img_h, img_w, _ = image.shape
    x, y = coordinates

    # shift x, y to be between 0 and 1
    x = int((x + 1) / 2 * img_w)
    y = int((y + 1) / 2 * img_h)

    cv2.circle(image, (x, y), 5, color, 2)

    return image

def torch2cv2(tensor):
    img = tensor.permute(1, 2, 0).cpu().numpy()*255
    img = cv2.cvtColor(
                img, cv2.COLOR_BGR2RGB
                )

    return img

## Preparing our data
Most of this is pretty standard code. If you'd like to learn more about Weights and Biases and PyTorch Lightning you can check out this video. **[add link to Charles's pl tutorial]**

In [None]:
from typing import Optional

class RoadDataModule(pl.LightningDataModule):

    def __init__(self, dataset_artifact: str, batch_size):
        super().__init__()
        self.dataset_artifact = dataset_artifact
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        # Assign train/val datasets for use in dataloaders
        train_pth, val_pth, test_pth = [os.path.join(self.artifact_dir, split) for split in ["train", "val", "test"]] 

        if stage == 'fit' or stage is None:
            self.train, self.val = XYDataset(train_pth, train=True), XYDataset(val_pth, train=False)

            self.dims = tuple(self.train[0][0].size())

        # Assign test dataset for use in dataloader(s)
        if stage == 'test' or stage is None:
            self.test = XYDataset(test_pth, train=False)

            self.dims = tuple(self.test[0][0].size())

    def train_dataloader(self):
        return self.make_loader(self.train, True)

    def val_dataloader(self):
        return self.make_loader(self.val, False)

    def test_dataloader(self):
        return self.make_loader(self.test, False)

    def make_loader(self, dataset, shuffle):
        return DataLoader(dataset=dataset,
                          batch_size=self.batch_size, 
                          shuffle=shuffle,
                          pin_memory=True, num_workers=2) 

### The interesting bit is this:



In [None]:
def prepare_data(self):
    # we download the dataset from WandB Artifacts
    artifact = wandb.use_artifact(self.dataset_artifact)
    self.artifact_dir = artifact.download()

RoadDataModule.prepare_data = prepare_data

We download our pre-processed dataset from WandB Artifacts. 
This way we will always know which model was trained on which version of the data, how the training went and we will be able to retrieve the trained weights. 

### Let's train then!
Again this is pretty standard pl/wandb code

In [None]:
config = dict(
    epochs=10,
    architecture="resnet34",
    pretrained=True,
    batch_size=64,
    learning_rate=1e-4,
    dataset="mix_ready:latest",
    train_augs=False,
    loss="mse_loss"
    )

with wandb.init(project="racecar", config=config, job_type="train", entity="wandb") as run:
    config = run.config

    dm = RoadDataModule(config.dataset, config.batch_size)
    road_regression = RoadRegression(config)

    wandb_logger = WandbLogger()
    trainer = pl.Trainer(
        logger=wandb_logger,
        gpus=1,
        max_epochs=config.epochs,
        log_every_n_steps=1
    )
    trainer.fit(road_regression, dm)

    trainer.test()

    # finally we log the model to wandb.
    torch.save(road_regression.model.state_dict(), "model.pth")
    artifact = wandb.Artifact('model', type='model')
    artifact.add_file('model.pth')
    run.log_artifact(artifact)

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


wandb: Paste an API key from your profile and hit enter: ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth


HBox(children=(FloatProgress(value=0.0, max=87306240.0), HTML(value='')))




GPU available: True, used: True
TPU available: False, using: 0 TPU cores
[34m[1mwandb[0m: Downloading large artifact mix_ready:latest, 51.06MB. 3025 files... Done. 0:0:0
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type             | Params
---------------------------------------------------
0 | train_metrics | MetricCollection | 0     
1 | valid_metrics | MetricCollection | 0     
2 | test_metrics  | MetricCollection | 0     
3 | model         | ResNet           | 21.3 M
---------------------------------------------------
21.3 M    Trainable params
0         Non-trainable params
21.3 M    Total params
85.143    Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/MeanAbsoluteError': 0.0898691862821579,
 'test/MeanSquaredError': 0.014147796668112278}
--------------------------------------------------------------------------------


VBox(children=(Label(value=' 101.26MB of 101.26MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=…

0,1
train/MeanSquaredError,0.02591
train/MeanAbsoluteError,0.14609
epoch,9.0
trainer/global_step,340.0
_runtime,369.0
_timestamp,1623315055.0
_step,351.0
val/MeanSquaredError,0.01795
val/MeanAbsoluteError,0.09238
test/MeanSquaredError,0.01415


0,1
train/MeanSquaredError,█▄▃▃▃▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/MeanAbsoluteError,█▆▅▄▄▃▃▃▂▃▂▂▃▂▂▂▂▂▂▂▂▂▁▂▂▁▂▂▁▁▂▁▂▂▁▁▂▁▁▃
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
trainer/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
_runtime,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█
_timestamp,▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇█
_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/MeanSquaredError,█▄▄▂▂▂▁▁▁▁
val/MeanAbsoluteError,█▅▄▃▂▂▁▁▁▁
test/MeanSquaredError,▁


# 2. Working out your own dataset split

Here I have mixed images from two different racetracks I built: Suzuka, Monza and Nurburgring.  
## [add images with captions?]
You may wish to tailor your own split to experiment.  
Maybe you could bring in images from real roads? Maybe you could train on one track and evaluate on another to see if your model generalizes?  
Feel free to modify the code below to achieve what you want. Make sure the artifact/folder your upload to wandb after this step contains a `train`, `val` and `test` folder so that it works with the training code above!

In [None]:
import random
import shutil

from wandb_jetracer.utils.utils import make_dirs, split_list_by_pct

config = dict(
    datasets=["suzuka:latest", "monza:latest", "nurburgring:latest"],
    output_dataset="mix_ready",
    split_pcts=[0.7, 0.2, 0.1],
)

with wandb.init(project="racecar", config=config, entity="wandb", job_type="pre-process-dataset") as run:
    config = run.config

    out_dirs = make_dirs("./tmp/")

    # make sure the train/val/test pct are coherent
    assert math.fsum(config.split_pcts) == 1

    for dataset in config.datasets:
        artifact = run.use_artifact(dataset)
        artifact_dir = artifact.download()

        # TODO: might be a good place to run tests on wandb projects        
        all_fnames = os.listdir(artifact_dir)
        random.shuffle(all_fnames)

        train, val, test = split_list_by_pct(all_fnames, config.split_pcts)
        sets = {
            "train": train,
            "val": val,
            "test": test
        }
     
        for out_dir, split in zip(out_dirs, ["train", "val", "test"]):
            for fname in sets[split]:
                source = os.path.join(artifact_dir, fname)
                dest = os.path.join(out_dir, fname)
                os.rename(source, dest)
    
    # upload artifacct
    artifact = wandb.Artifact(config.output_dataset, type='dataset')
    artifact.add_dir("./tmp/")
    run.log_artifact(artifact)

    shutil.rmtree('./tmp/', ignore_errors=True)

[34m[1mwandb[0m: Adding directory to artifact (./tmp)... Done. 0.7s


VBox(children=(Label(value=' 51.19MB of 51.19MB uploaded (50.99MB deduped)\r'), FloatProgress(value=1.0, max=1…