In [1]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset
import pandas as pd

class CarlaDataset(Dataset):
    def __init__(self, traj_path=None, map_path=None):
        # Use provided paths or default to None
        self.traj_path = traj_path or './traj_data.pkl'
        self.map_path = map_path or './map_dict.pkl'
        self.data = pickle.load(open(self.traj_path, 'rb'))
        self.map_dict = pickle.load(open(self.map_path, 'rb'))

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        traj, label = self.data[index]
        traj = torch.tensor(traj.copy())
        label = torch.tensor(label.copy())

        # 获取target第19帧的坐标
        target_coordinate = traj[0][-1][:2].numpy().tolist()
        # 计算每个字典中坐标与 target_coordinate 的欧氏距离
        distances = {key: np.linalg.norm(np.array(target_coordinate) - np.array(key)) for key in self.map_dict.keys()}
        # 找到最小距离对应的字典的value
        closest_coordinate = min(distances, key=distances.get)
        closest_value = self.map_dict[closest_coordinate]

        lane_list = [torch.tensor(df[['x', 'y']].values, dtype=torch.float32) for df in closest_value]


        data = {'feat': traj, 'lane_list': lane_list, 'label': label}
        return data

    def __iter__(self):
        return iter(range(len(self)))  # Make the dataset iterable

# Example of usage
dataset = CarlaDataset()
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)


ModuleNotFoundError: No module named 'pandas.core.indexes.numeric'

In [3]:
waypoints_xy = pickle.load(open('waypoints_xy.pkl', 'rb'))

In [4]:
from scipy.spatial import cKDTree
import numpy as np
import matplotlib.pyplot as plt

waypoints_xy = np.array(waypoints_xy)
tree = cKDTree(waypoints_xy)

idx = 14
target_coordinate = train_data[idx]['feat'][0][-1][:2].numpy().tolist()

# 在目标坐标周围20m的范围内进行查询
radius = 30.0
target_coordinateinate = train_data[idx]['feat'][0][-1][:2].numpy().tolist()
indices = tree.query_ball_point(target_coordinateinate, radius)

# 获取符合条件的坐标点
filtered_coords = waypoints_xy[indices]
print(len(filtered_coords))
# 绘制所有坐标点
# plt.scatter(waypoints_xy[:, 0], waypoints_xy[:, 1], label='All Points', alpha=0.5, marker='.')

# 绘制目标坐标
plt.scatter(target_coordinate[0], target_coordinate[1], color='red', label='Target', marker='o', s=100)

# 绘制在目标周围20m范围内的坐标点
plt.scatter(filtered_coords[:, 0], filtered_coords[:, 1], color='grey', label='Within 20m', marker='x', s=100)



# plt.plot(train_data[idx]['feat'][0][:, 0], train_data[idx]['feat'][0][:, 1], '-o', c='red')
# plt.plot(train_data[idx]['label'][:, 0], train_data[idx]['label'][:, 1], '-o', c='blue')


# 添加标签和标题
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Spatial Filtering of Coordinates')
plt.legend()
plt.axis('equal')
# 显示图形
# plt.show()
plt.savefig('Spatial Filtering of Coordinates.png')

NameError: name 'train_data' is not defined

dataset
!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

In [2]:
import pickle
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from scipy.spatial import cKDTree

class CarlaDataset(Dataset):
    def __init__(self, traj_path=None, map_path=None):
        # Use provided paths or default to None
        self.traj_path = traj_path or './trajectories_trafficLight.pkl'
        self.waypoints_path = map_path or './waypoints_xy.pkl'
        self.data = pickle.load(open(self.traj_path, 'rb'))
        self.waypoints_xy = np.array(pickle.load(open(self.waypoints_path, 'rb')))
        self.tree = cKDTree(self.waypoints_xy)
        self.radius = 30.0

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        traj, label = self.data[index]
        traj = torch.tensor(traj.copy())
        label = torch.tensor(label.copy())

        target_coordinateinate = traj[0][-1][:2].numpy().tolist()
        indices = self.tree.query_ball_point(target_coordinateinate, self.radius)
        # 获取符合条件的坐标点
        filtered_coords = self.waypoints_xy[indices]
        nbr_waypoints = torch.tensor(filtered_coords)
        
        n = len(traj)
        # 初始化一个大小为（n，2）的tensor
        ctrs = torch.zeros((n, 2))
        # 遍历前n个轨迹
        for i in range(n):
            last_row = traj[i][-1][:2].clone().detach().requires_grad_(True)
            ctrs[i] = last_row

        data = {'feat': traj, 'ctrs': ctrs, 'nbr_waypoints': nbr_waypoints, 'label': label}
        return data


    def __iter__(self):
        return iter(range(len(self)))  # Make the dataset iterable

