In [1]:
import os
import torch
from tqdm import tqdm
import numpy as np

import matplotlib.pyplot as plt
from utils import setup_seed
from dataset import Kitti, get_dataloader
from model import PointPillars
from loss import Loss
from torch.utils.tensorboard import SummaryWriter

Jupyter environment detected. Enabling Open3D WebVisualizer.
[Open3D INFO] WebRTC GUI backend enabled.
[Open3D INFO] WebRTCWindowSystem: HTTP handshake server disabled.


In [2]:
def save_summary(writer, loss_dict, global_step, tag, lr=None, momentum=None):
    for k, v in loss_dict.items():
        writer.add_scalar(f'{tag}/{k}', v, global_step)
    if lr is not None:
        writer.add_scalar('lr', lr, global_step)
    if momentum is not None:
        writer.add_scalar('momentum', momentum, global_step)

### Training Arguments

In [3]:
class Args:
    def __init__(self):
        self.data_root = "dataset/KITTI"
        self.saved_path = "logs/pillar_sequence"
        self.batch_size = 4
        self.num_workers = 4
        self.nclasses = 3
        self.init_lr = 0.00025
        self.max_epoch = 160
        self.log_freq = 8
        self.ckpt_freq_epoch = 20
        self.no_cuda = not torch.cuda.is_available()
 
args = Args()

### Dataloader

In [4]:
setup_seed()
train_dataset = Kitti(data_root=args.data_root,
                        split='train')
val_dataset = Kitti(data_root=args.data_root,
                    split='val')
train_dataloader = get_dataloader(dataset=train_dataset, 
                                    batch_size=args.batch_size, 
                                    num_workers=args.num_workers,
                                    shuffle=True)
val_dataloader = get_dataloader(dataset=val_dataset, 
                                batch_size=args.batch_size, 
                                num_workers=args.num_workers,
                                shuffle=False)

In [5]:
data_dict = val_dataset.__getitem__(8)
print(data_dict['pts'].shape)

(18200, 4)


### Model, Loss, Optimizer, Scheduler, Log

In [6]:
if not args.no_cuda:
    pointpillars = PointPillars(nclasses=args.nclasses).cuda()
else:
    pointpillars = PointPillars(nclasses=args.nclasses)

loss_func = Loss()

max_iters = 2*len(train_dataloader) * args.max_epoch
init_lr = args.init_lr
optimizer = torch.optim.AdamW(params=pointpillars.parameters(), 
                                lr=init_lr, 
                                betas=(0.95, 0.99),
                                weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,  
                                                max_lr=init_lr*10, 
                                                total_steps=max_iters, 
                                                pct_start=0.4, 
                                                anneal_strategy='cos',
                                                cycle_momentum=True, 
                                                base_momentum=0.95*0.895, 
                                                max_momentum=0.95,
                                                div_factor=10)

saved_logs_path = os.path.join(args.saved_path, 'summary')
os.makedirs(saved_logs_path, exist_ok=True)
writer = SummaryWriter(saved_logs_path)
saved_ckpt_path = os.path.join(args.saved_path, 'checkpoints')
os.makedirs(saved_ckpt_path, exist_ok=True)

In [7]:
from ptflops import get_model_complexity_info

with torch.no_grad():

    for i, data_dict in enumerate(tqdm(train_dataloader)):
        break

    if not args.no_cuda:
        # move the tensors to the cuda
        for key in data_dict:
            for j, item in enumerate(data_dict[key]):
                if torch.is_tensor(item):
                    data_dict[key][j] = data_dict[key][j].cuda()

    optimizer.zero_grad()

    batched_pts = data_dict['batched_pts']

    input_size = tuple(batched_pts[0].shape)

    macs, params = get_model_complexity_info(pointpillars, input_size, as_strings=True,
                                                 print_per_layer_stat=True, verbose=True)

    print(f"FLOPs: {macs}")
    print(f"Parameters: {params}")

