In [1]:
import numpy as np
import os
import sys
from fractions import gcd
from numbers import Number

import torch
from torch import Tensor, nn
from torch.nn import functional as F

from data import ArgoDataset, collate_fn
from utils import gpu, to_long,  Optimizer, StepLR

from layers import Conv1d, Res1d, Linear, LinearRes, Null
from numpy import float64, ndarray
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union

from lanegcn import PredNet, get_model
import torch
from torch.utils.data import Sampler, DataLoader

import matplotlib.pyplot as plt

config, Dataset, collate_fn, net, loss, post_process, opt = get_model()
import os

import argparse
import numpy as np
import random
import sys
import time
import shutil
from importlib import import_module
from numbers import Number

import torch
from torch.utils.data import Sampler, DataLoader


from utils import Logger, load_pretrain
def worker_init_fn(pid):
    np_seed = int(pid)
    np.random.seed(np_seed)
    random_seed = np.random.randint(2 ** 32 - 1)
    random.seed(random_seed)


dataset = Dataset(config["train_split"], config, train=True)
train_loader = DataLoader(
    dataset,
    batch_size=config["batch_size"],
    num_workers=config["workers"],
    shuffle=False,
    collate_fn=collate_fn,
    pin_memory=True,
    worker_init_fn=worker_init_fn,
    drop_last=True,
)


for i, data in enumerate(train_loader):
    break


In [2]:
from lanegcn import ActorNet, MapNet, actor_gather, graph_gather
actor_net = ActorNet(config)
map_net = MapNet(config)

# construct actor feature
actors, actor_idcs = actor_gather(gpu(data["feats"]))
actor_ctrs = gpu(data["ctrs"])
actors = actor_net(actors)

# construct map features
graph = graph_gather(to_long(gpu(data["graph"])))
nodes, node_idcs, node_ctrs = map_net(graph)

In [5]:
from lanegcn import A2A,A2M,M2A,M2M
a2m = A2M(config)
m2m = M2M(config)
m2a = M2A(config)
a2a = A2A(config)

In [6]:
nodes = a2m(nodes, graph, actors, actor_idcs, actor_ctrs)
nodes = m2m(nodes, graph)
actors = m2a(actors, actor_idcs, actor_ctrs, nodes, node_idcs, node_ctrs)
actors = a2a(actors, actor_idcs, actor_ctrs)


In [3]:
from lanegcn import GAT
gat = GAT(config)
actors_gat = gat(actors,actor_idcs,actor_ctrs,nodes, node_idcs, node_ctrs,100)
actors_gat = gat(actors,actor_idcs,actor_ctrs,actors,actor_idcs,actor_ctrs,7)

In [4]:
actors_gat.size()

torch.Size([528, 128])

In [5]:
agts = actors
agt_idcs = actor_idcs
agt_ctrs = actor_ctrs
ctx = actors
ctx_idcs = actor_idcs
ctx_ctrs = actor_ctrs
dist_th = 6

In [6]:
res = agts
batch_size = len(agt_idcs)
hi, wi = [], []
hi_count, wi_count = 0, 0
for i in range(batch_size):
    dist = agt_ctrs[i].view(-1, 1, 2) - ctx_ctrs[i].view(1, -1, 2)
    dist = torch.sqrt((dist ** 2).sum(2))
    mask = dist <= dist_th

    idcs = torch.nonzero(mask, as_tuple=False)
    if len(idcs) == 0:
        continue

    hi.append(idcs[:, 0] + hi_count)
    wi.append(idcs[:, 1] + wi_count)
    hi_count += len(agt_idcs[i])
    wi_count += len(ctx_idcs[i])
hi = torch.cat(hi, 0)
wi = torch.cat(wi, 0)

In [7]:
print(hi_count, wi_count)

528 528


In [9]:
dist = agt_ctrs[0].view(-1, 1, 2) - ctx_ctrs[0].view(1, -1, 2)
dist = torch.sqrt((dist ** 2).sum(2))
mask = dist <= dist_th
print(mask)
idcs = torch.nonzero(mask, as_tuple=False)
print(idcs)

tensor([[ True, False, False, False, False, False],
        [False,  True, False, False, False, False],
        [False, False,  True, False,  True, False],
        [False, False, False,  True, False, False],
        [False, False,  True, False,  True, False],
        [False, False, False, False, False,  True]])
tensor([[0, 0],
        [1, 1],
        [2, 2],
        [2, 4],
        [3, 3],
        [4, 2],
        [4, 4],
        [5, 5]])


In [None]:
fc = nn.Linear(128, 128)
att_fc = nn.Linear(2*128, 1)
dist = nn.Sequential(
            nn.Linear(2, 128),
            nn.ReLU(inplace=True),
            Linear(128, 128, norm='GN', ng=1),
        )

