In [1]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from data import ArgoDataset, collate_fn
from utils import gpu, to_long,  Optimizer, StepLR

from layers import Conv1d, Res1d, Linear, LinearRes, Null
from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

In [None]:
import os

import argparse
import numpy as np
import random
import sys
import time
import shutil
from importlib import import_module
from numbers import Number

import torch
from torch.utils.data import Sampler, DataLoader


from utils import Logger, load_pretrain

In [None]:
from lanegcn import get_model
import matplotlib.pyplot as plt

In [None]:
config, Dataset, collate_fn, net, loss, post_process, opt = get_model()

In [None]:
def worker_init_fn(pid):
    np_seed = int(pid)
    np.random.seed(np_seed)
    random_seed = np.random.randint(2 ** 32 - 1)
    random.seed(random_seed)

dataset = Dataset('./dataset/preprocess/train_crs_dist6_angle90.p', config, train=True)
train_loader = DataLoader(
        dataset,
        batch_size=config["batch_size"],
        num_workers=config["workers"],
        shuffle=False,   # True: At each epoch, reorder the data
        collate_fn=collate_fn,
        pin_memory=True,
        worker_init_fn=worker_init_fn,   # The next 36 were thrown away
        drop_last=True,
    )

In [None]:
data = {}
for i, data in enumerate(train_loader):
    data = dict(data)
    break
    


print(data.keys())      

In [None]:
data['gt_preds'][0].size()

In [None]:
out_rel, out = net(data)
traj_rel = out_rel['reg']
traj = out['reg']
print(traj_rel[0].size())

In [None]:
plt.scatter(traj_rel[0][0][0][:,0].detach().numpy(),traj_rel[0][0][0][:,1].detach().numpy())

plt.scatter(traj[0][0][0][:,0].detach().numpy(),traj[0][0][0][:,1].detach().numpy())

In [None]:
from lanegcn import ActorNet, PredNet, MapNet, A2A, A2M, M2A, M2M, graph_gather, actor_gather
pred_net = PredNet(config)

actor_net = ActorNet(config)
map_net = MapNet(config)
a2m = A2M(config)
m2m = M2M(config)
m2a = M2A(config)
a2a = A2A(config)

# construct actor feature
actors, actor_idcs = actor_gather(data["feats"])
actor_ctrs = data["ctrs"]
actors = actor_net(actors)
# construct map features
graph = graph_gather(to_long(data["graph"]))
nodes, node_idcs, node_ctrs = map_net(graph)
nodes = a2m(nodes, graph, actors, actor_idcs, actor_ctrs)
nodes = m2m(nodes, graph)
actors = m2a(actors, actor_idcs, actor_ctrs, nodes, node_idcs, node_ctrs)
actors = a2a(actors, actor_idcs, actor_ctrs)

# prediction
out = pred_net(actors, actor_idcs, actor_ctrs)
rot, orig = data["rot"], data["orig"]

In [None]:
print(out.keys())
traj = out['reg']
plt.scatter(traj[0][0][0][:,0].detach().numpy(),traj[0][0][0][:,1].detach().numpy())

In [None]:
def ref_copy(data):
    if isinstance(data, list):
        return [ref_copy(x) for x in data]
    if isinstance(data, dict):
        d = dict()
        for key in data:
            d[key] = ref_copy(data[key])
        return d
    return data



out1 = dict()
for key in ['reg', 'cls']:
    if key in out:
        out1[key] = ref_copy(out[key])

for i in range(len(out["reg"])):
    out1["reg"][i] = torch.matmul(out["reg"][i], rot[i]) + orig[i].view(1, 1, 1, -1)


In [None]:
traj_rel = out['reg']
traj = out1['reg']
plt.scatter(traj_rel[0][0][0][:,0].detach().numpy(),traj_rel[0][0][0][:,1].detach().numpy())


In [None]:
plt.scatter(traj[0][0][0][:,0].detach().numpy(),traj[0][0][0][:,1].detach().numpy())

In [None]:
from lanegcn import get_fake_traj_rel, get_pred_traj_rel
fake_traj_rel = get_fake_traj_rel(data['traj1'], out['reg'])
pred_traj_rel = get_pred_traj_rel(data['trajs2'])

In [None]:
from lanegcn import TrajectoryDiscriminator
discriminator = TrajectoryDiscriminator(config)
scores_real = discriminator(pred_traj_rel)
scores_fake = discriminator(fake_traj_rel)

In [None]:
print(scores_real.size(), scores_fake.size())

D_step

In [None]:
out_rel, out = net(data)
fake_traj_rel = get_fake_traj_rel(data['traj1'], out_rel['reg'])
pred_traj_rel = get_pred_traj_rel(data['trajs2'])


In [None]:
print(pred_traj_rel.size())

In [None]:
plt.scatter(fake_traj_rel[0][0,:].detach().numpy(), fake_traj_rel[0][1,:].detach().numpy())
plt.scatter(pred_traj_rel[0][0,:].detach().numpy(), pred_traj_rel[0][1,:].detach().numpy())

In [None]:
from lanegcn import TrajectoryDiscriminator
discriminator = TrajectoryDiscriminator(config)

scores_fake = discriminator(fake_traj_rel)
scores_real = discriminator(pred_traj_rel)


In [None]:
from loss import gan_d_loss
# Compute loss with optional gradient penalty
d_loss_fn = gan_d_loss
losses = {}
loss = torch.zeros(1)

data_loss = d_loss_fn(scores_real, scores_fake)
losses['D_data_loss'] = data_loss.item()
loss += data_loss
losses['D_total_loss'] = loss.item()


In [None]:
losses

In [None]:
import torch.optim as optim
optimizer_d = optim.Adam(discriminator.parameters(), lr=5e-4)

optimizer_d.zero_grad()
loss.backward()

optimizer_d.step()

In [None]:
losses

g_step

In [None]:
from loss import gan_g_loss
from lanegcn import Loss
# g_loss
def bce_loss(input, target):
    neg_abs = -input.abs()
    loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
    return loss.mean()
    
def gan_g_loss(scores_fake):
    y_fake = torch.ones_like(scores_fake) * random.uniform(0.7, 1.2)
    return bce_loss(scores_fake, y_fake)

loss_fn = Loss(config)
loss_out = loss_fn(out_rel, data)

In [None]:
print(loss_out.keys())

In [None]:
losses = {}
loss = torch.zeros(1)

losses['loss_reg_cls'] = loss_out["loss"].item()



g_loss_fn = gan_g_loss
out_rel, out = net(data)
fake_traj_rel = get_fake_traj_rel(data['traj1'], out_rel['reg'])
pred_traj_rel = get_pred_traj_rel(data['trajs2'])
scores_fake = discriminator(fake_traj_rel)
scores_real = discriminator(pred_traj_rel)
discriminator_loss = g_loss_fn(scores_fake)

loss += loss_out["loss"]
loss += discriminator_loss

losses['G_discriminator_loss'] = discriminator_loss.item()
losses['G_total_loss'] = loss.item()

optimizer_g = optim.Adam(net.parameters(), lr=5e-4)

optimizer_g.zero_grad()
loss.backward()

optimizer_g.step()