# bbox_cls_pred0, bbox_pred0, bbox_dir_cls_pred0, \
#                 bbox_cls_pred1, bbox_pred1, bbox_dir_cls_pred1, \
#                     bbox_cls_pred2, bbox_pred2, bbox_dir_cls_pred2, \
#                         anchor_target_dict = pointpillars(batched_pts=batched_pts, 
#                             mode='train',
#                             batched_gt_bboxes=batched_gt_bboxes, 
#                             batched_gt_labels=batched_labels)

  0%|          | 0/1317 [00:15<?, ?it/s]




  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


PointPillars(
  6.36 M, 100.000% Params, 63.29 GMac, 99.922% MACs, 
  (pillar_layer): PillarLayer(
    0, 0.000% Params, 0.0 Mac, 0.000% MACs, 
    (voxel_layer): Voxelization(voxel_size=[0.16, 0.16, 4], point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1], max_num_points=32, max_voxels=(16000, 40000), deterministic=True)
  )
  (pillar_encoder): PillarEncoder(
    704, 0.011% Params, 22.53 KMac, 0.000% MACs, 
    (conv): Conv1d(576, 0.009% Params, 18.43 KMac, 0.000% MACs, 9, 64, kernel_size=(1,), stride=(1,), bias=False)
    (bn): BatchNorm1d(128, 0.002% Params, 4.1 KMac, 0.000% MACs, 64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
  )
  (backbone0): Backbone(
    4.21 M, 66.155% Params, 29.71 GMac, 46.904% MACs, 
    (multi_blocks): ModuleList(
      4.21 M, 66.155% Params, 29.71 GMac, 46.904% MACs, 
      (0): Sequential(
        147.97 k, 2.326% Params, 7.94 GMac, 12.536% MACs, 
        (0): Conv2d(36.86 k, 0.580% Params, 1.97 GMac, 3.118% MACs, 64, 64, kernel_siz

In [8]:
print(flops)

<fvcore.nn.flop_count.FlopCountAnalysis object at 0x7fd9a2699340>


In [6]:
'''
for i, data_dict in enumerate(tqdm(train_dataloader)):
    if not args.no_cuda:
        # move the tensors to the cuda
        print("Here")
        for key in data_dict:
            for j, item in enumerate(data_dict[key]):
                if torch.is_tensor(item):
                    data_dict[key][j] = data_dict[key][j].cuda()
    
    optimizer.zero_grad()

    batched_pts = data_dict['batched_pts']
    batched_gt_bboxes = data_dict['batched_gt_bboxes']
    batched_labels = data_dict['batched_labels']
    batched_difficulty = data_dict['batched_difficulty']
    bbox_cls_pred0, bbox_pred0, bbox_dir_cls_pred0, \
        bbox_cls_pred1, bbox_pred1, bbox_dir_cls_pred1, \
            bbox_cls_pred2, bbox_pred2, bbox_dir_cls_pred2, anchor_target_dict = \
                pointpillars(batched_pts=batched_pts, 
                                mode='train',
                                batched_gt_bboxes=batched_gt_bboxes, 
                                batched_gt_labels=batched_labels)
    
    break

'''

'\nfor i, data_dict in enumerate(tqdm(train_dataloader)):\n    if not args.no_cuda:\n        # move the tensors to the cuda\n        print("Here")\n        for key in data_dict:\n            for j, item in enumerate(data_dict[key]):\n                if torch.is_tensor(item):\n                    data_dict[key][j] = data_dict[key][j].cuda()\n    \n    optimizer.zero_grad()\n\n    batched_pts = data_dict[\'batched_pts\']\n    batched_gt_bboxes = data_dict[\'batched_gt_bboxes\']\n    batched_labels = data_dict[\'batched_labels\']\n    batched_difficulty = data_dict[\'batched_difficulty\']\n    bbox_cls_pred0, bbox_pred0, bbox_dir_cls_pred0,         bbox_cls_pred1, bbox_pred1, bbox_dir_cls_pred1,             bbox_cls_pred2, bbox_pred2, bbox_dir_cls_pred2, anchor_target_dict =                 pointpillars(batched_pts=batched_pts, \n                                mode=\'train\',\n                                batched_gt_bboxes=batched_gt_bboxes, \n                                batched_g

### Training

In [9]:
training_loss0 = []
training_loss1 = []
training_loss2 = []

val_loss0 = []
val_loss1 = []
val_loss2 = []

In [10]:
for epoch in range(args.max_epoch):
# for epoch in range(args.max_epoch):
    epoch_loss0 = 0
    # epoch_loss1 = 0
    # epoch_loss2 = 0

    val_epoch_loss0 = 0
    # val_epoch_loss1 = 0
    # val_epoch_loss2 = 0

    print('=' * 20, epoch, '=' * 20)
    train_step, val_step = 0, 0
    for i, data_dict in enumerate(tqdm(train_dataloader)):
        if not args.no_cuda:
            # move the tensors to the cuda
            for key in data_dict:
                for j, item in enumerate(data_dict[key]):
                    if torch.is_tensor(item):
                        data_dict[key][j] = data_dict[key][j].cuda()
        
        optimizer.zero_grad()

        batched_pts = data_dict['batched_pts']
        batched_gt_bboxes = data_dict['batched_gt_bboxes']
        batched_labels = data_dict['batched_labels']
        batched_difficulty = data_dict['batched_difficulty']


        bbox_cls_pred0, bbox_pred0, bbox_dir_cls_pred0, \
            bbox_cls_pred1, bbox_pred1, bbox_dir_cls_pred1, \
                bbox_cls_pred2, bbox_pred2, bbox_dir_cls_pred2, anchor_target_dict = pointpillars(batched_pts=batched_pts, 
                                    mode='train',
                                    batched_gt_bboxes=batched_gt_bboxes, 
                                    batched_gt_labels=batched_labels)

        
        ################# Full features #################
        bbox_cls_pred0 = bbox_cls_pred0.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
        bbox_pred0 = bbox_pred0.permute(0, 2, 3, 1).reshape(-1, 7)
        bbox_dir_cls_pred0 = bbox_dir_cls_pred0.permute(0, 2, 3, 1).reshape(-1, 2)

        # ################# Half features #################
        # bbox_cls_pred1 = bbox_cls_pred1.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
        # bbox_pred1 = bbox_pred1.permute(0, 2, 3, 1).reshape(-1, 7)
        # bbox_dir_cls_pred1 = bbox_dir_cls_pred1.permute(0, 2, 3, 1).reshape(-1, 2)

        # ################# Quar features #################
        # bbox_cls_pred2 = bbox_cls_pred2.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
        # bbox_pred2 = bbox_pred2.permute(0, 2, 3, 1).reshape(-1, 7)
        # bbox_dir_cls_pred2 = bbox_dir_cls_pred2.permute(0, 2, 3, 1).reshape(-1, 2)


        batched_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
        batched_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
        batched_bbox_reg = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
        batched_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)


        
        pos_idx = (batched_bbox_labels >= 0) & (batched_bbox_labels < args.nclasses)

        bbox_pred0 = bbox_pred0[pos_idx]
        # bbox_pred1 = bbox_pred1[pos_idx]
        # bbox_pred2 = bbox_pred2[pos_idx]

        batched_bbox_reg = batched_bbox_reg[pos_idx]
        batched_bbox_reg0 = batched_bbox_reg.clone()
        # batched_bbox_reg1 = batched_bbox_reg.clone()
        # batched_bbox_reg2 = batched_bbox_reg.clone()

        # sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
        bbox_pred0[:, -1] = torch.sin(bbox_pred0[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
        # bbox_pred1[:, -1] = torch.sin(bbox_pred1[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
        # bbox_pred2[:, -1] = torch.sin(bbox_pred2[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())

        batched_bbox_reg0[:, -1] = torch.cos(bbox_pred0[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())
        # batched_bbox_reg1[:, -1] = torch.cos(bbox_pred1[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())
        # batched_bbox_reg2[:, -1] = torch.cos(bbox_pred2[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())


        bbox_dir_cls_pred0 = bbox_dir_cls_pred0[pos_idx]
        # bbox_dir_cls_pred1 = bbox_dir_cls_pred1[pos_idx]
        # bbox_dir_cls_pred2 = bbox_dir_cls_pred2[pos_idx]

        batched_dir_labels = batched_dir_labels[pos_idx]

        num_cls_pos = (batched_bbox_labels < args.nclasses).sum()

        bbox_cls_pred0 = bbox_cls_pred0[batched_label_weights > 0]
        # bbox_cls_pred1 = bbox_cls_pred1[batched_label_weights > 0]
        # bbox_cls_pred2 = bbox_cls_pred2[batched_label_weights > 0]

        batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
        batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]

        loss_dict0 = loss_func(bbox_cls_pred=bbox_cls_pred0,
                                bbox_pred=bbox_pred0,
                                bbox_dir_cls_pred=bbox_dir_cls_pred0,
                                batched_labels=batched_bbox_labels, 
                                num_cls_pos=num_cls_pos, 
                                batched_bbox_reg=batched_bbox_reg0, 
                                batched_dir_labels=batched_dir_labels)
        
        # loss_dict1 = loss_func(bbox_cls_pred=bbox_cls_pred1,
        #                         bbox_pred=bbox_pred1,
        #                         bbox_dir_cls_pred=bbox_dir_cls_pred1,
        #                         batched_labels=batched_bbox_labels, 
        #                         num_cls_pos=num_cls_pos, 
        #                         batched_bbox_reg=batched_bbox_reg1, 
        #                         batched_dir_labels=batched_dir_labels)
        
        # loss_dict2 = loss_func(bbox_cls_pred=bbox_cls_pred2,
        #                         bbox_pred=bbox_pred2,
        #                         bbox_dir_cls_pred=bbox_dir_cls_pred2,
        #                         batched_labels=batched_bbox_labels, 
        #                         num_cls_pos=num_cls_pos, 
        #                         batched_bbox_reg=batched_bbox_reg2, 
        #                         batched_dir_labels=batched_dir_labels)
        
        loss0 = loss_dict0['total_loss'] 
        # loss1 = loss_dict1['total_loss']
        # loss2 = loss_dict2['total_loss'] 
        loss = loss0 # + loss1 + loss2
        loss.backward()

        epoch_loss0 = epoch_loss0 + loss0.item()
        # epoch_loss1 = epoch_loss1 + loss1.item()
        # epoch_loss2 = epoch_loss2 + loss2.item()

        # torch.nn.utils.clip_grad_norm_(pointpillars.parameters(), max_norm=35)
        optimizer.step()
        scheduler.step()



        global_step = epoch * len(train_dataloader) + train_step + 1

        if global_step % args.log_freq == 0:
            save_summary(writer, loss_dict0, global_step, 'train',
                            lr=optimizer.param_groups[0]['lr'], 
                            momentum=optimizer.param_groups[0]['betas'][0])
        train_step += 1

    training_loss0.append(epoch_loss0)
    # training_loss1.append(epoch_loss1)
    # training_loss2.append(epoch_loss2)

    if (epoch + 1) % args.ckpt_freq_epoch == 0:
        torch.save(pointpillars.state_dict(), os.path.join(saved_ckpt_path, f'epoch_{epoch+1}.pth'))

        plt.figure()
        plt.plot(np.array(training_loss0))
        # plt.plot(np.array(training_loss1))
        # plt.plot(np.array(training_loss2))
        plt.title("Training Loss")
        plt.legend(["64 channel","32 channel","16 channel"])
        plt.show()

        plt.figure()
        plt.plot(np.array(val_loss0))
        # plt.plot(np.array(val_loss1))
        # plt.plot(np.array(val_loss2))
        plt.title("Validation Loss")
        plt.legend(["64 channel","32 channel","16 channel"])
        plt.show()

    if epoch % 2 == 0:
        continue
    pointpillars.eval()
    with torch.no_grad():
        for i, data_dict in enumerate(tqdm(val_dataloader)):
            try:
                if not args.no_cuda:
                    # move the tensors to the cuda
                    for key in data_dict:
                        for j, item in enumerate(data_dict[key]):
                            if torch.is_tensor(item):
                                data_dict[key][j] = data_dict[key][j].cuda()
                
                batched_pts = data_dict['batched_pts']
                batched_gt_bboxes = data_dict['batched_gt_bboxes']
                batched_labels = data_dict['batched_labels']
                batched_difficulty = data_dict['batched_difficulty']

                bbox_cls_pred0, bbox_pred0, bbox_dir_cls_pred0, anchor_target_dict = pointpillars(batched_pts=batched_pts, 
                                            mode='train',
                                            batched_gt_bboxes=batched_gt_bboxes, 
                                            batched_gt_labels=batched_labels)\
                    # bbox_cls_pred1, bbox_pred1, bbox_dir_cls_pred1, \
                    #     bbox_cls_pred2, bbox_pred2, bbox_dir_cls_pred2,  = \

                
                ################# Full features #################
                bbox_cls_pred0 = bbox_cls_pred0.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
                bbox_pred0 = bbox_pred0.permute(0, 2, 3, 1).reshape(-1, 7)
                bbox_dir_cls_pred0 = bbox_dir_cls_pred0.permute(0, 2, 3, 1).reshape(-1, 2)

                ################# Half features #################
                # bbox_cls_pred1 = bbox_cls_pred1.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
                # bbox_pred1 = bbox_pred1.permute(0, 2, 3, 1).reshape(-1, 7)
                # bbox_dir_cls_pred1 = bbox_dir_cls_pred1.permute(0, 2, 3, 1).reshape(-1, 2)

                # ################# Quar features #################
                # bbox_cls_pred2 = bbox_cls_pred2.permute(0, 2, 3, 1).reshape(-1, args.nclasses)
                # bbox_pred2 = bbox_pred2.permute(0, 2, 3, 1).reshape(-1, 7)
                # bbox_dir_cls_pred2 = bbox_dir_cls_pred2.permute(0, 2, 3, 1).reshape(-1, 2)


                batched_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
                batched_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
                batched_bbox_reg = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
                batched_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)


                
                pos_idx = (batched_bbox_labels >= 0) & (batched_bbox_labels < args.nclasses)

                bbox_pred0 = bbox_pred0[pos_idx]
                # bbox_pred1 = bbox_pred1[pos_idx]
                # bbox_pred2 = bbox_pred2[pos_idx]

                batched_bbox_reg = batched_bbox_reg[pos_idx]
                batched_bbox_reg0 = batched_bbox_reg.clone()
                # batched_bbox_reg1 = batched_bbox_reg.clone()
                # batched_bbox_reg2 = batched_bbox_reg.clone()

                # sin(a - b) = sin(a)*cos(b) - cos(a)*sin(b)
                bbox_pred0[:, -1] = torch.sin(bbox_pred0[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
                # bbox_pred1[:, -1] = torch.sin(bbox_pred1[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())
                # bbox_pred2[:, -1] = torch.sin(bbox_pred2[:, -1].clone()) * torch.cos(batched_bbox_reg[:, -1].clone())

                batched_bbox_reg0[:, -1] = torch.cos(bbox_pred0[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())
                # batched_bbox_reg1[:, -1] = torch.cos(bbox_pred1[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())
                # batched_bbox_reg2[:, -1] = torch.cos(bbox_pred2[:, -1].clone()) * torch.sin(batched_bbox_reg[:, -1].clone())


                bbox_dir_cls_pred0 = bbox_dir_cls_pred0[pos_idx]
                # bbox_dir_cls_pred1 = bbox_dir_cls_pred1[pos_idx]
                # bbox_dir_cls_pred2 = bbox_dir_cls_pred2[pos_idx]

                batched_dir_labels = batched_dir_labels[pos_idx]

                num_cls_pos = (batched_bbox_labels < args.nclasses).sum()

                bbox_cls_pred0 = bbox_cls_pred0[batched_label_weights > 0]
                # bbox_cls_pred1 = bbox_cls_pred1[batched_label_weights > 0]
                # bbox_cls_pred2 = bbox_cls_pred2[batched_label_weights > 0]

                batched_bbox_labels[batched_bbox_labels < 0] = args.nclasses
                batched_bbox_labels = batched_bbox_labels[batched_label_weights > 0]

                loss_dict0 = loss_func(bbox_cls_pred=bbox_cls_pred0,
                                        bbox_pred=bbox_pred0,
                                        bbox_dir_cls_pred=bbox_dir_cls_pred0,
                                        batched_labels=batched_bbox_labels, 
                                        num_cls_pos=num_cls_pos, 
                                        batched_bbox_reg=batched_bbox_reg0, 
                                        batched_dir_labels=batched_dir_labels)
                
                # loss_dict1 = loss_func(bbox_cls_pred=bbox_cls_pred1,
                #                         bbox_pred=bbox_pred1,
                #                         bbox_dir_cls_pred=bbox_dir_cls_pred1,
                #                         batched_labels=batched_bbox_labels, 
                #                         num_cls_pos=num_cls_pos, 
                #                         batched_bbox_reg=batched_bbox_reg1, 
                #                         batched_dir_labels=batched_dir_labels)
                
                # loss_dict2 = loss_func(bbox_cls_pred=bbox_cls_pred2,
                #                         bbox_pred=bbox_pred2,
                #                         bbox_dir_cls_pred=bbox_dir_cls_pred2,
                #                         batched_labels=batched_bbox_labels, 
                #                         num_cls_pos=num_cls_pos, 
                #                         batched_bbox_reg=batched_bbox_reg2, 
                #                         batched_dir_labels=batched_dir_labels)
                
                loss0 = loss_dict0['total_loss'] 
                # loss1 = loss_dict1['total_loss']
                # loss2 = loss_dict2['total_loss'] 

                val_epoch_loss0 = val_epoch_loss0 + loss0.item()
                # val_epoch_loss1 = val_epoch_loss1 + loss1.item()
                # val_epoch_loss2 = val_epoch_loss2 + loss2.item()

                global_step = epoch * len(val_dataloader) + val_step + 1
                if global_step % args.log_freq == 0:
                    save_summary(writer, loss_dict0, global_step, 'val')
                val_step += 1
            
            except:
                None

    val_loss0.append(val_epoch_loss0)
    # val_loss1.append(val_epoch_loss1)
    # val_loss2.append(val_epoch_loss2)
            
    pointpillars.train()




  0%|          | 0/1317 [00:09<?, ?it/s]


TypeError: forward() got an unexpected keyword argument 'batched_gt_bboxes'

In [None]:
plt.figure()
plt.plot(np.array(training_loss0))
# plt.plot(np.array(training_loss1))
# plt.plot(np.array(training_loss2))
plt.title("Training Loss")
plt.legend(["64 channel","32 channel","16 channel"])
plt.savefig("Matryoshka_train_loss.png")
plt.show()

plt.figure()
plt.plot(np.array(val_loss0))
# plt.plot(np.array(val_loss1))
# plt.plot(np.array(val_loss2))
plt.title("Validation Loss")
plt.legend(["64 channel","32 channel","16 channel"])
plt.savefig("Matryoshka_val_loss.png")
plt.show()