# Example of usage
dataset = CarlaDataset()
from sklearn.model_selection import train_test_split
train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)


In [3]:
batch_size = 32

def from_numpy(data):
    """Recursively transform numpy.ndarray to torch.Tensor.
    """
    if isinstance(data, dict):
        for key in data.keys():
            data[key] = from_numpy(data[key])
    if isinstance(data, list) or isinstance(data, tuple):
        data = [from_numpy(x) for x in data]
    if isinstance(data, np.ndarray):
        """Pytorch now has bool type."""
        data = torch.from_numpy(data)
    return data


def collate_fn(batch):
    batch = from_numpy(batch)
    return_batch = dict()
    # Batching by use a list for non-fixed size
    for key in batch[0].keys():
        return_batch[key] = [x[key] for x in batch]
    return return_batch

def worker_init_fn(pid):
    np_seed = int(pid)
    np.random.seed(np_seed)
    random_seed = np.random.randint(2 ** 32 - 1)
    random.seed(random_seed)

train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=collate_fn,
    worker_init_fn=worker_init_fn,
    pin_memory=True,
    drop_last=True,
)



In [4]:
for batch in train_loader:
    print(batch.keys())
    print(len(batch['feat']), len(batch['ctrs']), len(batch['nbr_waypoints']), len(batch['label']))
    print(batch['feat'][0].shape, batch['nbr_waypoints'][0].shape, batch['label'][0].shape)
    break

dict_keys(['feat', 'ctrs', 'nbr_waypoints', 'label'])
32 32 32 32
torch.Size([3, 20, 5]) torch.Size([660, 2]) torch.Size([30, 3])


In [19]:
from layers import Conv1d, Res1d, Linear, LinearRes, Null
import torch
from torch import Tensor, nn
from torch.nn import functional as F

In [20]:
class ActorNet(nn.Module):
    """
    Actor feature extractor with Conv1D
    """
    def __init__(self):
        super(ActorNet, self).__init__()
        norm = "GN"
        ng = 1
        n_in = 3
        n_out = [32, 64, 128]
        blocks = [Res1d, Res1d, Res1d]
        num_blocks = [2, 2, 2]

        groups = []
        for i in range(len(num_blocks)):
            group = []
            if i == 0:
                group.append(blocks[i](n_in, n_out[i], norm=norm, ng=ng))
            else:
                group.append(blocks[i](n_in, n_out[i], stride=2, norm=norm, ng=ng))

            for j in range(1, num_blocks[i]):
                group.append(blocks[i](n_out[i], n_out[i], norm=norm, ng=ng))
            groups.append(nn.Sequential(*group))
            n_in = n_out[i]
        self.groups = nn.ModuleList(groups)

        n = 128
        lateral = []
        for i in range(len(n_out)):
            lateral.append(Conv1d(n_out[i], n, norm=norm, ng=ng, act=False))
        self.lateral = nn.ModuleList(lateral)

        self.output = Res1d(n, n, norm=norm, ng=ng)

    def forward(self, actors):
        out = actors

        outputs = []
        for i in range(len(self.groups)):
            out = self.groups[i](out)
            outputs.append(out)

        out = self.lateral[-1](outputs[-1])
        for i in range(len(outputs) - 2, -1, -1):
            out = F.interpolate(out, scale_factor=2, mode="linear", align_corners=False)
            out += self.lateral[i](outputs[i])

        out = self.output(out)[:, :, -1]
        return out


In [21]:
def actor_gather(actors):
    batch_size = len(actors)
    num_actors = [len(x) for x in actors]

    actors = [x.transpose(1, 2) for x in actors]
    actors = torch.cat(actors, 0)

    actor_idcs = []
    count = 0
    for i in range(batch_size):
        idcs = torch.arange(count, count + num_actors[i]).to(actors.device)
        actor_idcs.append(idcs)
        count += num_actors[i]
    return actors, actor_idcs



