In [21]:
import numpy as np 
import time 
import numba 
import open3d as o3d 
import torch 
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader 
from  torch.optim.lr_scheduler import StepLR
import open3d as o3d 
import numpy as np 
import os 
import struct 
from tqdm import tqdm 
import matplotlib.pyplot as plt 


In [22]:
class VoxelGenerator:
    def __init__(self,
                 voxel_size,
                 point_cloud_range,
                 max_num_points=35, 
                 max_voxels=20000):
        point_cloud_range = np.array(point_cloud_range, dtype=np.float32) 
        # [0, -40, -3, 70.4, 40, 1]
        voxel_size = np.array(voxel_size, dtype=np.float32)
        grid_size = (point_cloud_range[3:] - point_cloud_range[:3]) / voxel_size 
        grid_size = np.round(grid_size).astype(np.int64)


        self._voxel_size = voxel_size 
        self._point_cloud_range = point_cloud_range 
        self._max_num_points = max_num_points 
        self._max_voxels = max_voxels 
        self._grid_size = grid_size 

    def generate(self, points, max_voxels):
        return points_to_voxel(points, self._voxel_size, self._point_cloud_range, self._max_num_points, True, max_voxels)
    
    @property
    def voxel_size(self):
        return self._voxel_size 
    

    @property 
    def max_num_points_per_voxel(self):
        return self._max_num_points 
    
    @property 
    def point_cloud_range(self):
        return self._point_cloud_range 
    
    @property 
    def grid_size(self):
        return self._grid_size 
    


@numba.jit(nopython=True)
def _points_to_voxel_reverse_kernel(points, 
                                    voxel_size, 
                                    coors_range,
                                    num_points_per_voxel,
                                    coor_to_voxelidx,
                                    voxels,
                                    coors,
                                    max_points=35,
                                    max_voxels=20000):
    # put all the computation in one loop.
    # we should not create large array in main jit code, otherwise reduce performance.
    N = points.shape[0]
    #ndim = points.shape[1] - 1 
    ndim = 3 
    ndim_minus_1 = ndim -  1
    grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size 
    #np.round(grid_size)
    #grid_size = np.round(grid_size).astype(np.int64)(np.int32)
    grid_size = np.round(grid_size, 0, grid_size).astype(np.int32)
    coor = np.zeros(shape=(3, ), dtype=np.int32)
    voxel_num = 0
    failed = False 
    for i in range(N):
        failed = False 
        for j in range(ndim):
            c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j])
            if c < 0 or c >= grid_size[j]:
                failed = True 
                break 
            coor[ndim_minus_1 - j] = c 
        if failed:
            continue 
        voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]]
        if voxelidx == -1:
            voxelidx = voxel_num 
            if voxel_num >= max_voxels:
                break 
            voxel_num += 1 
            coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx 
            coors[voxelidx] = coor 
        num = num_points_per_voxel[voxelidx]
        if num < max_points:
            voxels[voxelidx, num] = points[i]
            num_points_per_voxel[voxelidx] += 1
        
    return voxel_num, coors, voxels, num_points_per_voxel






@numba.jit(nopython=True)
def _points_to_voxel_kernel(points, 
                            voxel_size,
                            coors_range,
                            num_points_per_voxel, 
                            coor_to_voxelidx, 
                            voxels,
                            coors,
                            max_points=35,
                            max_voxels=20000):
    #need mutex if write cuda but numba.cuda don't support mutex.
    # in addition pytorch don't support cuda in dataloader.
    # put all the computation in  one loop 
    # We should not create large array in main jit code, otherwise decrease performance.
    N = points.shape[0]
    # ndim = points.shape[1] - 1
    ndim = 3 
    grid_size = (coors_range[3:] - coors_range[:3]) / voxel_size 
    #  grid_size = np.round(grid_size).astype(np.float64)(np.float32)
    grid_size = np.round(grid_size, 0, grid_size).astype(np.int32)

    lower_bounnd = coors_range[:3]
    upper_bound = coors_range[3:]
    coor = np.zeros(shape=(3,), dtype=np.int32)
    voxel_num = 0
    failed = False 
    for i in range(N):
        failed = False 
        for j in range(ndim):
            c = np.floor((points[i, j] - coors_range[j]) / voxel_size[j])
            if c < 0 or c >= grid_size[j]:
                failed = True 
                break 
            coor[j] = c 
        if failed:
            continue 
        voxelidx = coor_to_voxelidx[coor[0], coor[1], coor[2]]
        if voxelidx == -1:
            voxelidx = voxel_num 
            if voxel_num >= max_voxels:
                break 
            voxel_num += 1 
            coor_to_voxelidx[coor[0], coor[1], coor[2]] = voxelidx 
            coors[voxelidx] = coor 
        num = num_points_per_voxel[voxelidx]
        if num < max_points:
            voxels[voxelidx, num] = points[i]
            num_points_per_voxel[voxelidx] += 1 

    
    return voxel_num, coors, voxels, num_points_per_voxel 






