## Prepare Data path and load cfg

By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.

Then, we load our config file with relative paths and other configurations (rasteriser, training params...).

### Setup

In [None]:
from pathlib import Path
import os

In [None]:
#NOTE: DONT USE RELATIVE PATHS FOR THE MODELS PROVIDED BY L5
experiments_directory = Path(Path(os.path.abspath('')).parent.parent, "Experiments")
experiments_directory.mkdir(parents=True, exist_ok=True)

data_directory = Path(experiments_directory, "data")
data_directory.mkdir(parents=True, exist_ok=True)

prediction_directory = Path(experiments_directory, "prediction")
prediction_directory.mkdir(parents=True, exist_ok=True)

prediction_training_directory = Path(prediction_directory, "training")
prediction_training_directory.mkdir(parents=True, exist_ok=True)

save_directory = Path(prediction_training_directory, "saved_outputs")
save_directory.mkdir(parents=True, exist_ok=True)

In [None]:
import os
os.chdir(prediction_training_directory)

In [None]:
%%writefile requirements.txt
l5kit
pyyaml
ray==2.0.0rc1
ray[air]
wandb
optuna

In [None]:
from typing import Dict

from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet50
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace, rmse, prob_true_mode, average_displacement_error_oracle, average_displacement_error_mean, final_displacement_error_oracle, final_displacement_error_mean, detect_collision, distance_to_reference_trajectory
from l5kit.geometry import transform_points
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from prettytable import PrettyTable
from pathlib import Path

import os

### Get Data from Wandb

In [None]:
import wandb
wandb.login()

In [None]:
# Run information
# wandb_entity = "l5-demo"
project_name = "l5-prediction"
run_name = "download-l5-data"
run_type = "download"
run_description = """
Download data for the task of training a prediction model
"""
tags = ["download", "data"]

In [None]:
#🪄🐝
run = wandb.init(
    # entity=wandb_entity,
    project=project_name,
    job_type=run_type,
    name=run_name,
    notes=run_description,
    tags=tags
)

In [None]:
# artifact_entity = "l5-demo"
artifact_project = "l5-common"
artifact_name = "l5-data"
artifact_alias = "latest"
artifact_type = "dataset"

In [36]:
#🪄🐝
# artifact = run.use_artifact(f"{artifact_entity}/{artifact_project}/{artifact_name}:{artifact_alias}", type=artifact_type)
artifact = run.use_artifact(f"{artifact_project}/{artifact_name}:{artifact_alias}", type=artifact_type)

In [35]:
_ = artifact.download(data_directory)

