In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable


"""Change to the data folder"""
new_path = "../new_train/new_train/"
val_path = "../new_val_in/new_val_in/"

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, training=True):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.training = training

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)

#         varMin =[-5.37872696e+00,  0.00000000e+00, -6.39434204e+01, -8.02092514e+01,
#   0.00000000e+00,  0.00000000e+00,  1.00000000e+06,  1.00000000e+06]
#         varMax = [ 4.72452881e+03,  4.08314209e+03,  7.73842545e+01,  0.00000000e+00,
#   4.73215820e+03,  4.07677856e+03, -1.00000000e+06, -1.00000000e+06]
        columns = ['p_in','p_in','v_in','v_in','lane','lane','lane_norm','lane_norm']
        outColumns = ['p_out','p_out','v_out','v_out']
    #FOR 1000 rows, no shuffle
#     varMin=[ 0.00000000e+00,  0.00000000e+00, -3.83204346e+01, -4.79440918e+01,
#   0.00000000e+00,  0.00000000e+00,  1.00000000e+06,  1.00000000e+06]
#         varMax=[ 4.70824121e+03,  4.04640869e+03,  7.00706635e+01,  0.00000000e+00,
#   4.71727344e+03,  4.07275757e+03, -1.00000000e+06, -1.00000000e+06]
#   FOR ALL ROWS
#         varMaxOutput= [4773.,   4097.7,   193.19,  194.33]
#         varMinOutput=[ -53.912,    0.,    -210.04,  -187.71 ]
#         varMaxInput=[4748.2,   4096.1,    252.32,   183.53,  4791.6,   4121.4,     18.801,   16.702]
#         varMinInput=[ -46.958,    0.,    -222.63,  -179.87,   -75.963,    0.,     -18.564,  -16.691]
#         #Changed to 4 because we don't need lane and lane norm
#         print("here in argo")
#         print("before p_out ",data['p_out'])
#         print("before v_out ",data['v_out'])
#         for i in range(8):
#             j = i % 2
#             data[columns[i]][j] = (data[columns[i]][j] - varMinInput[i]) / (varMaxInput[i] - varMinInput[i])
#             if i < 4 and self.training:
#                 data[outColumns[i]][j] = (data[outColumns[i]][j] - varMinOutput[i]) / (varMaxOutput[i] - varMinOutput[i]) 
#         data['p_out'][0] = (data['p_out'][0] - -5.37872696e+00) / (4.72452881e+03 - -5.37872696e+00)
#         print("after p_out ",data['p_out'])
#         print("after v_out ",data['v_out'])
        return data


train_dataset  = ArgoverseDataset(data_path=new_path)
val_dataset = ArgoverseDataset(data_path=val_path,training=False)
#print((val_dataset[0]))
#print(len(train_dataset[0]))

### Create a loader to enable batch processing

In [18]:
batch_sz = 10

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = []
    out = []
    laneInfo = []
    city = []
#     print("pIn",batch[0]['p_in'])
    numbRows = 60
    for scene in batch:
#         print("scenePin",scene['p_in'])
        cityyy = numpy.zeros((60,1,4))
        #print(cityyy)
        if scene['city'] == 'PIT':
            cityyy[:,:,:] = 1
            #print("mummamia",cityyy)
        city.append(numpy.dstack([cityyy]))
        lanes = numpy.zeros((numbRows * 19,2))
        lane_norm = numpy.zeros((numbRows * 19,2))
#         pIn = numpy.zeros((numbRows,19,2))
#         vIn = numpy.zeros((numbRows,19,2))
        lengthLane = min(numbRows * 19,len(scene['lane']))
#         pIn[:len(scene['p_in']),:,:2] = scene['p_in']
#         vIn[:len(scene['v_in']),:,:2] = scene['v_in']
        
#         x = lanes.reshape()
        lanes[:lengthLane,:2] = scene['lane'][:lengthLane,:2]
        lane_norm[:lengthLane,:2] = scene['lane_norm'] [:lengthLane,:2]
        laneInfo.append(numpy.dstack([lanes.reshape(60,19,2),lane_norm.reshape(60,19,2)]))
        inp.append(numpy.dstack([scene['p_in'],scene['v_in']]))
        out.append(numpy.dstack([scene['p_out'], scene['v_out']]))