def points_to_voxel(points, voxel_size, coors_range, max_points=35, reverse_index=True, max_voxels=20000):
    """
    convert kitti points (N, >=3)  to voxels. This version calculate everything in one loop. Now it takes only 4.2 ms (complete point cloud) with JIT.
    (Don't calculate other features).
    Note: this function in ubuntu seems faster than window 10.
    Args:
       points: [N, ndim] float tensor.points[:, :3] contain xyz points and points[:, 3:] contain other information such as reflectivity.
       voxel_size : [3] list/tuple or array. float.xyz indicate voxel size.
       coors_range: [6] list/tuple/array, float. indicate voxel range. format xyzxyz, minmax.
       max_points:int. indicate maximum points contain in voxel.
       reverse_index. boolean. indicate wheather return reverse coordinates.
              If points has xyz format and reverse index is True, output coordinate will be zyx format, but points in features always in xyz format.
       max_voxels: int. indicate maximum voxels this function create. 
             for SECOND 20000 is a good choice. You should shuffle points before call this function because max_voxels may drop some points.
    Returns:
       voxels: [M, max_points, ndim] float tensor. Only contain points.
       coordinates: [M, 3]. int32 tensor.
       num_points_per_voxel: [M] int32 tensor.
        
    """
    if not isinstance(voxel_size, np.ndarray):
        voxel_size = np.array(voxel_size, dytpe=points.dtype)
    if not isinstance(coors_range, np.ndarray):
        coors_range = np.array(coors_range, dtype=points.dtype)
    voxelmap_shape = (coors_range[3:] - coors_range[:3]) / voxel_size 
    voxelmap_shape = tuple(np.round(voxelmap_shape).astype(np.int32).tolist())
    if reverse_index:
        voxelmap_shape = voxelmap_shape[::-1]
    # don't create large array in jit(nopython=True) code.
    num_points_per_voxel = np.zeros(shape=(max_voxels, ), dtype=np.int32)
    coor_to_voxelidx = -np.ones(shape=voxelmap_shape, dtype=np.int32)
    voxels = np.zeros(shape=(max_voxels, max_points, points.shape[-1]), dtype=points.dtype)
    coors = np.zeros(shape=(max_voxels,  3), dtype=np.int32)
    if reverse_index:
        voxel_num, coors, voxels, num_points_per_voxel = _points_to_voxel_reverse_kernel(
            points, voxel_size, coors_range, num_points_per_voxel, coor_to_voxelidx, voxels, coors, max_points, max_voxels
        )
    else:
        voxel_num, coors, voxels, num_points_per_voxel = _points_to_voxel_kernel(
            points, voxel_size, coors_range, num_points_per_voxel, coor_to_voxelidx, voxels, coors, max_points, max_voxels
        )

    #coors = coors[:voxel_num]
    #voxels = voxels[:voxel_num]
    #num_points_per_voxel =  num_points_per_voxel[:voxel_num]
    # voxels[:, :, -3:] = voxels[:, :, :3] - voxels[:, :, :3].sum(axis=1, keepdims=True) /num_points_per_voxel.reshape(-1,, 1,1)
    return voxels, coors, num_points_per_voxel 





#voxel_generator = VoxelGenerator(voxel_size=[0.4, 0.4, 0.4], point_cloud_range=[0, -40, -3, 70.4, 40, 1])


#path = "/media/parvez_alam/Expansion/Dataset/training/SteeringData/1671529879.462312222.pcd"

#pcd = o3d.io.read_point_cloud(path)
#points = np.asarray(pcd.points)


#voxels,  coors, num_points_per_voxel = voxel_generator.generate(points, 20000)
#print("voxels shape = ", voxels.shape)
#print("coors shape = ", coors.shape)
#print("num_points_per_voxel shape = ", num_points_per_voxel.shape)



In [23]:
data_path = "/home/parvez_alam/Data/Kitti/Tracking/data_tracking_velodyne/training/velodyne"
calib_path = "/home/parvez_alam/Data/Kitti/Tracking/data_tracking_calib/training/calib"
label_path = "/home/parvez_alam/Data/Kitti/Tracking/data_tracking_label_2/training/label_02"


In [24]:
def load_kitti_calib(calib_file):
   
    with open(calib_file) as f_calib:
        lines = f_calib.readlines()
      
    P0 = np.array(lines[0].strip('\n').split()[1:], dtype=np.float32)
    P1 = np.array(lines[1].strip('\n').split()[1:], dtype=np.float32)
    P2 = np.array(lines[2].strip('\n').split()[1:], dtype=np.float32)
    P3 = np.array(lines[3].strip('\n').split()[1:], dtype=np.float32)
    R0_rect = np.array(lines[4].strip('\n').split()[1:], dtype=np.float32)
    Tr_velo_to_cam = np.array(lines[5].strip('\n').split()[1:], dtype=np.float32)
    Tr_imu_to_velo = np.array(lines[6].strip('\n').split()[1:], dtype=np.float32) 
   
    return {'P0': P0, 'P1':P1, 'P2':P2, 'P3':P3, 'R0_rect': R0_rect, 'Tr_velo_to_cam': Tr_velo_to_cam.reshape(3,4), 'Tr_imu_to_velo': Tr_imu_to_velo}   

In [25]:
def camera_coordinate_to_point_cloud(box3d, Tr):

    def project_cam2velo(cam, Tr):
        T = np.zeros([4,4], dtype=np.float32)
        T[:3, :] = Tr 
        T[3, 3] = 1 
     
        T_inv = np.linalg.inv(T) 
        lidar_loc_ = np.dot(T_inv, cam) 
        lidar_loc = lidar_loc_[:3]
      
        return lidar_loc.reshape(1,3) 
      
    def ry_to_rz(ry):
        angle = -ry - np.pi / 2
      
        if angle >= np.pi:
           angle -= np.pi 
        if angle < -np.pi:
           angle = 2 * np.pi + angle 
        return angle 
      
      
      
   
    h,w,l,tx,ty,tz,ry = [float(i) for i in box3d]
    cam = np.ones([4,1])
    cam[0] = tx 
    cam[1] = ty
    cam[2] = tz 
    t_lidar = project_cam2velo(cam, Tr) 
  
   
    Box = np.array([[-l/2, -l/2, l/2, l/2, -l/2, -l/2, l/2, l/2],
                   [w/2, -w/2, -w/2, w/2, w/2, -w/2, -w/2, w/2],
                   [0, 0, 0, 0, h, h,  h, h]])
                   
    rz = ry_to_rz(ry) 
   
    rotMat = np.array([[np.cos(rz), -np.sin(rz), 0.0],
                       [np.sin(rz), np.cos(rz), 0.0],
                       [0.0,          0.0,       1.0]])
   
                      
    velo_box = np.dot(rotMat, Box) 
     
    cornerPosInVelo = velo_box + np.tile(t_lidar, (8, 1)).T 
   
    box3d_corner = cornerPosInVelo.transpose() 
   
    return t_lidar, box3d_corner , rz 