[34m[1mwandb[0m: Downloading large artifact l5-data:latest, 2386.92MB. 517 files... 
[34m[1mwandb[0m:   517 of 517 files downloaded.  
Done. 0:0:0.2


In [None]:
#BUG: need to seperate runs into download and training due to issues with routing runs after ray.tune
run.finish()

In [None]:
# Dataset is assumed to be on the folder specified
# in the L5KIT_DATA_FOLDER environment variable

# get config
cfg = load_config_data(Path(data_directory, "configurations", "agent_motion_config.yaml"))
l5_data_location = Path(data_directory, "dataset")
# run.config.update(cfg)

In [None]:
# cfg["zarr_dataset_location"] = l5_data_location
os.environ["L5KIT_DATA_FOLDER"] = str(l5_data_location)

## Model

Our baseline is a simple `resnet50` pretrained on `imagenet`. We must replace the input and the final layer to address our requirements.

In [None]:
def build_model(cfg: Dict) -> torch.nn.Module:
    # load pre-trained Conv2D model
    model = resnet50(pretrained=True)

    # change input channels number to match the rasterizer's output
    num_history_channels = (cfg["model_params"]["history_num_frames"] + 1) * 2
    num_in_channels = 3 + num_history_channels
    model.conv1 = nn.Conv2d(
        num_in_channels,
        model.conv1.out_channels,
        kernel_size=model.conv1.kernel_size,
        stride=model.conv1.stride,
        padding=model.conv1.padding,
        bias=False,
    )
    # change output size to (X, Y) * number of future states
    num_targets = 2 * cfg["model_params"]["future_num_frames"]
    model.fc = nn.Linear(in_features=2048, out_features=num_targets)

    return model

In [None]:
def forward(data, model, criterion):
    inputs = data["image"]
    target_availabilities = data["target_availabilities"].unsqueeze(-1)
    targets = data["target_positions"]
    # Forward pass
    outputs = model(inputs).reshape(targets.shape)
    loss = criterion(outputs, targets)
    # not all the output steps are valid, but we can filter them out from the loss using availabilities
    loss = loss * target_availabilities
    loss = loss.mean()
    return loss, outputs

In [None]:
def train_prediction_model_epoch(data, model, criterion, optimizer):
    loss, outputs = forward(data, model, criterion)
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss, outputs

Our data pipeline map a raw `.zarr` folder into a multi-processing instance ready for training by:
- loading the `zarr` into a `ChunkedDataset` object. This object has a reference to the different arrays into the zarr (e.g. agents and traffic lights);
- wrapping the `ChunkedDataset` into an `AgentDataset`, which inherits from torch `Dataset` class;
- passing the `AgentDataset` into a torch `DataLoader`

# Training

note: if you're on MacOS and using `py_satellite` rasterizer, you may need to disable opencv multiprocessing by adding:
`cv2.setNumThreads(0)` before the following cell. This seems to only affect running in python notebook and it's caused by the `cv2.warpaffine` function

In [None]:
import ray.train as train
from ray.air import session, Checkpoint

In [None]:
from ray import tune
from ray.tune.tuner import Tuner

In [None]:
def train_prediction_model(tuner_cfg : Dict):
    trial_name = session.get_trial_name()
    trial_id = session.get_trial_id()
    trial_readable_name = f"{trial_name}_{trial_id}"
    
    dm = LocalDataManager()
    
    # ==== Configurations
    shuffle = tuner_cfg["shuffle"]
    batch_size = int(tuner_cfg["batch_size"])
    num_workers = tuner_cfg["num_workers"]
    lr = tuner_cfg["lr"]
    max_num_steps = int(tuner_cfg["max_num_steps"])
    dataset_key = tuner_cfg["dataset_key"]
    cfg = tuner_cfg["cfg"]
    
    # ==== Loading Dataset
    rasterizer = build_rasterizer(cfg, dm)

    train_zarr = ChunkedDataset(dm.require(dataset_key)).open()
    train_dataset = AgentDataset(cfg, train_zarr, rasterizer)

    batch_size_per_worker = batch_size // session.get_world_size()
    train_dataloader = DataLoader(train_dataset, shuffle=shuffle, batch_size=batch_size_per_worker, num_workers=num_workers)
    train_dataloader = train.torch.prepare_data_loader(train_dataloader)
    
    # ==== Init model
    model = build_model(cfg)
    model = train.torch.prepare_model(model)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss(reduction="none")

    # ==== TRAIN LOOP
    tr_it = iter(train_dataloader)
    progress_bar = range(max_num_steps)
    num_checkpoints = 5
    steps_before_checkpointing = max_num_steps // num_checkpoints
    losses_train = []
    checkpoint_counter = 0
    
    for step in progress_bar:
        try:
            data = next(tr_it)
        except StopIteration:
            tr_it = iter(train_dataloader)
            data = next(tr_it)
            
        model.train()
        torch.set_grad_enabled(True)
        loss, _ = train_prediction_model_epoch(data, model, criterion, optimizer)
        losses_train.append(loss.item())
        avg_loss = np.mean(losses_train)
        metrics = {
            "loss": loss.item(),
            "avg_loss": avg_loss
        }
        
        if train.world_rank() == 0:
            print(metrics)
        
        if (step%steps_before_checkpointing==0) or (step==max_num_steps-1):
            session.report(
                metrics=metrics,
                checkpoint=Checkpoint.from_dict(dict(step=step, model=model)))
            checkpoint_counter += 1
        else:
            session.report(
                metrics=metrics
            )

    

### Distributed Training using Ray

We calculate the available hardware for our current training sessions and efficiently split CPUs based on GPUs or split CPUs evenly if possible

In [None]:
from ray.train.torch import TorchTrainer
from ray.air.config import RunConfig, ScalingConfig
from ray.air.callbacks.wandb import WandbLoggerCallback #🪄🐝

In [None]:
import multiprocessing

In [None]:
USE_GPU = torch.cuda.is_available()
NUM_GPUS = torch.cuda.device_count()
NUM_CPUS = multiprocessing.cpu_count()

In [None]:
if USE_GPU:
    num_actors = NUM_GPUS
    num_data_workers = NUM_CPUS // num_actors
else:
    num_data_workers = 4 if NUM_CPUS>=4 else NUM_CPUS
    ideal_num_actors = NUM_CPUS // num_data_workers
    num_actors = ideal_num_actors if ideal_num_actors else 1

To use Ray all we need to simply do is wrap the training function above. The only addition needed above was calls to `report.session` to log metrics during training

In [None]:
#NOTE: To figure out if scaling config intuiutin is correct: num_actors divide resources between each actor and within the train func each actor can the utilize the shared resources
trainer = TorchTrainer(
    train_loop_per_worker=train_prediction_model,
    scaling_config=ScalingConfig(num_workers=num_actors, use_gpu=USE_GPU),
)

### Distributed Hyperparemeter Tuning using Ray

Due to Ray's easy interface we can simply extend our normal trainer to Ray's tuner which will allow us to do efficient hyperparameter optimization. In our case we use `optuna`

In [None]:
tuner_train_config = {}
##static
tuner_train_config["shuffle"] = cfg["train_data_loader"]["shuffle"]
tuner_train_config["num_workers"] = num_data_workers
tuner_train_config["dataset_key"] = cfg["train_data_loader"]["key"]

##tunable
tuner_train_config["max_num_steps"] = tune.quniform(1000, 5000, 250)
tuner_train_config["lr"] = tune.loguniform(1e-4, 1e-2)
tuner_train_config["batch_size"] = tune.quniform(6, 24, 6)
cfg["raster_params"]["map_type"] = tune.choice(["py_semantic", "py_satellite"])

tuner_train_config["cfg"] = cfg

In [None]:
from ray.tune.logger import LoggerCallback
from typing import Dict, List

In [None]:
from ray.tune.stopper import ExperimentPlateauStopper
from ray.tune.search.optuna import OptunaSearch

In [None]:
n_search_attempts = 25

In [None]:
optuna_search = OptunaSearch()

In [None]:
tuner = Tuner(
        trainer,
        tune_config=tune.TuneConfig(
            metric="avg_loss", #loss or avg_loss here?
            mode="min",
            search_alg=optuna_search,
            num_samples=n_search_attempts,
        ),
        param_space={
            "train_loop_config": tuner_train_config
        },
        run_config=RunConfig(
            stop=ExperimentPlateauStopper("avg_loss"),
            callbacks=[WandbLoggerCallback(project=f"{project_name}-trials", save_checkpoints=True),]))  #🪄🐝

### Aggregrate and Report Metrics from All Trials

In [None]:
analysis = tuner.fit()

In [None]:
import time

In [None]:
time.sleep(30)

In [None]:
analysis_df = analysis.get_dataframe()

In [None]:
analysis_df

In [None]:
# Run information
# wandb_entity = "l5-demo"
project_name = "l5-prediction"
run_name = "train-prediction-model"
run_type = "train"
run_description = """
Train prediction model
"""
tags = ["train", "prediction"]

In [None]:
#🪄🐝
run = wandb.init(
    # entity=wandb_entity,
    project=project_name,
    job_type=run_type,
    name=run_name,
    notes=run_description,
    tags=tags,
    config=cfg
)

In [None]:
#BUG: to force a connection on the lineage graph
#🪄🐝
# artifact = run.use_artifact(f"{artifact_entity}/{artifact_project}/{artifact_name}:{artifact_alias}", type=artifact_type)
artifact = run.use_artifact(f"{artifact_project}/{artifact_name}:{artifact_alias}", type=artifact_type)

In [None]:
#🪄🐝
analysis_table = wandb.Table(dataframe=analysis_df)

In [None]:
#BUG: run gets lost after tune job due to change in cwd. Forced to make 2 runs
if len(analysis_table.data) == 0:
    raise ValueError("bad table for some reason")
else:
    run.log({"analysis_table": analysis_table})
    run.finish()