# Introduction to DriveNet

DriveNet is a CNN able to drive an agent on a specific route. Among its features, DriveNet can follow lanes and respect traffic lights.

In this notebook you're going to train and test your own Drivenet model using Lyft L5 Dataset and L5Kit.

TODO add more details

In [None]:
from tempfile import gettempdir
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.geometry import transform_points
from l5kit.visualization import TARGET_POINTS_COLOR, draw_trajectory
from l5kit.drivenet.model import DriveNetModel
from l5kit.kinematic import AckermanPerturbation
from l5kit.random import GaussianRandomGenerator

import os

## Prepare Data path and load cfg

By setting the `L5KIT_DATA_FOLDER` variable, we can point the script to the folder where the data lies.

Then, we load our config file with relative paths and other configurations (rasteriser, training params...).

In [None]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = "/tmp/l5kit_data"
dm = LocalDataManager(None)
# get config
cfg = load_config_data("./drivenet_config.yaml")
print(cfg)

# Data, Rasterisation and Perturbations

TODO talk about perturbation

In [None]:
perturb_prob = cfg["train_data_loader"]["perturb_probability"]
# finally create the rasteriser
rasterizer = build_rasterizer(cfg, dm)
perturbation = AckermanPerturbation(
        random_offset_generator=GaussianRandomGenerator(mean=np.array([0.0, 0.0]), std=np.array([1.0, np.pi / 6])),
        perturb_prob=perturb_prob,
    )

# ===== INIT DATASET
train_zarr = ChunkedDataset(dm.require(cfg["train_data_loader"]["key"])).open()
train_dataset = EgoDataset(cfg, train_zarr, rasterizer, perturbation)

# show same example with and without perturbation
for perturbation_value in [0, 1]:
    perturbation.perturb_prob = perturbation_value

    data_ego = train_dataset[0]
    im_ego = rasterizer.to_rgb(data_ego["image"].transpose(1, 2, 0))
    target_positions = transform_points(data_ego["target_positions"], data_ego["raster_from_agent"])
    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)
    plt.imshow(im_ego[::-1])
    plt.show()


# before leaving, ensure perturb_prob is correct
perturbation.perturb_prob = perturb_prob


## Model

L5Kit provides a model file for DriveNet. The backbone is a ResNetX one (either 18 or 50) pre-trained on ImageNet.
For a full description of the model please check `l5kit/drivenet/model.py`.

#### Inputs

Our inputs are not just RGB images so we need to replace the first convolutional layer with a custom one. This is already handled inside the DriveNet class, but the information about the number of channels is provided by the rasteriser (remember that each one can have a different number of channels).

#### Outputs

The model outputs change between train an evaluation. During train, the loss value is computed and returned, while the full outputs are reported during the evaluation phase.

In [None]:
model = DriveNetModel(
        model_arch="resnet50",
        num_input_channels=rasterizer.num_channels(),
        num_targets=3 * cfg["model_params"]["future_num_frames"],  # X, Y, Yaw * number of future states,
        weights_scaling= [1., 1., 1.],
        criterion=nn.MSELoss(reduction="none")
        )
print(model)

# Prepare for training

In [None]:
train_cfg = cfg["train_data_loader"]
train_dataloader = DataLoader(train_dataset, shuffle=train_cfg["shuffle"], batch_size=train_cfg["batch_size"], 
                             num_workers=train_cfg["num_workers"])
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)

print(train_dataset)

# Training

note: if you're on MacOS and using `py_satellite` rasterizer, you may need to disable opencv multiprocessing by adding:
`cv2.setNumThreads(0)` before the following cell. This seems to only affect running in python notebook and it's caused by the `cv2.warpaffine` function

In [None]:
# ==== TRAIN LOOP
tr_it = iter(train_dataloader)
progress_bar = tqdm(range(cfg["train_params"]["max_num_steps"]))
losses_train = []
model.train()

for _ in progress_bar:
    try:
        data = next(tr_it)
    except StopIteration:
        tr_it = iter(train_dataloader)
        data = next(tr_it)
    torch.set_grad_enabled(True)
    result = model(data)
    loss = result["loss"]
    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    losses_train.append(loss.item())
    progress_bar.set_description(f"loss: {loss.item()} loss(avg): {np.mean(losses_train)}")

### Plot Loss Curve
We can plot the train loss against the iterations (batch-wise)

In [None]:
plt.plot(np.arange(len(losses_train)), losses_train, label="train loss")
plt.legend()
plt.show()

# Store the Model

Let's store the model as a torchscript. This format allows us to re-load the model and weights without requiring the class definition later

# What's Next

open loop and close loop (can we have links here??)