In [26]:
class KITTIDATA(Dataset):
    def __init__(self, data_path, label_path, calib_path, train=True):
        self.data_path = data_path 
        self.label_path = label_path 
        self.calib_path = calib_path 

        self.scenes = sorted(os.listdir(self.data_path))
        self.labels = sorted(os.listdir(self.label_path))
        self.calibs = sorted(os.listdir(self.calib_path))

        self.train_scenes = self.scenes[0:17]
        self.test_scenes = self.scenes[17:21]

        # object for voxelization 
        self.voxel_generator = VoxelGenerator(voxel_size=[0.4, 0.4, 0.4], point_cloud_range=[0, -40, -3, 70.4, 40, 1]) 

        
        self.files = [] 


        if train == True :
            for i in tqdm(range(len(self.train_scenes))):
                pcd_file_path = os.path.join(self.data_path, self.scenes[i])
                calib_file = os.path.join(self.calib_path, self.calibs[i])
                label_file = os.path.join(self.label_path, self.labels[i])

                calibration = load_kitti_calib(calib_file)

                # get the total number of frames in particulat scenes 
                num_frames = len(os.listdir(pcd_file_path))

                bb_list = []            # store bounding boxex of complete scene 
    
                with open(label_file) as f_label:
                    lines = f_label.readlines()
       
                    for line in lines:
                       line = line.strip('\n').split() 
                       if line[2] != 'DontCare':
                            frame_index = line[0]             # frame number
                            track_id = line[1]                # track ID
                            category = line[2]                # class of the BB 
                            center, box3d_corner, rz = camera_coordinate_to_point_cloud(line[10:17], calibration['Tr_velo_to_cam'])
                            center = center[0] 
                            bb_list.append([frame_index, center, box3d_corner, rz, category, track_id])
                pcd_frames = sorted(os.listdir(pcd_file_path)) 
    
                for n in range(len(pcd_frames)):
                   pcd_path = os.path.join(pcd_file_path, pcd_frames[n])
                   size_float = 4 
                   list_pcd = [] 
                   with open(pcd_path, "rb") as f:
                      byte = f.read(size_float * 4) 
                      while byte:
                         x, y, z, intensity = struct.unpack("ffff", byte) 
                         list_pcd.append([x,y,z]) 
                         byte = f.read(size_float * 4) 

                   points = np.asarray(list_pcd) 

                   #pcd = o3d.geometry.PointCloud() 
                   #pcd.points = o3d.utility.Vector3dVector(points) 
       
       
                   bboxes = [] 
                   for k in range(len(bb_list)):
                       if int(bb_list[k][0]) == n : 
                          bboxes.append([bb_list[k][1], bb_list[k][2], bb_list[k][3], bb_list[k][4], bb_list[k][5]])   # [center, BB,rz, cetegory, track_id]

                    
                   sample = {}
                   sample['pcd'] = points 
                   sample['labels'] = bboxes 
                   self.files.append(sample)

        else:
            for i in tqdm(range(len(self.test_scenes))):
                pcd_file_path = os.path.join(self.data_path, self.scenes[i])
                calib_file = os.path.join(self.calib_path, self.calibs[i])
                label_file = os.path.join(self.label_path, self.labels[i])

                calibration = load_kitti_calib(calib_file)

                # get the total number of frames in particulat scenes 
                num_frames = len(os.listdir(pcd_file_path))

                bb_list = []            # store bounding boxex of complete scene 
    
                with open(label_file) as f_label:
                    lines = f_label.readlines()
       
                    for line in lines:
                       line = line.strip('\n').split() 
                       if line[2] != 'DontCare':
                          frame_index = line[0]             # frame number
                          track_id = line[1]                # track ID
                          category = line[2]                # class of the BB 
                          center, box3d_corner, rz = camera_coordinate_to_point_cloud(line[10:17], calibration['Tr_velo_to_cam'])
                          center = center[0] 
                          bb_list.append([frame_index, center, box3d_corner,rz, category, track_id])
                pcd_frames = sorted(os.listdir(pcd_file_path)) 
    
                for n in range(len(pcd_frames)):
                   pcd_path = os.path.join(pcd_file_path, pcd_frames[n])
                   size_float = 4 
                   list_pcd = [] 
                   with open(pcd_path, "rb") as f:
                      byte = f.read(size_float * 4) 
                      while byte:
                         x, y, z, intensity = struct.unpack("ffff", byte) 
                         list_pcd.append([x,y,z]) 
                         byte = f.read(size_float * 4) 

                   points = np.asarray(list_pcd) 

                   #pcd = o3d.geometry.PointCloud() 
                   #pcd.points = o3d.utility.Vector3dVector(points) 
       
       
                   bboxes = [] 
                   for k in range(len(bb_list)):
                       if int(bb_list[k][0]) == n : 
                          bboxes.append([bb_list[k][1], bb_list[k][2], bb_list[k][3], bb_list[k][4], bb_list[k][5]])   # [center, BB,rz,cate, track_id]

                    
                   sample = {}
                   sample['pcd'] = points 
                   sample['labels'] = bboxes 
                   self.files.append(sample)

        
    def __len__(self):
       return len(self.files) 
    
    def __getitem__(self, index):
       pcd = self.files[index]['pcd']
       bboxes = self.files[index]['labels']
       
       # voxelize the point cloud 
       voxels, coors, num_points_per_voxel = self.voxel_generator.generate(pcd,20000)

       return {'voxels':voxels, 'coors':coors, 'bboxes':bboxes}
    
                     

In [None]:
train_ds = KITTIDATA(data_path=data_path, label_path=label_path, calib_path=calib_path, train=True)
train_loader = DataLoader(dataset=train_ds, batch_size=1, shuffle=False )
print("total_frames=", len(train_ds)) 


  6%|██▌                                         | 1/17 [00:37<10:06, 37.93s/it]

In [7]:
device=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)  

cuda:0


In [8]:
class Tnet(nn.Module):
    def __init__(self, k=35):
        super().__init__() 
        self.k = k 
        self.conv1 = nn.Conv1d(k, 64, 1)
        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 256) 
        self.fc3 = nn.Linear(256, k*k)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
        self.bn4 = nn.BatchNorm1d(512)
        self.bn5 = nn.BatchNorm1d(256)


    def forward(self, input):
        # input.shape == (bs, n, 3)
        bs = input.size(0)
        xb = F.relu(self.bn1(self.conv1(input)))
        xb = F.relu(self.bn2(self.conv2(xb)))
        xb = F.relu(self.bn3(self.conv3(xb)))
        pool = nn.MaxPool1d(xb.shape[-1]) (xb)
        flat = torch.squeeze(pool)
        xb = F.relu(self.bn4(self.fc1(flat)))
        xb = F.relu(self.bn5(self.fc2(xb)))


        # initialize as identity 
        init = torch.eye(self.k,  requires_grad=True).repeat(bs, 1,1)
        if xb.is_cuda:
            init = init.cuda() 
        matrix = self.fc3(xb).view(-1, self.k , self.k) + init 
        return matrix 
    