In [22]:
class MapNet(nn.Module):
    def __init__(self):
        super(MapNet, self).__init__()
        input_size = 2
        output_size = 128
        self.fc1 = nn.Linear(input_size, 64)
        self.fc2 = nn.Linear(64, 64)
        self.fc3 = nn.Linear(64, 128)
        
    def forward(self, x):
        x = x.float()
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


def lane_gather(lanes):
    batch_size = len(lanes)
    num_lanes = [len(x) for x in lanes]
    lanes = torch.cat(lanes, 0)
    lane_idcs = []
    count = 0
    for i in range(batch_size):
        idcs = torch.arange(count, count + num_lanes[i]).to(lanes.device)
        lane_idcs.append(idcs)
        count += num_lanes[i]
    return lanes, lane_idcs



In [23]:
from math import gcd
class Att(nn.Module):
    """
    Attention block to pass context nodes information to target nodes
    This is used in Actor2Map, Actor2Actor, Map2Actor and Map2Map
    """
    def __init__(self) -> None:
        super(Att, self).__init__()
        norm = "GN"
        ng = 1
        n_agt = 128
        n_ctx = 128
        self.dist = nn.Sequential(
            nn.Linear(2, n_ctx),
            nn.ReLU(inplace=True),
            Linear(n_ctx, n_ctx, norm=norm, ng=ng),
        )

        self.query = Linear(n_agt, n_ctx, norm=norm, ng=ng)

        self.ctx = nn.Sequential(
            Linear(3 * n_ctx, n_agt, norm=norm, ng=ng),
            nn.Linear(n_agt, n_agt, bias=False),
        )

        self.agt = nn.Linear(n_agt, n_agt, bias=False)
        self.norm = nn.GroupNorm(gcd(ng, n_agt), n_agt)
        self.linear = Linear(n_agt, n_agt, norm=norm, ng=ng, act=False)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, agts, agt_idcs, agt_ctrs, ctx, ctx_idcs, ctx_ctrs):
        res = agts
        if len(ctx) == 0:
            agts = self.agt(agts) 
            agts = self.relu(agts)    
            agts = self.linear(agts)
            agts += res
            agts = self.relu(agts)
            return agts

        batch_size = len(agt_idcs)
        hi, wi = [], []
        hi_count, wi_count = 0, 0
        for i in range(batch_size):
            dist = agt_ctrs[i].view(-1, 1, 2) - ctx_ctrs[i].view(1, -1, 2)
            dist = torch.sqrt((dist ** 2).sum(2))
            mask = dist <= 100.0

            idcs = torch.nonzero(mask, as_tuple=False)
            if len(idcs) == 0:
                continue

            hi.append(idcs[:, 0] + hi_count)
            wi.append(idcs[:, 1] + wi_count)
            hi_count += len(agt_idcs[i])
            wi_count += len(ctx_idcs[i])
        hi = torch.cat(hi, 0)
        wi = torch.cat(wi, 0)

        agt_ctrs = torch.cat(agt_ctrs, 0)
        ctx_ctrs = torch.cat(ctx_ctrs, 0)
        dist = agt_ctrs[hi] - ctx_ctrs[wi]
        
        # 将输入张量和权重转换为相同的数据类型
        dist = self.dist(dist.to(agts.dtype))
        query = self.query(agts[hi])
        
        # 将输入张量和权重转换为相同的数据类型
        ctx = ctx[wi].to(agts.dtype)
        ctx = torch.cat((dist, query, ctx), 1)
        ctx = self.ctx(ctx)

        agts = self.agt(agts)
        agts.index_add_(0, hi, ctx)
        agts = self.norm(agts)
        agts = self.relu(agts)

        agts = self.linear(agts)
        agts += res
        agts = self.relu(agts)
        return agts

    


In [24]:

class A2A(nn.Module):
    """
    The actor to actor block performs interactions among actors.
    """
    def __init__(self):
        super(A2A, self).__init__()
        norm = "GN"
        ng = 1

        n_actor = 128
        n_map = 128

        att = []
        for i in range(2):
            att.append(Att())
        self.att = nn.ModuleList(att)

    def forward(self, actors, actor_idcs, actor_ctrs):
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                actors,
                actor_idcs,
                actor_ctrs,
            )
        return actors
        


