In [None]:
from heatmap_model.interaction_model import CTnet, CTnet_causal
import numpy as np
import sys
import matplotlib.pyplot as plt
%matplotlib inline 

import torch
from torch import nn, Tensor
import torchvision.datasets as dataset
from torch.optim.lr_scheduler import StepLR
import datetime

from scipy.special import expit

from heatmap_model.utils import *
from heatmap_model.uncertainty_utils import *
from heatmap_model.inference import *
from heatmap_model.train import *
from heatmap_model.interaction_dataset import *
from heatmap_model.losses import *
from vis_utils.visualization import *
from config import *
from scipy.interpolate import make_interp_spline

from absl import logging
logging._warn_preinit_stderr = 0
logging.warning('Worrying Stuff')

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {device} device')

In [None]:
# change hyper-parameters here
para_train = paralist.copy()
para_train['resolution'] = 0.5
para_train['encoder_attention_size'] = 128
para_train['epoches'] = 64
para_train['test'] = False 
para_train['batch_size'] = 16
para_train['step'] = 1

In [None]:
para_test = para_train.copy()
para_test['test'] = True
para_test['ymax'] = 85
para_test['resolution'] = 0.5

In [None]:
dz = np.load('./results/kld.npz', allow_pickle=True)
Dtest = dz['Dtest']

In [None]:
dz = np.load('./results/Utest3.npz', allow_pickle=True)
Etest = dz['E']
# dz = np.load('./results/Uval3.npz', allow_pickle=True)
# Eval = dz['E'][:,-1]

In [None]:
dz = np.load('./results/Htest3.npz', allow_pickle=True)
H = dz['Heatmap']
Atest = seq_area(H,0.5,0.2)
# dz = np.load('./results/Hval3.npz', allow_pickle=True)
# H = dz['Heatmap']
# Aval = seq_area(H,0.5,0.2)

In [None]:
# set test=True during inference, drivale is optional
model = CTnet_causal(para_train).to(device)
trainset = InteractionDataset(['train1', 'train2','train3','train4'], 'train', para_train)
validationset = InteractionDataset(['val'], 'val', para_train)
validation_loader = DataLoader(validationset, batch_size=para_train['batch_size'], shuffle=False)
BATCH_SIZE = para_train['batch_size']
EPOCH_NUMBER = para_train['epoches']
loss = OverAllLoss_reg(para_train).to(device)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler_heatmap = StepLR(optimizer, step_size=1, gamma=0.975)
train_model(EPOCH_NUMBER, BATCH_SIZE, trainset, model, optimizer, validation_loader, loss, scheduler_heatmap)

In [None]:
scenario, V, polygons = read_polygons()

In [None]:
model = CTnet(para_test).to(device)
model.encoder.load_state_dict(torch.load('./pretrained/encoder.pt'))
model.decoder.load_state_dict(torch.load('./pretrained/decoder.pt'))
model.eval()

In [None]:
testset = InteractionDataset(['val'], 'val', para_test)

In [None]:
POLY = Inference_Polygon_sup(model, para_test, 'valall', nmax=107848, T=30)

In [None]:
selected, _ = selected_trainset()
print(len(selected))
POLY = Inference_Polygon_train(model, para_test, selected)

In [None]:
len(POLY)

In [None]:
with open('poly_results/polygon_train.pkl', 'wb') as f:
    pickle.dump(POLY, f)

In [None]:
ind = 2356#np.argsort(Dtest)[-3]
print(ind, scenario[ind], V[ind])
Ht = []
for i in np.arange(0.1, 3.1, 0.1):
    traj, maps, lanefeatures, adj, Af,c_mask, timestamp, gtxy = testset.test_sampling(ind, i)
    # traj[:,1:] = 0
    # adj[:,56:,56:] = 0
    # c_mask[:,56:] = 0
    heatmap = model(traj, maps, lanefeatures, adj, Af, c_mask, timestamp, gtxy)
    hr = heatmap.detach().to('cpu').numpy()
    hr = hr/np.amax(hr)
    #hr[hr<0.1]=0
    Ht.append(hr)
Ht = np.array(Ht)
Hsum = np.sum(Ht,0)

In [None]:
hist = testset.T[:, 0, :-1, 2:4]
Y = trajectory_generation(Ht, para_test)

In [None]:
for i in range(len(polygons)):
    if len(polygons[i][-1]) > 1:
        print(i)
#Y[:,1] = Y[:,1]*1.1

In [None]:
Y = ModalSamplingm2(Ht[-1], 0.4, para_test, r=2., k=6)

In [None]:
Ys = []
Y  = []

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1,3,figsize=(7.5,7.5))

In [None]:
fig, ax1 = Visualize_index_double(ind, Ht[-1]**1.5, [], Y[-1:], 'Predictor', [-25,20], [-25,95], para_test, fig, ax1, mode='test')

In [None]:
fig, ax2 = Visualize_index_double(ind, Hsum, [], [], 'Regularizor', [-25,20], [-25,95], para_test, fig, ax2, mode='valall')

In [None]:
fig, ax3 = Visualize_index_double(ind, Ht[-1],[], Y[-1:], 'Drive alone', [-25,20], [-25,95], para_test, fig, ax3, mode='test')

In [None]:
fig, axes = Visualize_index(ind, Ht[-1], [], [], sum(polygons[ind][::5], []), '', [-25, 25], [-12, 35], para=para_test, mode='valall')