class Transform(nn.Module):
    def __init__(self, k=35):
        super().__init__() 
        self.k = k
        self.input_transform = Tnet(self.k)
        self.feature_transform = Tnet(k=64)
        self.conv1 = nn.Conv1d(self.k, 64, 1)

        self.conv2 = nn.Conv1d(64, 128, 1)
        self.conv3 = nn.Conv1d(128, 1024, 1)

        self.bn1 = nn.BatchNorm1d(64)
        self.bn2 = nn.BatchNorm1d(128)
        self.bn3 = nn.BatchNorm1d(1024)
    
    def forward(self, input):
        matrix3x3 = self.input_transform(input)
        # batch matrix multiplication 
        xb = torch.bmm(torch.transpose(input, 1, 2), matrix3x3).transpose(1,2)

        xb =  F.relu(self.bn1(self.conv1(xb)))

        matrix64x64 = self.feature_transform(xb)
        xb = torch.bmm(torch.transpose(xb, 1,2), matrix64x64).transpose(1, 2)

        xb = F.relu(self.bn2(self.conv2(xb)))
        xb = self.bn3(self.conv3(xb))
        xb = nn.MaxPool1d(xb.shape[-1])(xb)
        output = torch.squeeze(xb)
        return output, matrix3x3, matrix64x64




class PointNet(nn.Module):
    def __init__(self, k=35,  output_dim=61):
        super().__init__() 
        self.k = k 
        self.transform = Transform(self.k) 
        self.fc1 = nn.Linear(1024, 512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, output_dim)
        


        self.bn1 = nn.BatchNorm1d(512)
        self.bn2 = nn.BatchNorm1d(1024)
        
        

    def forward(self, input):
        xb, matrix3x3, matrix64x64 = self.transform(input) 
        xb = F.relu(self.bn1(self.fc1(xb)))
        xb = F.relu(self.bn2(self.fc2(xb)))
        output = F.relu(self.fc3(xb))
        return output 
    

In [9]:
class AttentionHead(nn.Module):
    def __init__(self, head_dim):
        super().__init__() 
        
        self.head_dim = head_dim 
        
        self.fc1 = nn.Linear(self.head_dim, 2 * self.head_dim)
        self.fc2 = nn.Linear(2 * self.head_dim, self.head_dim)
        
    def forward(self, x):
        B, H, W, num_voxels, dim = x.shape 
        
        scale = np.sqrt( 2 * self.head_dim)
        
        Query = self.fc1(x)
        Key = self.fc1(x)
        Value = self.fc1(x)
        
        attention_weight = (Query @ Key.transpose(3,4)) / scale 
        attention_weight = F.softmax(attention_weight, dim=-1)
        
        feature = attention_weight @ Value 
        
        feature = F.relu(self.fc2(feature))
        
        return feature 
    
    

In [10]:
class MultiheadAttention(nn.Module):
    def __init__(self, num_heads=8, dim=64):
        super().__init__()
        
        self.num_heads = num_heads
        self.head_dim = int(dim/num_heads)
        
        self.head1 = AttentionHead(self.head_dim)
        self.head2 = AttentionHead(self.head_dim)
        self.head3 = AttentionHead(self.head_dim)
        self.head4 = AttentionHead(self.head_dim)
        self.head5 = AttentionHead(self.head_dim)
        self.head6 = AttentionHead(self.head_dim) 
        self.head7 = AttentionHead(self.head_dim) 
        self.head8 = AttentionHead(self.head_dim)
        
        self.bn1 = nn.BatchNorm3d(40) 
        
        self.fc1 = nn.Linear(dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, dim)
        
        self.bn2 = nn.BatchNorm3d(40)
        
        
        
        
        
    def forward(self, x):
        B, W, H, window_size, window_size, dim = x.shape 
        x = x.reshape(B, W, H, window_size * window_size, dim)
        x = x.reshape(B, W, H, window_size * window_size, self.num_heads, self.head_dim)
        x = x.reshape(self.num_heads, B, W, H, window_size * window_size, self.head_dim)
        
        head1_feature = self.head1(x[0])
        head2_feature = self.head2(x[1])
        head3_feature = self.head3(x[2])
        head4_feature = self.head4(x[3])
        head5_feature = self.head5(x[4])
        head6_feature = self.head6(x[5])
        head7_feature = self.head7(x[6])
        head8_feature = self.head8(x[7])
        
        feature = torch.cat((head1_feature, head2_feature, head3_feature, head4_feature,
                             head5_feature, head6_feature, head7_feature, head8_feature), dim=-1)
        
        #print("in multi_head_attention ", feature.shape)
        
        mh_feature = feature 
        
        # apply batchnorm 
        feature = self.bn1(feature)
        
        # apply MLP 
        
        feature = F.relu(self.fc1(feature))
        feature = F.relu(self.fc2(feature))
        feature = F.relu(self.fc3(feature))
        
        # Add and Norm 
        
        feature = feature + mh_feature      # residual connection 
        
        feature = self.bn2(feature)
        
        #print(feature.shape)
        
        # take the max for global feature of window 
        global_feature = torch.max(feature, dim=3)[0]
        
        
        #print("global dimension", global_feature.shape)
        
        
        
        # convert it into original shape 
        feature = feature.reshape(B, W, H, window_size, window_size, dim)
        
        #print(feature.shape)
        
        return feature, global_feature 
    
    

In [11]:
class GlobalAttentionHead(nn.Module):
    def __init__(self, head_dim):
        super().__init__() 
        
        self.head_dim = head_dim 
        
        self.fc1 = nn.Linear(self.head_dim, 2 * self.head_dim)
        self.fc2 = nn.Linear(2 * self.head_dim, self.head_dim)
        
    def forward(self, x):
        B, W, H, dim  = x.shape 
        
        scale = np.sqrt(2 * self.head_dim)
        
        Query = self.fc1(x)
        Key = self.fc1(x)
        Value = self.fc1(x)
        
        attention_weight = (Query @ Key.transpose(2, 3)) / scale 
        attention_weight = F.softmax(attention_weight, dim=-1)
        
        feature = attention_weight @ Value 
        feature = F.relu(self.fc2(feature))
        
        return feature 

    