In [25]:
class A2M(nn.Module):
    def __init__(self):
        super(A2M, self).__init__()
        n_map = 128
        norm = "GN"
        ng = 1

        att = []
        for i in range(2):
            att.append(Att())
        self.att = nn.ModuleList(att)

    def forward(self, lanes, lane_idcs, lane_ctrs, actors, actor_idcs, actor_ctrs):
        feat_list = [] 
        for i in range(len(self.att)):
            feat = self.att[i](
                lanes,
                lane_idcs,
                lane_ctrs,
                actors,
                actor_idcs,
                actor_ctrs
            )

        return feat




In [26]:
class M2A(nn.Module):
    """
    The lane to actor block fuses updated
        map information from lane nodes to actor nodes
    """
    def __init__(self):
        super(M2A, self).__init__()
        norm = "GN"
        ng = 1

        n_actor = 128
        n_map = 128

        att = []
        for i in range(2):
            att.append(Att())
        self.att = nn.ModuleList(att)

    def forward(self, actors, actor_idcs, actor_ctrs, lanes, lane_idcs, lane_ctrs):
        for i in range(len(self.att)):
            actors = self.att[i](
                actors,
                actor_idcs,
                actor_ctrs,
                lanes,
                lane_idcs,
                lane_ctrs
            )
        return actors

In [27]:
class AttDest(nn.Module):
    def __init__(self):
        super(AttDest, self).__init__()
        norm = "GN"
        ng = 1
        n_agt = 128
        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )

        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)

    def forward(self, agts, agt_ctrs, dest_ctrs):
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)

        dist = (agt_ctrs.unsqueeze(1) - dest_ctrs).view(-1, 2)
        dist = self.dist(dist)
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)

        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts)
        return agts



class PredNet(nn.Module):
    """
    Final motion forecasting with Linear Residual block
    """
    def __init__(self):
        super(PredNet, self).__init__()
        norm = "GN"
        ng = 1
        n_actor = 128
        embedding_dim = 128
        h_dim = 128
        num_layers = 1

        self.spatial_embedding0 = nn.Linear(2, embedding_dim)
        self.decoder0 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos0 = nn.Linear(n_actor, 2*30)

        self.spatial_embedding1 = nn.Linear(2, embedding_dim)
        self.decoder1 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos1 = nn.Linear(n_actor, 2*30)
        
        self.spatial_embedding2 = nn.Linear(2, embedding_dim)
        self.decoder2 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos2 = nn.Linear(n_actor, 2*30)
        
        self.spatial_embedding3 = nn.Linear(2, embedding_dim)
        self.decoder3 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos3 = nn.Linear(n_actor, 2*30)
        
        self.spatial_embedding4 = nn.Linear(2, embedding_dim)
        self.decoder4 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos4 = nn.Linear(n_actor, 2*30)
        
        self.spatial_embedding5 = nn.Linear(2, embedding_dim)
        self.decoder5 = nn.LSTM(n_actor, n_actor, num_layers)
        self.hidden2pos5 = nn.Linear(n_actor, 2*30)


        self.att_dest = AttDest()
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), nn.Linear(n_actor, 1)
        )

    def forward(self, actors, actor_idcs, actor_ctrs):
        embedding_dim = h_dim = 128
        decoder_h_dim = 128
        decoder_h = actors.view(-1, decoder_h_dim)
        decoder_h = torch.unsqueeze(decoder_h, 0)
        
        num_layers = 1
        batch = actors.size(0)
        decoder_c = torch.zeros(num_layers, batch, decoder_h_dim)
        state_tuple = (decoder_h, decoder_c)
            
        ctrs = torch.cat(actor_ctrs, 0)
        batch = ctrs.size(0)
        
        
        decoder_input0 = self.spatial_embedding0(ctrs)
        decoder_input0 = decoder_input0.view(1, batch, embedding_dim)
        output0, state_tuple0 = self.decoder0(decoder_input0, state_tuple)
        rel_pos0 = self.hidden2pos0(output0.view(-1, h_dim))
        
        decoder_input1 = self.spatial_embedding1(ctrs)
        decoder_input1 = decoder_input1.view(1, batch, embedding_dim)
        output1, state_tuple1 = self.decoder1(decoder_input1, state_tuple)
        rel_pos1 = self.hidden2pos1(output1.view(-1, h_dim))
        
        decoder_input2 = self.spatial_embedding2(ctrs)
        decoder_input2 = decoder_input2.view(1, batch, embedding_dim)
        output2, state_tuple2 = self.decoder2(decoder_input2, state_tuple)
        rel_pos2 = self.hidden2pos2(output2.view(-1, h_dim))
        
        decoder_input3 = self.spatial_embedding3(ctrs)
        decoder_input3 = decoder_input3.view(1, batch, embedding_dim)
        output3, state_tuple3 = self.decoder3(decoder_input3, state_tuple)
        rel_pos3 = self.hidden2pos3(output3.view(-1, h_dim))
        
        decoder_input4 = self.spatial_embedding4(ctrs)
        decoder_input4 = decoder_input4.view(1, batch, embedding_dim)
        output4, state_tuple4 = self.decoder4(decoder_input4, state_tuple)
        rel_pos4 = self.hidden2pos4(output4.view(-1, h_dim))
        
        decoder_input5 = self.spatial_embedding5(ctrs)
        decoder_input5 = decoder_input5.view(1, batch, embedding_dim)
        output5, state_tuple5 = self.decoder5(decoder_input5, state_tuple)
        rel_pos5 = self.hidden2pos5(output5.view(-1, h_dim))
        
        preds = []
        preds.append(rel_pos0)
        preds.append(rel_pos1)
        preds.append(rel_pos2)
        preds.append(rel_pos3)
        preds.append(rel_pos4)
        preds.append(rel_pos5)
            
            
            
        reg = torch.cat([x.unsqueeze(1) for x in preds], 1)
        reg = reg.view(reg.size(0), reg.size(1), -1, 2)

        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            reg[idcs] = reg[idcs] + ctrs

        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs)
        cls = self.cls(feats).view(-1, 6)

        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2)

        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out