In [None]:
fig.savefig('imgs/uqnet_ap1.jpg', dpi=600)

In [None]:
polygon = np.array(polygons[0][-1][0])
p = np.array(polygon)
px = (p[:,1]+46)/2
py = (p[:,0]+24)/2
p = Polygon(np.array([px, py]).T)
p2 = Polygon(np.array(polygon))
plt.plot(*p.exterior.xy)
plt.plot(*p2.exterior.xy)

In [None]:
Yp = final_prediction(model, testset, para_test, mode='test')

In [None]:
polygons[ind][-1]

In [None]:
np.savez_compressed('./results/fpcontrast1', FinalPoint=Yp)

In [None]:
FDE, MR = ComputeError(Yp,testset.Y[:,-1], r=1.5, sh=2)

In [None]:
Ht[Ht<0.1] = -3

In [None]:
datafiles = os.listdir('./rawdata/test/')
datafiles.sort()

In [None]:
Ht.shape

In [None]:
with open('./interaction_merge/test_index.pickle', 'rb') as f:
    Dnew = pickle.load(f)

In [None]:
S = np.zeros(len(Dnew[0]))
for i in tqdm(range(len(Dnew[0]))):
    file_id = int(Dnew[0][i][:-6])
    if file_id in [1,2,3,6,7,8,9]:
        S[i] = 1
    if file_id in [5,11,12,13,14]:
        S[i] = 2
    if file_id in [4,10, 15, 16, 17]:
        S[i] = 3

In [None]:
S = np.zeros(len(Dnew[0]))
for i in tqdm(range(len(Dnew[0]))):
    file_id = int(Dnew[0][i][:-6])
    if file_id in [1,2,4]:
        S[i] = 1
    if file_id in [6,7,8,9]:
        S[i] = 2
    if file_id in [3, 5, 10, 11, 12]:
        S[i] = 3

In [None]:
with open('Uresult/test_scenario.pkl', 'wb') as f:
    pickle.dump(S.astype(int).tolist(), f)

In [None]:
len(np.where(S==3)[0])

In [None]:
plt.plot(Aval[ind].T-3)
plt.show()

In [None]:
plt.hist(Atest[:,-1], bins=30)
plt.show()

In [None]:
#plt.rcParams["figure.figsize"] = (6,2)
fig=plt.figure(figsize=(6,2))
ax=fig.add_subplot(111)
ax.plot(np.arange(0.1, 3.1, 0.1), Atest[1791]-1.8, linewidth=3)
#plt.scatter(np.arange(0.1, 3.1, 0.1), np.mean(E, 0))
#plt.hlines(1.62, 0, 3.1,linestyles='dashed', label='added white noise')
plt.xlim(0.1,3)
plt.ylim(0, 60)
plt.xlabel('t(s)', fontsize=12)
plt.ylabel('A ($m^2$)', fontsize=12)
plt.title('A(t) for test case-1791', fontsize=14)
ax.xaxis.set_label_coords(0.95,0.2)
ax.yaxis.set_label_coords(0.08,0.7)
plt.grid()
#plt.legend()
plt.savefig('./imgs/Atest1791.pdf')
plt.show()

In [None]:
Y = testset.Y[:,-1]

In [None]:
dz = np.load('./results/Hval3.npz', allow_pickle=True)
H = dz['Heatmap'][-1]

In [None]:
nll = np.zeros(len(H))
for i in range(len(H)):
    print(i, end='\r')
    #nll[i] = NLLEstimate(H[i].toarray(), Y[i], para_test)
    nll[i] = NLLEstimate_test(H[i].toarray(), para_test)

In [None]:
plt.plot(Y[:,0], Y[:,1])
plt.show()

In [None]:
S = bezier_curve(H[:,6459], n=39)

In [None]:
plt.plot(S[:,0], S[:,1])
plt.plot(H[:,6459,0], H[:,6459,1])
plt.show()

In [None]:
H = rawtrajectory(model, para_test, 'val', batchsize=4, T=30)

In [None]:
np.savez_compressed('./results/rawtrajval2', T=H)

In [None]:
H.shape

In [None]:
plt.plot(H[:,6459,0], H[:,6459,1])
plt.show()

In [None]:
V = np.sqrt(testset.T[:,0,-1,4]**2+testset.T[:,0,-1,5]**2)

In [None]:
testset = InteractionDataset(['val'], 'val', para_test)

In [None]:
V = np.sqrt(testset.T[:,0,-1,4]**2+testset.T[:,0,-1,5]**2)

In [None]:
np.savez_compressed('./results/valspeed', V=V)

In [None]:
with open('./interaction_merge/val_index.pickle', 'rb') as f:
    Dnew = pickle.load(f)

In [None]:
len(Dnew[0])

In [None]:
Dnew[0]

In [None]:
with open('./interaction_merge/val_all_index.pickle', 'rb') as f:
    Dnew = pickle.load(f)

In [None]:
from tqdm import tqdm
import gc
F = []
Nb = []

for i in tqdm(range(len(Dnew[0]))):
    head = Dnew[0][i]
    for car in Dnew[1][i]:
        gc.disable()
        F.append(head)
        Nb.append(car)
        gc.enable()

D = (np.array(F), np.array(Nb).astype('str'))


In [None]:
import pickle
with open('./interaction_merge/valall_index.pickle', 'wb') as handle:
    pickle.dump(D, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
D[0]