In [12]:
class GlobalMultiheadAttention(nn.Module):
    def __init__(self, num_heads=8, dim=64):
        super().__init__() 
        
        self.num_heads = num_heads 
        self.head_dim = int(dim/self.num_heads) 
        
        
        self.head1 = GlobalAttentionHead(self.head_dim)
        self.head2 = GlobalAttentionHead(self.head_dim)
        self.head3 = GlobalAttentionHead(self.head_dim)
        self.head4 = GlobalAttentionHead(self.head_dim)
        self.head5 = GlobalAttentionHead(self.head_dim)
        self.head6 = GlobalAttentionHead(self.head_dim)
        self.head7 = GlobalAttentionHead(self.head_dim)
        self.head8 = GlobalAttentionHead(self.head_dim)
        
        
        self.bn1 = nn.BatchNorm2d(40)
        
        self.fc1 = nn.Linear(dim, 128)
        self.fc2 = nn.Linear(128, 256)
        self.fc3 = nn.Linear(256, dim)
        
        self.bn2 = nn.BatchNorm2d(40)
        
        
        
        
    def forward(self, x):
        B, W, H, dim = x.shape 
        
        x = x.reshape(B, W, H, self.num_heads, self.head_dim)
        x = x.reshape(self.num_heads, B , W, H, self.head_dim)
        
        head1_feature = self.head1(x[0])
        head2_feature = self.head2(x[1])
        head3_feature = self.head3(x[2])
        head4_feature = self.head4(x[3])
        head5_feature = self.head5(x[4])
        head6_feature = self.head6(x[5])
        head7_feature = self.head7(x[6])
        head8_feature = self.head8(x[7])
        
        feature = torch.cat((head1_feature, head2_feature, head3_feature, head4_feature, 
                           head5_feature, head6_feature, head7_feature, head8_feature), dim=-1)
        
        
        mh_feature = feature 
        
        # apply batchnorm 
        feature = self.bn1(feature)
        
        # apply mlp 
        feature = F.relu(self.fc1(feature))
        feature = F.relu(self.fc2(feature))
        feature = F.relu(self.fc3(feature))
        
        # add and norm 
        
        feature = feature + mh_feature   # residual connection  
        
        feature = self.bn2(feature) 
        
        return feature 
           

In [13]:
class TransformerEncoder(nn.Module):
    def __init__(self):
        super().__init__() 
        
        self.pointnet = PointNet() 
        
        self.multi_head_attention1 = MultiheadAttention(dim=64)
        self.global_multi_head_attention1 = GlobalMultiheadAttention(dim=64)
        
        self.multi_head_attention2 = MultiheadAttention(dim=64 * 2)
        self.global_multi_head_attention2 = GlobalMultiheadAttention(dim=64 * 2)
        
        self.multi_head_attention3 = MultiheadAttention(dim=64 * 2 * 2)
        self.global_multi_head_attention3 = GlobalMultiheadAttention(dim=64 * 2 * 2)
        
        
        
        
        
        
    def forward(self, x, coors):
        B, Voxels, max_points, dim = x.shape
        x = x.reshape(-1, max_points, dim)
        
        # send to pointnet
        x = self.pointnet(x).unsqueeze(0)
    
        # concatenation of voxel coordinate with voxel features ( positional embedding)
        x = torch.cat((x, coors), 2).reshape(1, 200, 100, 64)
        B, H, W ,  dim = x.shape 
        # divide into window of size 5 * 5 
        x = x.reshape(B, int(H/5), int(W/5), 5, 5, dim)
        
        # apply attention block 1
        feature, global_feature = self.multi_head_attention1(x)
        # apply global attention-block
        global_feature = self.global_multi_head_attention1(global_feature)
        global_feature = global_feature.unsqueeze(3).unsqueeze(3)
        # repeat tensor 
        global_feature = global_feature.repeat(1, 1, 1, 5, 5, 1)
        # add global feature with local features 
        feature = torch.cat((feature, global_feature), dim=-1) 
        
        
        
        # apply attention block 2 
        feature, global_feature = self.multi_head_attention2(feature)
        # apply global attention block 
        global_feature = self.global_multi_head_attention2(global_feature)
        global_feature = global_feature.unsqueeze(3).unsqueeze(3)
        # repeat tensor 
        global_feature = global_feature.repeat(1,1,1,5,5,1)
        # add global feaeture with local features 
        feature = torch.cat((feature, global_feature), dim=-1)
        
        
        # apply attention block 3 
        feature, global_feature = self.multi_head_attention3(feature)
        # apply globla attenion block 
        global_feature = self.global_multi_head_attention3(global_feature)
        global_feature = global_feature.unsqueeze(3).unsqueeze(3)
        # repeat tensor 
        global_feature = global_feature.repeat(1,1,1,5,5,1)
        # add global feature with local features 
        feature = torch.cat((feature, global_feature), dim=-1) 
        
        
        
        
        return feature 
    
    

In [14]:
class FusionAttention(nn.Module):
    def __init__(self, head_dim):
        super().__init__() 
        
        self.head_dim = head_dim 
        self.fc1 = nn.Linear(self.head_dim, 2 * self.head_dim)
        self.fc2 = nn.Linear(2*self.head_dim , self.head_dim)
        
    def forward(self, query, key, value):
        
        scale = np.sqrt(self.head_dim)
        
        Query = self.fc1(query)
        Key = self.fc1(key)
        Value = self.fc1(value)
        
        attention_weight = (Query @ Key.transpose(3, 4)) / scale 
        attention_weight = F.softmax(attention_weight, dim=-1)
        
        feature = attention_weight @ Value 
        feature = self.fc2(feature)
        return feature 
    
    