In [28]:
import torch
import torch.nn as nn

class AttDest(nn.Module):
    def __init__(self, n_agt=128, norm="GN", ng=1):
        super(AttDest, self).__init__()
        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )
        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)

    def forward(self, agts, agt_ctrs, dest_ctrs):
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)

        dist = (agt_ctrs.unsqueeze(1) - dest_ctrs).view(-1, 2)
        dist = self.dist(dist)
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)

        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts)
        return agts


class PredNet(nn.Module):
    def __init__(self, n_actor=128, embedding_dim=128, h_dim=128, num_layers=1, norm="GN", ng=1):
        super(PredNet, self).__init__()
        self.att_dest = AttDest()
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), nn.Linear(n_actor, 1)
        )

        self.decoders = nn.ModuleList()
        self.spatial_embeddings = nn.ModuleList()
        self.hidden2pos = nn.ModuleList()

        for _ in range(6):
            self.spatial_embeddings.append(nn.Linear(2, embedding_dim))
            self.decoders.append(nn.LSTM(n_actor, n_actor, num_layers))
            self.hidden2pos.append(nn.Linear(n_actor, 2 * 30))

    def forward(self, actors, actor_idcs, actor_ctrs):
        embedding_dim = h_dim = 128
        decoder_h_dim = 128
        decoder_h = actors.view(-1, decoder_h_dim)
        decoder_h = torch.unsqueeze(decoder_h, 0)

        num_layers = 1
        batch = actors.size(0)
        decoder_c = torch.zeros(num_layers, batch, decoder_h_dim)
        state_tuple = (decoder_h, decoder_c)

        ctrs = torch.cat(actor_ctrs, 0)
        batch = ctrs.size(0)

        preds = []

        for i in range(6):
            spatial_embedding = self.spatial_embeddings[i]
            decoder = self.decoders[i]
            hidden2pos = self.hidden2pos[i]

            decoder_input = spatial_embedding(ctrs)
            decoder_input = decoder_input.view(1, batch, embedding_dim)
            output, state_tuple = decoder(decoder_input, state_tuple)
            rel_pos = hidden2pos(output.view(-1, h_dim))
            preds.append(rel_pos)

        reg = torch.cat([x.unsqueeze(1) for x in preds], 1)
        reg = reg.view(reg.size(0), reg.size(1), -1, 2)

        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            reg[idcs] = reg[idcs] + ctrs

        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs)
        cls = self.cls(feats).view(-1, 6)

        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2)

        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out


