# Training a baseline model

In this notebook we create a baseline solution to our semantic segmentation problem. To iterate fast we use a notebook here. We will then refactor this code into a script to be able to use hyperparameter sweeps.

In [1]:
import wandb
import pandas as pd
from fastai.vision.all import *
from fastai.callback.wandb import WandbCallback

import params
# Helper functions - for example metrics we will track during our experiments
from utils import get_predictions, create_iou_table, MIOU, BackgroundIOU, \
                  RoadIOU, TrafficLightIOU, TrafficSignIOU, PersonIOU, VehicleIOU, BicycleIOU

# Config

Create a train_config that gets passed to the W&B `run` to control training hyperparameters

In [2]:
# SimpleNamespace - creates an object to store values as attributes without creating your own (almost empty) class.

train_config = SimpleNamespace(
    framework="fastai",
    img_size=(180, 320),
    batch_size=8,
    augment=True, # use data augmentation
    epochs=10, 
    lr=2e-3,
    pretrained=True,  # whether to use pretrained encoder
    seed=42,
)

In [3]:
# Set seed for reproducibility.
set_seed(train_config.seed, reproducible=True)

# Download dataset

In [4]:
# Inputs
# - pass train_config into W&B run to control training hyperparameters
# - project=params.WANDB_PROJECT to make this W&B run be part of same project as previous  notebook W&B runs
run = wandb.init(project=params.WANDB_PROJECT, entity=params.ENTITY, job_type="training", config=train_config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33md-oliver-cort[0m ([33mdoc93[0m). Use [1m`wandb login --relogin`[0m to force relogin


Use W&B `artifacts` to track the lineage of the models.

In [5]:
# Use artefacts to track the data linage of our models
processed_data_artifact = run.use_artifact(f'{params.PROCESSED_DATA_AT}:latest')

# Download split data from W&B artifact
processed_dataset_dir = Path(processed_data_artifact.download())

# Read csv containing data split data (train/valid/test)
df = pd.read_csv(processed_dataset_dir / 'data_split.csv')

[34m[1mwandb[0m: Downloading large artifact bdd_simple_1k_split:latest, 846.07MB. 4010 files... 
[34m[1mwandb[0m:   4010 of 4010 files downloaded.  
Done. 0:0:11.5


# Preprocess data

In [6]:
# Remove test set rows
df = df[df.Stage != 'test'].reset_index(drop=True)
# - is_valid column will tell our trainer how we want to split data between training and validation.
df['is_valid'] = df.Stage == 'valid'

In [7]:
def label_func(fname):
    return (fname.parent.parent/"labels")/f"{fname.stem}_mask.png"

In [8]:
# Add image and mask label paths to dataframe
df["image_fname"] = [processed_dataset_dir/f'images/{f}' for f in df.File_Name.values]
df["label_fname"] = [label_func(f) for f in df.image_fname.values]

# Model training

We use fastai's DataBlock API to feed data into model training and validation.

In [9]:
def get_data(df, bs=4, img_size=(180, 320), augment=True):
    block = DataBlock(blocks=(ImageBlock, MaskBlock(codes=params.BDD_CLASSES)),
                  get_x=ColReader("image_fname"),
                  get_y=ColReader("label_fname"),
                  splitter=ColSplitter(),
                  item_tfms=Resize(img_size),
                  batch_tfms=aug_transforms() if augment else None,
                 )
    return block.dataloaders(df, bs=bs)

Using `wandb.config` to track our training hyperparameters (config parameters defined in `wandb.init(... , config=config_file)` ).

In [10]:
config = wandb.config

In [12]:
dls = get_data(df, bs=config.batch_size, img_size=config.img_size, augment=config.augment)

Could not do one pass in your dataloader, there is something wrong in it. Please see the stack trace below:


NotImplementedError: The operator 'aten::_linalg_solve_ex.result' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

In [13]:
# We use intersection over union metrics: mean across all classes (MIOU) and IOU for each class separately.
metrics = [MIOU(), BackgroundIOU(), RoadIOU(), TrafficLightIOU(), \
           TrafficSignIOU(), PersonIOU(), VehicleIOU(), BicycleIOU()]

# The model is a unet based on a pretrained resnet18 backbone.
learn = unet_learner(dls, arch=resnet18, pretrained=config.pretrained, metrics=metrics)

NameError: name 'dls' is not defined

`Fastai` already has a callback that integrates tightly with W&B. Only need to pass the `WandbCallback` to the `learner` (setup model) and we are ready to go. The callback will `log` all the useful variables for us. For example, whatever metric we pass to the learner will be tracked by the callback.

In [None]:
callbacks = [
    SaveModelCallback(monitor='miou'),              # Save model with best miou metric
    WandbCallback(log_preds=False, log_model=True)  # We log predictions manually on W&B (so set log_preds=False), and we want to log model W&B (so log_model=True)
]

In [None]:
# Train model
learn.fit_one_cycle(config.epochs, config.lr, cbs=callbacks)

Log a `table` with model predictions and ground truth, to W&B, so that we can do `error analysis` in the W&B dashboard. 

In [None]:
samples, outputs, predictions = get_predictions(learn)
table = create_iou_table(samples, outputs, predictions, params.BDD_CLASSES)
wandb.log({"pred_table":table})

Reload the scores of the model from the best checkpoint. 

To make sure we track the final metrics correctly, we will validate the model again and save the final loss and metrics to wandb.summary.

In [None]:
scores = learn.validate()
metric_names = ['final_loss'] + [f'final_{x.name}' for x in metrics]
final_results = {metric_names[i] : scores[i] for i in range(len(scores))}
for k,v in final_results.items(): 
    wandb.summary[k] = v

In [None]:
wandb.finish()