In [15]:
class CorrelationAttentionHead(nn.Module):
    def __init__(self, num_head=8, dim=512):
        super().__init__() 
        
        self.num_head = num_head 
        self.head_dim = int(dim/self.num_head)
        
        
        self.head1 = FusionAttention(self.head_dim)
        self.head2 = FusionAttention(self.head_dim)
        self.head3 = FusionAttention(self.head_dim)
        self.head4 = FusionAttention(self.head_dim)
        self.head5 = FusionAttention(self.head_dim)
        self.head6 = FusionAttention(self.head_dim)
        self.head7 = FusionAttention(self.head_dim)
        self.head8 = FusionAttention(self.head_dim)
        
        self.bn1 = nn.BatchNorm3d(40)
        
        self.fc1 = nn.Linear(dim, 2 * dim)
        self.fc2 = nn.Linear(2 * dim, 3 * dim)
        self.fc3 = nn.Linear(3 * dim, dim)
        
        self.bn2 = nn.BatchNorm3d(40)
        
    def forward(self, current_x, past_x):
        c_B, c_W, c_H, c_window_size, c_window_size, c_dim = current_x.shape 
        p_B, p_W, p_H, p_window_size, p_window_size, p_dim = past_x.shape 
        
        # change shape 
        current_x = current_x.reshape(c_B, c_W, c_H, c_window_size * c_window_size, c_dim)
        past_x = past_x.reshape(p_B, p_W, p_H, p_window_size * p_window_size, p_dim)
        
        Query = current_x.reshape(c_B, c_W, c_H, c_window_size * c_window_size, self.num_head, self.head_dim)
        Query = current_x.reshape(self.num_head, c_B, c_W, c_H, c_window_size*c_window_size, self.head_dim)
        
        Key = past_x.reshape(p_B, p_W, p_H, p_window_size * p_window_size, self.num_head, self.head_dim)
        Key = past_x.reshape(self.num_head, p_B, p_W, p_H, p_window_size*p_window_size, self.head_dim)
        Value = Key 
        
        head1_feature = self.head1(Query[0], Key[0], Value[0])
        head2_feature = self.head2(Query[1], Key[1], Value[1])
        head3_feature = self.head3(Query[2], Key[2], Value[2])
        head4_feature = self.head4(Query[3], Key[3], Value[3])
        head5_feature = self.head5(Query[4], Key[4], Value[4])
        head6_feature = self.head6(Query[5], Key[5], Value[5])
        head7_feature = self.head7(Query[6], Key[6], Value[6])
        head8_feature = self.head8(Query[7], Key[7], Value[7])
        
        feature = torch.cat((head1_feature, head2_feature, head3_feature, head4_feature,
                            head5_feature, head6_feature, head7_feature, head8_feature), dim=-1)
        
        
        mh_feature = feature 
        
        # apply batchnorm 
        feature = self.bn1(feature)
        
        # apply mlp 
        feature = F.relu(self.fc1(feature))
        feature = F.relu(self.fc2(feature))
        feature = F.relu(self.fc3(feature))
        
        # add and norm 
        feature = feature + mh_feature    # residual connections 
        
        feature = self.bn2(feature)
        
        # convert it into original shape
        feature = feature.reshape(c_B, c_W, c_H, c_window_size, c_window_size, c_dim)
        
        return feature  
        

In [16]:
class Model(nn.Module):
    def __init__(self):
        super().__init__() 
        
        self.encoder = TransformerEncoder()
        
        self.correlationModule = CorrelationAttentionHead() 
        
        self.past_frame_features = 0
        
        # make displacement head 
        self.fc_d1 = nn.Linear(512, 1024)
        self.fc_d2 = nn.Linear(1024, 2024)
        self.fc_d3 = nn.Linear(2024, 3 * 80*70)
        
        # meke score head 
        self.fc_s1 = nn.Linear(512, 1024)
        self.fc_s2 = nn.Linear(1024, 2024)
        self.fc_s3 = nn.Linear(2024, 80*70)
        
        # make detection_head 
        self.fc_cent1 = nn.Linear(512, 1024)
        self.fc_cent2 = nn.Linear(1024, 2024)
        self.fc_cent3 = nn.Linear(2024, 80*70*3)
        
        # make coordinate prediction head 
        self.fc_coord1 = nn.Linear(512, 1024)
        self.fc_coord2 = nn.Linear(1024, 2024)
        self.fc_coord3 = nn.Linear(2024, 80*70*8*3)
        
        # make classification_head 
        self.fc_cl1 = nn.Linear(512, 1024)
        self.fc_cl2 = nn.Linear(1024, 2024)
        self.fc_cl3 = nn.Linear(2024, 80*70*10) 
        
        # make rotation prediction head 
        
        self.fc_rot1 = nn.Linear(512, 1024)
        self.fc_rot2 = nn.Linear(1024, 2024)
        self.fc_rot3 = nn.Linear(2024, 80*70)
        
        
        
      
        
    def forward(self, voxels, coors, n):
        num_frames = n+1                        # for tracking frame number
        # apply encoder module 
        features = self.encoder(voxels, coors)
        
        
        # apply correlation fusion module
        
        if num_frames == 1:
            correlated_features = self.correlationModule(features, features)   # current_x and past_x are  same for end to end training
            self.past_frame_features= features # make current frame features as a past frame features.
        else:
            correlated_features = self.correlationModule(features, self.past_frame_features)
            self.past_frame_features = features #make current frame features as a past frame features.
            
            
        
        # take max pool for all voxel features for redusing number of features amd number of model parameters
        
        features = features.reshape(1, -1, 512).max(1, keepdim=True)[0]
        correlated_features = correlated_features.reshape(1, -1, 512).max(1, keepdim=True)[0]
        
        # apply displacement head 
        displacement = F.relu(self.fc_d1(correlated_features))
        displacement = F.relu(self.fc_d2(displacement))
        displacement = self.fc_d3(displacement).reshape(1, 3, 80, 70)
        
        # apply score head 
        score = F.relu(self.fc_s1(features))
        score = F.relu(self.fc_s2(score))
        score = self.fc_s3(score).reshape(1, 80, 70)
        
        # apply detection head 
        detection = F.relu(self.fc_cent1(features))
        detection = F.relu(self.fc_cent2(detection))
        detection = self.fc_cent3(detection).reshape(1, 3, 80, 70)
        
        # apply coordinate prediction head 
        coordinate = F.relu(self.fc_coord1(features))
        coordinate = F.relu(self.fc_coord2(coordinate))
        coordinate = self.fc_coord3(coordinate).reshape(1, 8*3, 80, 70)
        
        
        
        # apply classification head 
        classification = F.relu(self.fc_cl1(features))
        classification = F.relu(self.fc_cl2(classification))
        classification = self.fc_cl3(classification).reshape(1, 10, 80, 70)
        
        # apply rotation prediction head 
        rot = F.relu(self.fc_rot1(features)) 
        rot = F.relu(self.fc_rot2(rot))
        rot = self.fc_rot3(rot).reshape(1, 80, 70)
        
        
        return displacement, score, detection, coordinate, classification, rot
           
        

