# Initializing Colab

Installing pytorch_xla

In [None]:
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version nightly --apt-packages libomp5 libopenblas-dev

Downloading data

In [None]:
!wget https://lyft-l5-datasets-public.s3-us-west-2.amazonaws.com/prediction/v1.1/{sample.tar,aerial_map.tar,semantic_map.tar}
!mkdir -p lyft/scenes/ lyft/aerial_map/ lyft/semantic_map/
!tar -xf sample.tar sample.zarr && mv sample.zarr lyft/scenes/ && rm sample.tar
!tar -xf aerial_map.tar -C lyft/ && rm aerial_map.tar
!tar -xf semantic_map.tar meta.json && mv meta.json lyft/
!tar -xf semantic_map.tar semantic_map.pb && mv semantic_map.pb lyft/semantic_map/ && rm semantic_map.tar

Cloning repository and installing dependencies

In [None]:
!git clone https://github.com/VahidZee/ReasonAwareRasterizedTrajectoryPrediction.git
cd ReasonAwareRasterizedTrajectoryPrediction
pip install -r requirements.txt

# Importing dependencies

In [None]:
import pytorch_lightning as pl
from l5kit.configs import load_config_data
from raster import BaseResnet, BaseTrainerModule, LyftDataModule

## Initializing various parts

In [None]:
config = load_config_data('./config.yaml')
datamodule = LyftDataModule('/Users/vahidzee/Desktop/lyft_data', config)
model = BaseResnet(
    in_channels = ( config["model_params"]["history_num_frames"] + 1) * 2 + 3,
    out_dim = 2 * config["model_params"]["future_num_frames"],
    model_type = config['model_params']['model_architecture'],
    pretrained = True,
)
training_procedure = BaseTrainerModule(model=model)
training_procedure.datamodule = datamodule

In [None]:
trainer = pl.Trainer(tpus=8)

In [None]:
%load_ext tensorboard
%tensorboard --logdir=lightning_logs

In [None]:
trainer.fit(training_procedure)