In [29]:
class AttDest(nn.Module):
    def __init__(self):
        super(AttDest, self).__init__()
        norm = "GN"
        ng = 1
        n_agt = 128
        self.dist = nn.Sequential(
            nn.Linear(2, n_agt),
            nn.ReLU(inplace=True),
            Linear(n_agt, n_agt, norm=norm, ng=ng),
        )

        self.agt = Linear(2 * n_agt, n_agt, norm=norm, ng=ng)

    def forward(self, agts, agt_ctrs, dest_ctrs):
        n_agt = agts.size(1)
        num_mods = dest_ctrs.size(1)

        dist = (agt_ctrs.unsqueeze(1) - dest_ctrs).view(-1, 2)
        dist = self.dist(dist)
        agts = agts.unsqueeze(1).repeat(1, num_mods, 1).view(-1, n_agt)

        agts = torch.cat((dist, agts), 1)
        agts = self.agt(agts)
        return agts


class PredNet(nn.Module):
    """
    Final motion forecasting with Linear Residual block
    """
    def __init__(self):
        super(PredNet, self).__init__()
        norm = "GN"
        ng = 1
        n_actor = 128
        pred = []
        for i in range(6):
            pred.append(
                nn.Sequential(
                    LinearRes(n_actor, n_actor, norm=norm, ng=ng),
                    nn.Linear(n_actor, 2 * 30),
                )
            )
        self.pred = nn.ModuleList(pred)

        self.att_dest = AttDest()
        self.cls = nn.Sequential(
            LinearRes(n_actor, n_actor, norm=norm, ng=ng), nn.Linear(n_actor, 1)
        )

    def forward(self, actors, actor_idcs, actor_ctrs):
        preds = []
        for i in range(len(self.pred)):
            preds.append(self.pred[i](actors))
        reg = torch.cat([x.unsqueeze(1) for x in preds], 1)
        reg = reg.view(reg.size(0), reg.size(1), -1, 2)

        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            reg[idcs] = reg[idcs] + ctrs

        dest_ctrs = reg[:, :, -1].detach()
        feats = self.att_dest(actors, torch.cat(actor_ctrs, 0), dest_ctrs)
        cls = self.cls(feats).view(-1, 6)

        cls, sort_idcs = cls.sort(1, descending=True)
        row_idcs = torch.arange(len(sort_idcs)).long().to(sort_idcs.device)
        row_idcs = row_idcs.view(-1, 1).repeat(1, sort_idcs.size(1)).view(-1)
        sort_idcs = sort_idcs.view(-1)
        reg = reg[row_idcs, sort_idcs].view(cls.size(0), cls.size(1), -1, 2)

        out = dict()
        out["cls"], out["reg"] = [], []
        for i in range(len(actor_idcs)):
            idcs = actor_idcs[i]
            ctrs = actor_ctrs[i].view(-1, 1, 1, 2)
            out["cls"].append(cls[idcs])
            out["reg"].append(reg[idcs])
        return out

In [30]:
actor_net = ActorNet()
map_net = MapNet()
a2m = A2M()
m2a = M2A()
a2a = A2A()

actors, actor_idcs = actor_gather(batch['feat'])   # torch.Size([116, 3, 20])
actors = actor_net(actors)
lanes, lane_idcs = lane_gather(batch['nbr_waypoints'])
lanes = map_net(lanes)
# lanes = torch.sum(lanes, dim=0).view(1, -1)
lanes = a2m(lanes, lane_idcs, batch['nbr_waypoints'], actors, actor_idcs, batch['ctrs'])
actors = m2a(actors, actor_idcs, batch['ctrs'], lanes, lane_idcs, batch['nbr_waypoints'])
actors = a2a(actors, actor_idcs, batch['ctrs'])

In [31]:
pred_net = PredNet()
out = pred_net(actors, actor_idcs, batch['ctrs'])


