In [None]:
from heatmap_model.interaction_model import UQnet,TrajModel
import numpy as np
import sys
import matplotlib.pyplot as plt
%matplotlib inline 

import torch
from torch import nn, Tensor
import torchvision.datasets as dataset
from torch.optim.lr_scheduler import StepLR
import datetime

from scipy.special import expit

from heatmap_model.utils import *
from heatmap_model.train import *
from heatmap_model.interaction_dataset import *
from heatmap_model.losses import *
from config import *

from absl import logging
logging._warn_preinit_stderr = 0
logging.warning('Worrying Stuff')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

### This jupyter notebook shows how to train the models

- We first give the defaut setting of hyperparameters used in our paper
- Change the batch size if your device have more or less enough memory, it hardly influence the convergence speed
- You can also change this to <code>.py</code> script and add parser commands

In [None]:
# change hyper-parameters here
paralist['resolution'] = 1. 
paralist['encoder_attention_size'] = 128
paralist['use_sem'] = False
paralist['epoches'] = 8
paralist['mode'] = 'lanescore'
paralist['prob_mode'] = 'ce'
paralist['inference'] = False 
paralist['batch_size'] = 8

### Now we define the model, create training and validation dataloader, and create the used loss function
- the <code>test</code> argument in UQnet is always <code>False</code> during training
- The <code>OverAllLoss</code> is the one used in the paper

In [None]:
# set test=True during inference, drivale is optional
model = UQnet(paralist, test=False, drivable=False).to(device)
trainset = InteractionDataset(['train1', 'train2','train3','train4'], 'train', paralist, paralist['mode'], filters=False)
validationset = InteractionDataset(['val'], 'val', paralist, paralist['mode'], filters=False)
validation_loader = DataLoader(validationset, batch_size=paralist['batch_size'], shuffle=False)
BATCH_SIZE = paralist['batch_size']
EPOCH_NUMBER = paralist['epochs']
loss = OverAllLoss(paralist).to(device)

Now we train 7 randomly-initialized models and save their parameters in <code>intractionDE</code> folder.

In [None]:
# train 7 models to compose an ensemble
for order in range(7):
    model = UQnet(paralist).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    scheduler_heatmap = StepLR(optimizer, step_size=1, gamma=0.975)
    train_model(EPOCH_NUMBER, BATCH_SIZE, trainset, model, optimizer, validation_loader, loss,
                  scheduler_heatmap, paralist, mode=paralist['mode'])
    torch.save(model.encoder.state_dict(), './interactionDE/encoder'+str(order)+'.pt')
    torch.save(model.decoder.state_dict(), './interactionDE/decoder'+str(order)+'.pt')