#     print("from mycollate",scene['p_in']) 
#     print("p_in",inp[0][0][0])
    laneInfo = torch.FloatTensor(laneInfo)
    inp = torch.FloatTensor(inp)
    out = torch.FloatTensor(out)
    city = torch.FloatTensor(city)
#     print("laneInfor shape: ",laneInfo.shape)
#     print("inp shape: ",inp.shape)
#     print("out shape: ",out.shape)
#     print("city shape: ",city.shape)
    return [city,laneInfo,inp, out]

train_loader = DataLoader(train_dataset,batch_size=batch_sz, shuffle = True, collate_fn=my_collate, num_workers=8)

In [12]:
batch_sz = 2
def val_collate(batch):
    agentIds = []
    trackIds = []
    sceneIdxs = []
    laneInfo = []
    inp = []
    city = []
    numbRows = 60
    for scene in batch:
        """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
        agentIds.append(scene['agent_id'])
        trackIds.append(scene['track_id'])
        sceneIdxs.append(scene['scene_idx'])
        
        cityyy = numpy.zeros((60,1,4))
        #print(cityyy)
        if scene['city'] == 'PIT':
            cityyy[:,:,:] = 1
        city.append(numpy.dstack([cityyy]))
        
        lanes = numpy.zeros((numbRows * 19,2))
        lane_norm = numpy.zeros((numbRows * 19,2))
        lengthLane = min(numbRows * 19,len(scene['lane']))
        lanes[:lengthLane,:2] = scene['lane'][:lengthLane,:2]
        lane_norm[:lengthLane,:2] = scene['lane_norm'] [:lengthLane,:2]
        laneInfo.append(numpy.dstack([lanes.reshape(60,19,2),lane_norm.reshape(60,19,2)]))
        inp.append(numpy.dstack([scene['p_in'], scene['v_in']]))
    inp = torch.FloatTensor(inp)
    laneInfo = torch.FloatTensor(laneInfo)
    city = torch.FloatTensor(city)
    return [inp,sceneIdxs,agentIds,trackIds,laneInfo,city]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=val_collate, num_workers=6)

In [19]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class FullyConnectedModel(nn.Module):
    def __init__(self):
        super(FullyConnectedModel, self).__init__()
        

        self.hidden_dim = 240*60
        
        self.linear = nn.Sequential(
            nn.Linear(240 * 39, self.hidden_dim),
            nn.ReLU(), 
            nn.Linear(self.hidden_dim,240 * 32),
            nn.ReLU(),
            nn.Linear(240*32, 240*30)
        )
#         self.bn1 = nn.BatchNorm1d(num_features=4)
        
    def forward(self, x):
        x = self.linear(x)

        return x
    
    def forwardTesting(self,x,num_steps=30):
        res =[]
        h = torch.zeros((self.num_layers,len(x),self.hidden_dim)).cuda()
        c = torch.zeros((self.num_layers,len(x),self.hidden_dim)).cuda()
        for steps in range(num_steps):
            x,(h,c) = self.lstm(x,(h,c))
            x = x[:,-1:]
            x = x.transpose(1,2)
            x = self.fc(x)
            x = x.transpose(1,2)
#             print("xxxx " ,x)
            res.append(x)
        res = torch.cat(res,1)
        return res


### Training


In [None]:
from statistics import mean
import random
import numpy as np

torch.cuda.empty_cache()
agent_id = 0
learning_rate = 1e-5
momentum = 0.01
device = torch.device("cuda:0")
input_dim = 4 * 60    # input dimension
hidden_dim = 3000  # hidden layer dimension should be greater when batch_sz small
layer_dim = 3    # number of hidden layers
output_dim = 4   # output dimension
n_epochs = 30

'''Define Loss, Optimizer'''
#model = RNNModel(input_dim, output_dim, hidden_dim, layer_dim).to(device)
model = FullyConnectedModel()
# model.load_state_dict(torch.load('./3000Temp.pth'))
model = model.to(device)
# optimizer = optim.Adagrad(model.parameters(), lr=learning_rate,lr_decay=1e-6,weight_decay=1e-5)
# optimizer = optim.Adadelta(model.parameters())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,weight_decay=1e-5,amsgrad=True)
# optimizer = optim.RMSprop(model.parameters(),lr=learning_rate,momentum=momentum)

model.train()
print("start")
loss_ema = -1
try:
    for i_epoch in range(n_epochs):
        for i_batch, sample_batch in enumerate(train_loader):
        #     print("test")
            city,laneInfo,inp, out = sample_batch
            laneInfo = laneInfo.cuda()
            inp = inp.cuda()
            out = out.cuda()
            city = city.cuda()
            optimizer.zero_grad()
            mixed = torch.cat([city,laneInfo,inp],2).transpose(1,2).reshape(-1,39,4 * 60).reshape(-1,39*240)
#             print(mixed.shape)
            
            y_pred = model(mixed)
            y_pred = y_pred.reshape((-1,30,60,4)).transpose(1,2)
#             print(y_pred.shape)
#             print("before inverse ",y_pred)
#             print("before inverse ",out[0])
            #THIS IS IF I NORMALIZED, I WANT TO DENORMALIZE OUT AND Y_PRED
#             for i in range(4):
#                 y_pred[:,:,:,i] = (y_pred[:,:,:,i] * (varMaxOutput[i] - varMinOutput[i])) + varMinOutput[i]
#                 out[:,:,:,i] = (out[:,:,:,i] * (varMaxOutput[i] - varMinOutput[i])) + varMinOutput[i]
#             print("after inverse y_pred: ",y_pred[0][0][0][0], " out: ",out[0][0][0][0])

#             print(y_pred.shape)
            
#             print(out.shape)
            #loss = nn.MSELoss()
            #loss = loss(y_pred,out)
#             print(y_pred)
            loss = (torch.mean((y_pred-out)**2))
            loss.backward()
            optimizer.step()
            if loss_ema < 0:
                loss_ema = loss
            loss_ema= loss_ema*0.90 +loss*0.1

            if i_batch % 20 == 0:
                print("epoch#{:d}".format(i_epoch),"batch#{:d}".format(i_batch),"scenes#{:d}".format(i_batch * batch_sz)," avg loss per scene(past 100): ",loss_ema.item(),loss.item())
        torch.save(model.state_dict(), './FCM_Adam_30Epochs_HDim7k_L2.pth')    
except KeyboardInterrupt:
    print("savedModel")
    torch.save(model.state_dict(), './FCM_Adam_30Epochs_HDim7k_l2.pth')


start
epoch#0 batch#0 scenes#0  avg loss per scene(past 100):  239863.703125 239863.703125
epoch#0 batch#20 scenes#200  avg loss per scene(past 100):  227730.703125 123027.0
epoch#0 batch#40 scenes#400  avg loss per scene(past 100):  131954.46875 94597.578125
epoch#0 batch#60 scenes#600  avg loss per scene(past 100):  88500.2578125 81683.5859375
epoch#0 batch#80 scenes#800  avg loss per scene(past 100):  60515.84375 56603.91796875
epoch#0 batch#100 scenes#1000  avg loss per scene(past 100):  47431.265625 40524.05078125
epoch#0 batch#120 scenes#1200  avg loss per scene(past 100):  47396.36328125 29581.623046875
epoch#0 batch#140 scenes#1400  avg loss per scene(past 100):  42298.6953125 50429.24609375
epoch#0 batch#160 scenes#1600  avg loss per scene(past 100):  37675.1328125 38992.7890625
epoch#0 batch#180 scenes#1800  avg loss per scene(past 100):  30586.404296875 23041.03125
epoch#0 batch#200 scenes#2000  avg loss per scene(past 100):  35121.0859375 43862.734375
epoch#0 batch#220 scen

epoch#0 batch#1760 scenes#17600  avg loss per scene(past 100):  3462.19189453125 2329.532470703125
epoch#0 batch#1780 scenes#17800  avg loss per scene(past 100):  3133.47802734375 4082.1328125
epoch#0 batch#1800 scenes#18000  avg loss per scene(past 100):  2959.32666015625 1562.3887939453125
epoch#0 batch#1820 scenes#18200  avg loss per scene(past 100):  2727.551025390625 2532.47607421875
epoch#0 batch#1840 scenes#18400  avg loss per scene(past 100):  3635.408935546875 1662.75048828125
epoch#0 batch#1860 scenes#18600  avg loss per scene(past 100):  3061.597412109375 2322.55908203125
epoch#0 batch#1880 scenes#18800  avg loss per scene(past 100):  3480.697998046875 1758.5433349609375
epoch#0 batch#1900 scenes#19000  avg loss per scene(past 100):  3072.523681640625 2035.44384765625
epoch#0 batch#1920 scenes#19200  avg loss per scene(past 100):  3447.786865234375 1673.63232421875
epoch#0 batch#1940 scenes#19400  avg loss per scene(past 100):  2727.772216796875 2603.896240234375
epoch#0 bat

epoch#0 batch#3420 scenes#34200  avg loss per scene(past 100):  1381.0311279296875 2121.402099609375
epoch#0 batch#3440 scenes#34400  avg loss per scene(past 100):  983.0147094726562 659.2813720703125
epoch#0 batch#3460 scenes#34600  avg loss per scene(past 100):  1244.8804931640625 623.2198486328125
epoch#0 batch#3480 scenes#34800  avg loss per scene(past 100):  1160.815185546875 509.2760009765625
epoch#0 batch#3500 scenes#35000  avg loss per scene(past 100):  1523.470703125 2336.001708984375
epoch#0 batch#3520 scenes#35200  avg loss per scene(past 100):  2186.9951171875 819.8563232421875
epoch#0 batch#3540 scenes#35400  avg loss per scene(past 100):  1076.76513671875 508.94610595703125
epoch#0 batch#3560 scenes#35600  avg loss per scene(past 100):  1219.9793701171875 1849.2752685546875
epoch#0 batch#3580 scenes#35800  avg loss per scene(past 100):  1170.9837646484375 417.55712890625
epoch#0 batch#3600 scenes#36000  avg loss per scene(past 100):  1249.403076171875 433.3727722167969
ep

epoch#0 batch#5080 scenes#50800  avg loss per scene(past 100):  969.2888793945312 631.1041870117188
epoch#0 batch#5100 scenes#51000  avg loss per scene(past 100):  705.234130859375 837.1282958984375
epoch#0 batch#5120 scenes#51200  avg loss per scene(past 100):  1440.3067626953125 1041.02490234375
epoch#0 batch#5140 scenes#51400  avg loss per scene(past 100):  875.6458740234375 554.9942016601562
epoch#0 batch#5160 scenes#51600  avg loss per scene(past 100):  878.2565307617188 2318.318603515625
epoch#0 batch#5180 scenes#51800  avg loss per scene(past 100):  977.2085571289062 227.0517578125
epoch#0 batch#5200 scenes#52000  avg loss per scene(past 100):  1141.7720947265625 315.6935729980469
epoch#0 batch#5220 scenes#52200  avg loss per scene(past 100):  578.8165283203125 245.27000427246094
epoch#0 batch#5240 scenes#52400  avg loss per scene(past 100):  638.941162109375 628.5588989257812
epoch#0 batch#5260 scenes#52600  avg loss per scene(past 100):  601.4143676757812 312.620849609375
epoc

epoch#0 batch#6740 scenes#67400  avg loss per scene(past 100):  547.8971557617188 586.0022583007812
epoch#0 batch#6760 scenes#67600  avg loss per scene(past 100):  500.23309326171875 220.83294677734375
epoch#0 batch#6780 scenes#67800  avg loss per scene(past 100):  354.785888671875 223.70767211914062
epoch#0 batch#6800 scenes#68000  avg loss per scene(past 100):  458.61895751953125 116.99038696289062
epoch#0 batch#6820 scenes#68200  avg loss per scene(past 100):  705.9547729492188 1195.91259765625
epoch#0 batch#6840 scenes#68400  avg loss per scene(past 100):  542.5673217773438 112.22792053222656
epoch#0 batch#6860 scenes#68600  avg loss per scene(past 100):  955.93896484375 1285.9114990234375
epoch#0 batch#6880 scenes#68800  avg loss per scene(past 100):  542.2310791015625 1454.50244140625
epoch#0 batch#6900 scenes#69000  avg loss per scene(past 100):  452.43133544921875 154.3436737060547
epoch#0 batch#6920 scenes#69200  avg loss per scene(past 100):  381.10406494140625 262.9808654785

epoch#0 batch#8380 scenes#83800  avg loss per scene(past 100):  422.8091125488281 113.03434753417969
epoch#0 batch#8400 scenes#84000  avg loss per scene(past 100):  350.0228271484375 139.5830078125
epoch#0 batch#8420 scenes#84200  avg loss per scene(past 100):  389.1553039550781 151.33641052246094
epoch#0 batch#8440 scenes#84400  avg loss per scene(past 100):  212.9105987548828 76.44017791748047
epoch#0 batch#8460 scenes#84600  avg loss per scene(past 100):  273.03472900390625 147.91563415527344
epoch#0 batch#8480 scenes#84800  avg loss per scene(past 100):  313.828369140625 114.59235382080078
epoch#0 batch#8500 scenes#85000  avg loss per scene(past 100):  262.012451171875 107.12844848632812
epoch#0 batch#8520 scenes#85200  avg loss per scene(past 100):  324.357177734375 178.06736755371094
epoch#0 batch#8540 scenes#85400  avg loss per scene(past 100):  223.73550415039062 86.33541870117188
epoch#0 batch#8560 scenes#85600  avg loss per scene(past 100):  319.9376220703125 1118.00109863281

epoch#0 batch#10020 scenes#100200  avg loss per scene(past 100):  294.585693359375 195.99497985839844
epoch#0 batch#10040 scenes#100400  avg loss per scene(past 100):  196.30618286132812 271.7266845703125
epoch#0 batch#10060 scenes#100600  avg loss per scene(past 100):  607.0221557617188 207.52830505371094
epoch#0 batch#10080 scenes#100800  avg loss per scene(past 100):  404.41021728515625 918.6416625976562
epoch#0 batch#10100 scenes#101000  avg loss per scene(past 100):  509.8984069824219 214.5258026123047
epoch#0 batch#10120 scenes#101200  avg loss per scene(past 100):  312.16986083984375 126.79512786865234
epoch#0 batch#10140 scenes#101400  avg loss per scene(past 100):  207.30068969726562 91.54048919677734
epoch#0 batch#10160 scenes#101600  avg loss per scene(past 100):  390.3963317871094 233.2445831298828
epoch#0 batch#10180 scenes#101800  avg loss per scene(past 100):  389.1741027832031 250.98167419433594
epoch#0 batch#10200 scenes#102000  avg loss per scene(past 100):  200.05928

epoch#0 batch#11640 scenes#116400  avg loss per scene(past 100):  295.058837890625 136.65664672851562
epoch#0 batch#11660 scenes#116600  avg loss per scene(past 100):  202.675537109375 83.93341827392578
epoch#0 batch#11680 scenes#116800  avg loss per scene(past 100):  167.80470275878906 106.26513671875
epoch#0 batch#11700 scenes#117000  avg loss per scene(past 100):  379.3274230957031 628.3902587890625
epoch#0 batch#11720 scenes#117200  avg loss per scene(past 100):  304.3170471191406 105.48664093017578
epoch#0 batch#11740 scenes#117400  avg loss per scene(past 100):  166.2720184326172 77.7982406616211
epoch#0 batch#11760 scenes#117600  avg loss per scene(past 100):  229.1732940673828 218.19049072265625
epoch#0 batch#11780 scenes#117800  avg loss per scene(past 100):  238.51341247558594 290.62939453125
epoch#0 batch#11800 scenes#118000  avg loss per scene(past 100):  557.7421264648438 216.66375732421875
epoch#0 batch#11820 scenes#118200  avg loss per scene(past 100):  747.4617309570312

epoch#0 batch#13260 scenes#132600  avg loss per scene(past 100):  160.8447265625 65.5428237915039
epoch#0 batch#13280 scenes#132800  avg loss per scene(past 100):  188.39566040039062 93.12960815429688
epoch#0 batch#13300 scenes#133000  avg loss per scene(past 100):  527.8881225585938 193.4725341796875
epoch#0 batch#13320 scenes#133200  avg loss per scene(past 100):  245.87168884277344 59.36488342285156
epoch#0 batch#13340 scenes#133400  avg loss per scene(past 100):  206.1551971435547 64.46488952636719
epoch#0 batch#13360 scenes#133600  avg loss per scene(past 100):  310.76983642578125 487.9817810058594
epoch#0 batch#13380 scenes#133800  avg loss per scene(past 100):  222.253662109375 67.62409210205078
epoch#0 batch#13400 scenes#134000  avg loss per scene(past 100):  155.50592041015625 110.00395202636719
epoch#0 batch#13420 scenes#134200  avg loss per scene(past 100):  209.1947021484375 168.36195373535156
epoch#0 batch#13440 scenes#134400  avg loss per scene(past 100):  614.33612060546

epoch#0 batch#14860 scenes#148600  avg loss per scene(past 100):  97.78162384033203 62.416988372802734
epoch#0 batch#14880 scenes#148800  avg loss per scene(past 100):  183.03968811035156 58.84638977050781
epoch#0 batch#14900 scenes#149000  avg loss per scene(past 100):  232.7743682861328 60.170902252197266
epoch#0 batch#14920 scenes#149200  avg loss per scene(past 100):  259.1224365234375 106.85298919677734
epoch#0 batch#14940 scenes#149400  avg loss per scene(past 100):  178.2078094482422 65.67948913574219
epoch#0 batch#14960 scenes#149600  avg loss per scene(past 100):  234.0336151123047 1127.168701171875
epoch#0 batch#14980 scenes#149800  avg loss per scene(past 100):  123.32902526855469 197.05104064941406
epoch#0 batch#15000 scenes#150000  avg loss per scene(past 100):  188.77951049804688 123.2465591430664
epoch#0 batch#15020 scenes#150200  avg loss per scene(past 100):  110.34806823730469 77.47073364257812
epoch#0 batch#15040 scenes#150400  avg loss per scene(past 100):  109.0088

epoch#0 batch#16460 scenes#164600  avg loss per scene(past 100):  159.84413146972656 58.18214797973633
epoch#0 batch#16480 scenes#164800  avg loss per scene(past 100):  118.97081756591797 305.8617858886719
epoch#0 batch#16500 scenes#165000  avg loss per scene(past 100):  376.2129821777344 84.37897491455078
epoch#0 batch#16520 scenes#165200  avg loss per scene(past 100):  235.41259765625 87.7412109375
epoch#0 batch#16540 scenes#165400  avg loss per scene(past 100):  178.579345703125 343.9378662109375
epoch#0 batch#16560 scenes#165600  avg loss per scene(past 100):  103.5271987915039 46.45777130126953
epoch#0 batch#16580 scenes#165800  avg loss per scene(past 100):  86.89533233642578 58.575279235839844
epoch#0 batch#16600 scenes#166000  avg loss per scene(past 100):  122.42393493652344 48.66861343383789
epoch#0 batch#16620 scenes#166200  avg loss per scene(past 100):  273.4433898925781 44.2552375793457
epoch#0 batch#16640 scenes#166400  avg loss per scene(past 100):  143.87701416015625 3

epoch#0 batch#18080 scenes#180800  avg loss per scene(past 100):  351.6915588378906 413.6252746582031
epoch#0 batch#18100 scenes#181000  avg loss per scene(past 100):  211.22923278808594 56.568912506103516
epoch#0 batch#18120 scenes#181200  avg loss per scene(past 100):  79.5420913696289 53.42406463623047
epoch#0 batch#18140 scenes#181400  avg loss per scene(past 100):  109.4108657836914 72.96163177490234
epoch#0 batch#18160 scenes#181600  avg loss per scene(past 100):  177.42056274414062 67.44419860839844
epoch#0 batch#18180 scenes#181800  avg loss per scene(past 100):  98.62753295898438 135.82203674316406
epoch#0 batch#18200 scenes#182000  avg loss per scene(past 100):  221.4917449951172 69.21240997314453
epoch#0 batch#18220 scenes#182200  avg loss per scene(past 100):  199.473388671875 161.01441955566406
epoch#0 batch#18240 scenes#182400  avg loss per scene(past 100):  151.06040954589844 61.40392303466797
epoch#0 batch#18260 scenes#182600  avg loss per scene(past 100):  356.52810668

epoch#0 batch#19700 scenes#197000  avg loss per scene(past 100):  123.50572967529297 41.93190383911133
epoch#0 batch#19720 scenes#197200  avg loss per scene(past 100):  82.89351654052734 34.512939453125
epoch#0 batch#19740 scenes#197400  avg loss per scene(past 100):  143.42771911621094 137.0289764404297
epoch#0 batch#19760 scenes#197600  avg loss per scene(past 100):  116.4769515991211 137.27432250976562
epoch#0 batch#19780 scenes#197800  avg loss per scene(past 100):  86.83885955810547 114.54083251953125
epoch#0 batch#19800 scenes#198000  avg loss per scene(past 100):  240.08932495117188 67.7737045288086
epoch#0 batch#19820 scenes#198200  avg loss per scene(past 100):  119.66056823730469 55.939083099365234
epoch#0 batch#19840 scenes#198400  avg loss per scene(past 100):  195.19595336914062 125.37928009033203
epoch#0 batch#19860 scenes#198600  avg loss per scene(past 100):  494.24652099609375 120.57262420654297
epoch#0 batch#19880 scenes#198800  avg loss per scene(past 100):  128.6215

In [8]:
torch.save(model.state_dict(), './FCM_Adam_30Epochs_HDim7k_L2.pth')

In [17]:
import numpy as np
# import pandas as pd

save_file = "FCM_Adam_30Epochs_HDim7k_L2.csv"


header = ["ID"]
header += ["v"+str(x) for x in range(1, 61)]
print(header)

with open(save_file, 'w') as f:
    f.write(",".join(header)+"\n")

device = "cuda:0"

testmodel = FullyConnectedModel()
testmodel.load_state_dict(torch.load('FCM_Adam_30Epochs_HDim7k_L2.pth'))
testmodel.to(device)
testmodel.eval()

full_out = []
print("test")
varMaxOutput= [4773.,   4097.7,   193.19,  194.33]
varMinOutput=[ -53.912,    0.,    -210.04,  -187.71 ]
#         #Changed to 4 because we don't need lane and lane norm
batch_sz = 2
     
for i_batch, sample_batch in enumerate(val_loader):
    if i_batch % 100 == 0:
        print("batch #: ", i_batch * batch_sz)

    inp, scene_idx, agent_ids, track_ids,laneInfo,city = sample_batch
    city = city.cuda()
    laneInfo = laneInfo.cuda()
    inp = inp.cuda()

    mixed = torch.cat([city,laneInfo,inp],2).transpose(1,2).reshape(-1,39,4 * 60).reshape(-1,39*240)
    y_pred = testmodel(mixed)
    y_pred = y_pred.reshape((-1,30,60,4)).transpose(1,2)
    

    i_scene = 0
    out = []
    for ag_id in agent_ids:
        nestedTrackIds = track_ids[i_scene]
        j = 0
        for track_id in nestedTrackIds:

            if ag_id == track_id[0]:
                out.append(y_pred[i_scene][j][:,:2].reshape(60))
            j+=1
        i_scene+=1
                           
 

    with open(save_file, "a") as f:
        for i_scene in range(len(out)):
#             realOut = out[i].eval(session=tf.compat.v1.Session())
            array = out[i_scene]
#             print(out[i][0])
            row = ','.join([str(element.item()) for element in array])
            f.write(str(scene_idx[i_scene])+"," + row +"\n")


['ID', 'v1', 'v2', 'v3', 'v4', 'v5', 'v6', 'v7', 'v8', 'v9', 'v10', 'v11', 'v12', 'v13', 'v14', 'v15', 'v16', 'v17', 'v18', 'v19', 'v20', 'v21', 'v22', 'v23', 'v24', 'v25', 'v26', 'v27', 'v28', 'v29', 'v30', 'v31', 'v32', 'v33', 'v34', 'v35', 'v36', 'v37', 'v38', 'v39', 'v40', 'v41', 'v42', 'v43', 'v44', 'v45', 'v46', 'v47', 'v48', 'v49', 'v50', 'v51', 'v52', 'v53', 'v54', 'v55', 'v56', 'v57', 'v58', 'v59', 'v60']
test
batch #:  0
batch #:  200
batch #:  400
batch #:  600
batch #:  800
batch #:  1000
batch #:  1200
batch #:  1400
batch #:  1600
batch #:  1800
batch #:  2000
batch #:  2200
batch #:  2400
batch #:  2600
batch #:  2800
batch #:  3000