In [32]:
class PredLoss(nn.Module):
    def __init__(self):
        super(PredLoss, self).__init__()
        self.reg_loss = nn.SmoothL1Loss(reduction="sum")

    def forward(self, out, gt_preds, has_preds):
        cls, reg = out["cls"], out["reg"]
        cls = torch.cat([x[0].unsqueeze(0) for x in cls], 0)
        reg = torch.cat([x[0].unsqueeze(0) for x in reg], 0)
        gt_preds = torch.cat([x.unsqueeze(0) for x in gt_preds], 0)
        has_preds = torch.cat([x.unsqueeze(0) for x in has_preds], 0)
        has_preds = (has_preds == 1)
        
        loss_out = dict()
        zero = 0.0 * (cls.sum() + reg.sum())
        loss_out["cls_loss"] = zero.clone()
        loss_out["num_cls"] = 0
        loss_out["reg_loss"] = zero.clone()
        loss_out["num_reg"] = 0

        num_mods, num_preds = 6, 30

        last = has_preds.float() + 0.1 * torch.arange(num_preds).float().to(
            has_preds.device
        ) / float(num_preds)
        max_last, last_idcs = last.max(1)
        mask = max_last > 1.0

        cls = cls[mask]
        reg = reg[mask]
        gt_preds = gt_preds[mask]
        has_preds = has_preds[mask]
        last_idcs = last_idcs[mask]

        row_idcs = torch.arange(len(last_idcs)).long().to(last_idcs.device)
        dist = []
        for j in range(num_mods):
            dist.append(
                torch.sqrt(
                    (
                        (reg[row_idcs, j, last_idcs] - gt_preds[row_idcs, last_idcs])
                        ** 2
                    ).sum(1)
                )
            )
        dist = torch.cat([x.unsqueeze(1) for x in dist], 1)
        min_dist, min_idcs = dist.min(1)
        row_idcs = torch.arange(len(min_idcs)).long().to(min_idcs.device)

        mgn = cls[row_idcs, min_idcs].unsqueeze(1) - cls
        mask0 = (min_dist < 2.0).view(-1, 1)
        mask1 = dist - min_dist.view(-1, 1) > 0.2
        mgn = mgn[mask0 * mask1]
        mask = mgn < 0.2
        coef = 1.0
        loss_out["cls_loss"] += coef * (
            0.2 * mask.sum() - mgn[mask].sum()
        )
        loss_out["num_cls"] += mask.sum().item()

        reg = reg[row_idcs, min_idcs]
        coef = 1.0
        loss_out["reg_loss"] += coef * self.reg_loss(
            reg[has_preds], gt_preds[has_preds]
        )
        loss_out["num_reg"] += has_preds.sum().item()
        return loss_out


class Loss(nn.Module):
    def __init__(self):
        super(Loss, self).__init__()
        self.pred_loss = PredLoss()

    def forward(self, out, data):
        loss_out = self.pred_loss(out, [label[:, :2] for label in data["label"]], [label[:, 2] for label in data["label"]])
        loss_out["loss"] = loss_out["cls_loss"] / (
            loss_out["num_cls"] + 1e-10
        ) + loss_out["reg_loss"] / (loss_out["num_reg"] + 1e-10)
        return loss_out

In [33]:
loss = Loss()
loss_out = loss(out, batch)

In [34]:
loss_out

{'cls_loss': tensor(16.8879, grad_fn=<AddBackward0>),
 'num_cls': 63,
 'reg_loss': tensor(6498.7559, grad_fn=<AddBackward0>),
 'num_reg': 960,
 'loss': tensor(7.0376, grad_fn=<AddBackward0>)}

In [35]:
class PostProcess(nn.Module):
    def __init__(self):
        super(PostProcess, self).__init__()

    def forward(self, out, data):
        post_out = dict()
        post_out["preds"] = [x[0:1].detach().cpu().numpy() for x in out["reg"]]
        post_out["gt_preds"] = [x[0:1].numpy() for x in [label[:, :2] for label in data["label"]]]
        post_out["has_preds"] = [x[0:1].numpy() for x in [label[:, 2] for label in data["label"]]]
        return post_out

    def append(self, metrics, loss_out, post_out):
        if len(metrics.keys()) == 0:
            for key in loss_out:
                if key != "loss":
                    metrics[key] = 0.0

            for key in post_out:
                metrics[key] = []

        for key in loss_out:
            if key == "loss":
                continue
            if isinstance(loss_out[key], torch.Tensor):
                metrics[key] += loss_out[key].item()
            else:
                metrics[key] += loss_out[key]

        for key in post_out:
            metrics[key] += post_out[key]
        return metrics

    def display(self, metrics, dt, epoch, lr=None):
        """Every display-iters print training/val information"""
        if lr is not None:
            print("Epoch %3.3f, lr %.5f, time %3.2f" % (epoch, lr, dt))
        else:
            print(
                "************************* Validation, time %3.2f *************************"
                % dt
            )

        cls = metrics["cls_loss"] / (metrics["num_cls"] + 1e-10)
        reg = metrics["reg_loss"] / (metrics["num_reg"] + 1e-10)
        loss = cls + reg

        preds = np.concatenate(metrics["preds"], 0)
        gt_preds = np.concatenate(metrics["gt_preds"], 0)
        has_preds = np.concatenate(metrics["has_preds"], 0)
        ade1, fde1, ade, fde, min_idcs = pred_metrics(preds, gt_preds, has_preds)

        print(
            "loss %2.4f %2.4f %2.4f, ade1 %2.4f, fde1 %2.4f, ade %2.4f, fde %2.4f"
            % (loss, cls, reg, ade1, fde1, ade, fde)
        )
        print()