In [17]:
def focal_loss(input, target, gamma):
    ce = F.cross_entropy(input, target, reduction='none')
    pt = torch.exp(-ce)

    focal_loss = (1 - pt)**gamma * ce 

    # take sum of focal loss of all classes 
    focal_loss = focal_loss.sum()

    return focal_loss 

In [None]:
model = Model()
model.to(device)

In [18]:
optimizer = torch.optim.Adam(model.parameters(), lr = 0.1)

NameError: name 'model' is not defined

In [19]:
scheduler = StepLR(optimizer, step_size=10, gamma=0.1)

NameError: name 'optimizer' is not defined

In [20]:
epoch_loss = [] 

In [None]:
def train(model, train_loader, epoch):
    for i in range(epoch):
        running_loss = 0.0 
        global past_canters 
        past_centers = 0
        for n, data in enumerate(train_loader):
            loss = 0.0
            if n == 0:
                voxels = data['voxels'].to(device).float()
                coors = data['coors'].to(device).float()
                bboxes = data['bboxes']
                
                # store centers, bboxes 
                centers = []
                coordinates = []
                rot = [] 
                ids = [] 
                category = [] 
                
                for l in range(len(bboxes)):
                    centers.append(bboxes[l][0][0])
                    coordinates.append(bboxes[l][1][0].reshape(1, -1)[0])
                    rot.append(bboxes[l][2][0])
                    category.append(bboxes[l][3][0])
                    ids.append(bboxes[l][4][0])
                    
                #print("centers = ", centers)
                #print("coordinates = ", coordinates[0])
                #print("rot=",rot[0].item())
                #print("ids=", ids)
                #print("category=",category[0])
                                    
                # send input to the model
                displacement_pred, score_pred, detection_pred, coordinate_pred, classification_pred, rot_pred = model(voxels, coors, n) 
                
                displacement_target = torch.zeros(displacement_pred.shape).to(device)
                score_target = torch.zeros(score_pred.shape).to(device)
                detection_target = torch.zeros(detection_pred.shape).to(device)
                coordinate_target = torch.zeros(coordinate_pred.shape).to(device)
                classification_target = torch.zeros(classification_pred.shape).to(device)
                rot_target = torch.zeros(rot_pred.shape).to(device)
                
                
                
                
                
                for l in range(len(centers)):
                    center = torch.round(centers[l])
                    x = center[0].item()
                    y = center[1].item()
                     
                    if x >= 70 or y > 40 or y <= -40:      # delete all annotation outside the range
                        continue 
                        
                    x = int(x)
                    if y < 0:
                        y = abs(y) + 40 
                        y = int(y)
                    else:
                        y = int(y) 
                        
                    score_target[0][y][x] = 1      # y , x locate the location on output grid 
                    
                    # get the center location of BB
                    center_bb = centers[l]
                    center_x = center_bb[0].item()
                    center_y = center_bb[1].item()
                    center_z = center_bb[2].item()
                    
                    detection_target[0][0][y][x] = center_x 
                    detection_target[0][1][y][x] = center_y 
                    detection_target[0][2][y][x] = center_z 
                    
                    
                    for k in range(24):
                        coordinate_target[0][k][y][x] = coordinates[l][k]
                        
                    # 'Car'->0, 'Van'->1, 'Truck'->2, 'Pedestrian'->3, 'Person_sitting'->4, 'Cyclist'->5, 
                    #'Tram'->6, 'Misc'->7, 'Person'->8, 'Background'->9 
                    
                    if category[l] == 'Car':
                        classification_target[0][0][y][x] = 1 
                    if category[l] == 'Van':
                        classification_target[0][1][y][x] = 1 
                    if category[l] == 'Truck':
                        classification_target[0][2][y][x] = 1 
                    if category[l] == 'Pedestrian':
                        classification_target[0][3][y][x] = 1 
                    if category[l] == 'Person_sitting':
                        classification_target[0][4][y][x] = 1 
                    if category[l] == 'Cyclist':
                        classification_target[0][5][y][x] = 1 
                    if category[l] == 'Tram':
                        classification_target[0][6][y][x] = 1 
                    if category[l] == 'Misc':
                        classification_target[0][7][y][x] = 1 
                    if category[l] == 'Person':
                        classification_target[0][8][y][x] = 1 
                     
                    #  add one at background (complementation)
                    classification_target[0][9][y][x] = 1 
                    # prepare rot_target 
                    rot_target[0][y][x] = rot[l].item()
                    
                # correct background class  
                
                for k in  range(80):
                    for j in range(70):
                        if classification_target[0][9][k][j] == 1:
                            classification_target[0][9][k][j] = 0 
                        else:
                            classification_target[0][9][k][j] = 1 
               
                # calculate loss 
                displacement_loss = ((displacement_pred - displacement_target)**2).sum()
                # calculate score loss 
                score_loss = ((score_pred - score_target) ** 2).sum()
                # calculate detection loss 
                detection_loss = ((detection_pred - detection_target)**2).sum()
                # calculate coordinate loss 
                coordinate_loss = ((coordinate_pred - coordinate_target)**2).sum() 
                # rotation loss 
                rot_loss = ((rot_pred - rot_target)**2).sum()
                # calculate classification loss using focal loss function 
                classification_loss = focal_loss(classification_pred, classification_target, 2)
                
                loss = loss + 1* classification_loss + 2 * detection_loss + 2 * coordinate_loss + 2 * score_loss + 2 * rot_loss + displacement_loss 
                
                # apply backprop 
                optimizer.zero_grad() 
                loss.backward(retain_graph=True) 
                
                # print gradient 
                for name, param in model.named_parameters():
                    print(name, param.grad)
                    
                
                
                optimizer.step() 
                
                scheduler.step()
                
                running_loss = running_loss + loss.item() 
                
                # assign current detection target to past_centers
                past_centers = detection_target 
                
                
                del loss, displacement_loss, score_loss, detection_loss, coordinate_loss, rot_loss, classification_loss 
                del displacement_pred, score_pred, detection_pred, coordinate_pred, classification_pred, rot_pred 
                del displacement_target, score_target, detection_target, coordinate_target, classification_target, rot_target 
                
                     
            else:
                voxels = data['voxels'].to(device).float()
                coors = data['coors'].to(device).float()
                bboxes = data['bboxes']
                
                # store centers, bboxes 
                centers = []
                coordinates = []
                rot = [] 
                ids = [] 
                category = [] 
                
                for l in range(len(bboxes)):
                    centers.append(bboxes[l][0][0])
                    coordinates.append(bboxes[l][1][0].reshape(1, -1)[0])
                    rot.append(bboxes[l][2][0])
                    category.append(bboxes[l][3][0])
                    ids.append(bboxes[l][4][0])
                    
                #print("centers = ", centers)
                #print("coordinates = ", coordinates[0])
                #print("rot=",rot[0].item())
                #print("ids=", ids)
                #print("category=",category[0])
                                    
                # send input to the model
                displacement_pred, score_pred, detection_pred, coordinate_pred, classification_pred, rot_pred = model(voxels, coors, n) 
                
                #displacement_target = torch.zeros(displacement_pred.shape).to(device)
                score_target = torch.zeros(score_pred.shape).to(device)
                detection_target = torch.zeros(detection_pred.shape).to(device)
                coordinate_target = torch.zeros(coordinate_pred.shape).to(device)
                classification_target = torch.zeros(classification_pred.shape).to(device)
                rot_target = torch.zeros(rot_pred.shape).to(device)
                
                
                
                
                
                for l in range(len(centers)):
                    center = torch.round(centers[l])
                    x = center[0].item()
                    y = center[1].item()
                     
                    if x >= 70 or y > 40 or y <= -40:      # delete all annotation outside the range
                        continue 
                        
                    x = int(x)
                    if y < 0:
                        y = abs(y) + 40 
                        y = int(y)
                    else:
                        y = int(y) 
                        
                    score_target[0][y][x] = 1      # y , x locate the location on output grid 
                    
                    # get the center location of BB
                    center_bb = centers[l]
                    center_x = center_bb[0].item()
                    center_y = center_bb[1].item()
                    center_z = center_bb[2].item()
                    
                    detection_target[0][0][y][x] = center_x 
                    detection_target[0][1][y][x] = center_y 
                    detection_target[0][2][y][x] = center_z 
                    
                    
                    for k in range(24):
                        coordinate_target[0][k][y][x] = coordinates[l][k]
                        
                    # 'Car'->0, 'Van'->1, 'Truck'->2, 'Pedestrian'->3, 'Person_sitting'->4, 'Cyclist'->5, 
                    #'Tram'->6, 'Misc'->7, 'Person'->8, 'Background'->9 
                    
                    if category[l] == 'Car':
                        classification_target[0][0][y][x] = 1 
                    if category[l] == 'Van':
                        classification_target[0][1][y][x] = 1 
                    if category[l] == 'Truck':
                        classification_target[0][2][y][x] = 1 
                    if category[l] == 'Pedestrian':
                        classification_target[0][3][y][x] = 1 
                    if category[l] == 'Person_sitting':
                        classification_target[0][4][y][x] = 1 
                    if category[l] == 'Cyclist':
                        classification_target[0][5][y][x] = 1 
                    if category[l] == 'Tram':
                        classification_target[0][6][y][x] = 1 
                    if category[l] == 'Misc':
                        classification_target[0][7][y][x] = 1 
                    if category[l] == 'Person':
                        classification_target[0][8][y][x] = 1 
                     
                    #  add one at background (complementation)
                    classification_target[0][9][y][x] = 1 
                    # prepare rot_target 
                    rot_target[0][y][x] = rot[l].item()
                    
                # correct background class  
                
                for k in  range(80):
                    for j in range(70):
                        if classification_target[0][9][k][j] == 1:
                            classification_target[0][9][k][j] = 0 
                        else:
                            classification_target[0][9][k][j] = 1 
                
                
                # calculate displacement_target 
                displacement_target = detection_target - past_centers 
                
                
                # calculate loss 
                displacement_loss = ((displacement_pred - displacement_target)**2).sum()
                # calculate score loss 
                score_loss = ((score_pred - score_target) ** 2).sum()
                # calculate detection loss 
                detection_loss = ((detection_pred - detection_target)**2).sum()
                # calculate coordinate loss 
                coordinate_loss = ((coordinate_pred - coordinate_target)**2).sum() 
                # rotation loss 
                rot_loss = ((rot_pred - rot_target)**2).sum()
                # calculate classification loss using focal loss function 
                classification_loss = focal_loss(classification_pred, classification_target, 2)
                
                loss = loss + 1* classification_loss + 2 * detection_loss + 2 * coordinate_loss + 2 * score_loss + 2 * rot_loss + displacement_loss 
                
                # apply backprop 
                optimizer.zero_grad() 
                loss.backward(retain_graph=True) 
                # print gradient
                for name, param in model.named_parameters():
                    print(name, param.grad)
                
                optimizer.step() 
                
                scheduler.step()
                
                
                running_loss = running_loss + loss.item() 
                
                # assign current detection_tartget to past_centers
                past_centers = detection_target
                
                
                del loss, displacement_loss, score_loss, detection_loss, coordinate_loss, rot_loss, classification_loss 
                del displacement_pred, score_pred, detection_pred, coordinate_pred, classification_pred, rot_pred 
                del displacement_target, score_target, detection_target, coordinate_target, classification_target, rot_target 
        
        epoch_loss.append(running_loss)
        print(" epoch = {} , running_loss = {}".format(i+1, running_loss))
        
        checkpoint = {
            "epoch_number" : i+1,
            "model_state" : model.state_dict()
        }
        torch.save(checkpoint, "tracking_model2.pth") 
        
        
            

In [None]:
train(model, train_loader, 80)