In [None]:
import datetime
import pickle
import torch
import os
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from matplotlib.pyplot import figure
from model_definitions import GRU_submodel,Classnet
from utils import get_ade,get_fde
from PIL import Image

torch.manual_seed(0)
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)
torch.use_deterministic_algorithms(True)
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
torch.autograd.set_detect_anomaly(True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(datetime.datetime.now())
im = Image.open("./localization_grid.pgm")

In [None]:
class fullmodel(nn.Module):
    def __init__(self,numrc,k,device,embed_dim,hidden_dim,dropout_rate,pred_len):
        super().__init__()
        self.classifier = Classnet(2,hidden_dim,RC_class_size=numrc)
        self.classifier.load_state_dict(torch.load("./subnets/ATC_DDPTP_DC_ID_classnet.pth"))
        self.device = device
        self.classifier.to(device)
        self.classifier.eval()
        self.submodel = nn.ModuleList()
        for i in range(numrc):
            temp = GRU_submodel(embed_dim,hidden_dim,dropout_rate,pred_len)
            temp.load_state_dict(torch.load("./subnets/ATC_DDPTP_DC_ID_submodel_"+str(i)+".pth"))
            temp.to(device)
            temp.eval()
            self.submodel.append(temp)
        self.k = k
    def forward(self, x):
        classprob = self.classifier(x)
        #print(classprob)
        prob,indx =  torch.topk(classprob, self.k)
        #print(prob,indx)
        out = []
        #print(indx)
        for i in indx[0]:
            #print(i)
            out.append(self.submodel[i](x))
        out = torch.stack(out)[0]
        return out

In [None]:
numrc=4
k=1
embed_dim = 128
hidden_dim = 128
dropout_rate = 0.5
pred_len = 20

In [None]:
net = fullmodel(numrc,k,device,embed_dim,hidden_dim,dropout_rate,pred_len)
net.to(device)
net.eval()

In [None]:
with open('datas/ATC_test_x.pickle', 'rb') as f:
    x = pickle.load(f)
with open('datas/ATC_test_y.pickle', 'rb') as f:
    y = pickle.load(f)
tensor_x = torch.Tensor(x)/100/10
tensor_y = torch.Tensor(y)/100/10
vdata = TensorDataset(tensor_x,tensor_y)
allset = DataLoader(vdata,batch_size=1,shuffle=False,num_workers =4)
accuADE = []
accuFDE = []

for i, data in enumerate(allset, 0):
    inputs, labels = data[0].to(device), data[1].to(device)
    output1 = net(inputs)
    accuADE.append(get_ade(output1.cpu().detach().numpy(),labels.cpu().detach().numpy()))
    accuFDE.append(get_fde(output1.cpu().detach().numpy(),labels.cpu().detach().numpy()))
print('Final ADE: %.3f' %(np.mean(accuADE)))
print('Final FDE: %.3f' %(np.mean(accuFDE)))

In [None]:
indx = 478

inputs, labels = vdata[indx][0].to(device), vdata[indx][1].to(device)
output1 = net(torch.unsqueeze(inputs,0))
inp = inputs.cpu().detach().numpy()
testest = output1.cpu().detach().numpy()
testest2 = labels.cpu().detach().numpy()

import matplotlib.pyplot as plt
from matplotlib.pyplot import figure
plt.figure(figsize=(16, 9))

plt.scatter(testest[:,:,0][0],testest[:,:,1][0])
plt.scatter(testest2[:,0],testest2[:,1])
plt.scatter(inp[:,0],inp[:,1])
plt.legend(['prediction',"ground_truth","observation"])
plt.axis([-45, 30, -20, 20])
plt.imshow(im,
           aspect='equal',
           origin="upper",
           extent=[-60, 80, -40, 20],
           vmin=0,
           vmax=255,
           cmap='gray')
#plt.gca().invert_yaxis()