def pred_metrics(preds, gt_preds, has_preds):
    assert has_preds.all()
    preds = np.asarray(preds, np.float32)
    gt_preds = np.asarray(gt_preds, np.float32)

    print(preds.shape, gt_preds.shape)


    """batch_size x num_mods x num_preds"""
    err = np.sqrt(((preds - np.expand_dims(gt_preds, (1, 2))) ** 2).sum(3))

    ade1 = err[:, 0].mean()
    fde1 = err[:, 0, -1].mean()

    min_idcs = err[:, :, -1].argmin(1)
    row_idcs = np.arange(len(min_idcs)).astype(np.int64)
    err = err[row_idcs, min_idcs]
    ade = err.mean()
    fde = err[:, -1].mean()
    return ade1, fde1, ade, fde, min_idcs

In [36]:
post_process = PostProcess()
metrics = dict()
post_out = post_process(out, batch)
post_process.append(metrics, loss_out, post_out)

{'cls_loss': 16.88785743713379,
 'num_cls': 63.0,
 'reg_loss': 6498.755859375,
 'num_reg': 960.0,
 'preds': [array([[[[-36.528793, -95.908035],
           [-34.525288, -94.152084],
           [-33.72307 , -95.67411 ],
           [-35.0533  , -94.56317 ],
           [-35.023643, -92.53043 ],
           [-31.937307, -96.144455],
           [-35.088024, -94.98214 ],
           [-35.34944 , -95.28975 ],
           [-33.73517 , -91.571396],
           [-33.330284, -94.350555],
           [-33.597202, -93.10538 ],
           [-33.854527, -94.062065],
           [-37.00071 , -95.42738 ],
           [-35.52767 , -96.109856],
           [-34.013145, -94.02513 ],
           [-34.03819 , -95.23058 ],
           [-32.39424 , -92.274666],
           [-33.735966, -94.314514],
           [-34.79856 , -96.21058 ],
           [-35.833378, -95.7268  ],
           [-34.953117, -95.54326 ],
           [-33.28731 , -92.939445],
           [-32.48859 , -93.625435],
           [-35.080463, -93.8114  ],
     

In [37]:
import time
dt = time.time()
# metrics = sync(metrics)
post_process.display(metrics, dt, 10)
start_time = time.time()
metrics = dict()


************************* Validation, time 1699959050.84 *************************
(32, 6, 30, 2) (32, 2)
loss 7.0376 0.2681 6.7695, ade1 1.9235, fde1 2.1752, ade 1.7780, fde 1.3248



In [42]:
from Net import Net
net = Net()
print(net)

Net(
  (actor_net): ActorNet(
    (groups): ModuleList(
      (0): Sequential(
        (0): Res1d(
          (conv1): Conv1d(3, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (relu): ReLU(inplace=True)
          (bn1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (bn2): GroupNorm(1, 32, eps=1e-05, affine=True)
          (downsample): Sequential(
            (0): Conv1d(3, 32, kernel_size=(1,), stride=(1,), bias=False)
            (1): GroupNorm(1, 32, eps=1e-05, affine=True)
          )
        )
        (1): Res1d(
          (conv1): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (conv2): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), bias=False)
          (relu): ReLU(inplace=True)
          (bn1): GroupNorm(1, 32, eps=1e-05, affine=True)
          (bn2): GroupNorm(1, 32, eps=1e-05, affine=True)
        )
      )
    