In [None]:
agt_ctrs = torch.cat(agt_ctrs, 0)
ctx_ctrs = torch.cat(ctx_ctrs, 0)
distance = agt_ctrs[hi] - ctx_ctrs[wi]
distance_dim = dist(distance)

In [None]:
distance_dim.size()

torch.Size([892, 128])

In [None]:
z = fc(agts)
z_i = z[hi]
z_j = z[wi]

In [None]:
z_ij = torch.cat((z_i, z_j), dim =-1)
a_ij = att_fc(z_ij)
print(a_ij.size())

torch.Size([892, 1])


In [None]:
torch.nonzero(hi==2).squeeze()

tensor([2, 3])

In [None]:
hi[torch.nonzero(hi==2).squeeze()]

tensor([2, 2])

In [None]:
print(z_i.size(), z_j.size())
print(hi[:10],wi[:10])

torch.Size([892, 128]) torch.Size([892, 128])
tensor([0, 1, 2, 2, 3, 4, 4, 5, 6, 7]) tensor([0, 1, 2, 4, 3, 2, 4, 5, 6, 7])


In [None]:
a_ij.size()

torch.Size([892, 1])

In [None]:
alpha_ij = []
for i in range(527):
    # print(a_ij[torch.nonzero(hi==i).squeeze()])
    index = torch.nonzero(hi==i).squeeze(-1)
    alpha_ij.append(F.softmax(a_ij[index], dim=0))
    
    a_ij_new = F.softmax(a_ij[index], dim=0)
    # print(a_ij_new)
    x_i = z_i[torch.nonzero(hi==i)[0]]
    x_j = z_j[index]
    print(x_i.size(), x_j.size())
    distance_dim_ij = distance_dim[index]
    print(distance_dim_ij.size())

    x_j_new = torch.cat((distance_dim_ij, x_j), dim=-1)
    print(x_j_new.size())
    print('*', (a_ij_new*x_j_new).size())



    x_j_new_ = torch.sum(a_ij_new*x_j_new, dim=0).unsqueeze(0)
    print("1", x_j_new_.size())
    
    x = torch.cat((x_i, x_j_new_), dim=1)
    print(x.size())
    # break
# print(alpha_ij[15])

torch.Size([1, 128]) torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 256])
* torch.Size([1, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 256])
* torch.Size([1, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([2, 128])
torch.Size([2, 128])
torch.Size([2, 256])
* torch.Size([2, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 256])
* torch.Size([1, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([2, 128])
torch.Size([2, 128])
torch.Size([2, 256])
* torch.Size([2, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 256])
* torch.Size([1, 256])
1 torch.Size([1, 256])
torch.Size([1, 384])
torch.Size([1, 128]) torch.Size([1, 128])
torch.Size([1, 128])
torch.Size([1, 256])
* torch.Si

In [None]:
alpha_ij

[tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.4997],
         [0.5003]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.4997],
         [0.5003]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.5024],
         [0.4976]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.3248],
         [0.3357],
         [0.3396]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.5024],
         [0.4976]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([[0.3407],
         [0.3236],
         [0.3358]], grad_fn=<SoftmaxBackward>),
 tensor([[0.3248],
         [0.3357],
         [0.3396]], grad_fn=<SoftmaxBackward>),
 tensor([1.], grad_fn=<SoftmaxBackward>),
 tensor([1.]

In [None]:
alpha_ij = []
h_new = []
for i in range(hi[-1]+1):
    index = torch.nonzero(hi==i).squeeze(-1)

    alpha_ij.append(F.softmax(a_ij[index], dim=0))

    a_ij_new = F.softmax(a_ij[index], dim=0)
    
    x_i = z_i[torch.nonzero(hi==i)[0]]
    x_j = z_j[index]
    distance_dim_ij = distance_dim[index]

    x_j_new = torch.cat((distance_dim_ij, x_j), dim=-1)
    # print((alpha*x_j_new).size())

    x_j_new_ = torch.sum(a_ij_new*x_j_new, dim=0).unsqueeze(0)

    # print(x_i.size(), x_j_new_.size())
    h_new.append(torch.cat((x_i, x_j_new_), -1))
    # break


In [None]:
h_new_ = torch.cat(h_new)
h_new_.size()

torch.Size([528, 384])

In [None]:
norm = nn.GroupNorm(1, 128*3)
linear = Linear(128*3, 128, norm='GN', ng=1, act=False)
relu = nn.ReLU(inplace=True)

In [None]:
agts = norm(h_new_)
agts = relu(agts)

agts = linear(agts)
agts += res
gts = relu(agts)


In [None]:
gts.size()

torch.Size([528, 128])