In [None]:
import numpy as np
import os
import pickle
import data_process
import open3d as o3d
import copy
import numba
import random
from ops import Voxelization, nms_cuda
import torch

from torch.utils.data import DataLoader
from models.pointpillars import PointPillars, PFNLayer
import data_process
import data_augment

from tqdm import tqdm

from models.gpu_mem_track import MemTracker

import inspect
from models.losses import Losses

In [None]:
# dataset_root = '/media/chris/Workspace/Dataset/3d-object-detection-for-autonomous-vehicles/kitti_format/'
# dataset_root = '/media/chris/Workspace/Dataset/kitti'
dataset_root = '/media/chris/Workspace/Dataset/3d-object-detection-one_scene/kitti_format'
# dataset_root = '/media/chris/Workspace/Dataset/lyft'

# dataset_root = '/media/chris/胖虎的硬盘/kitti_format'
identifier = 'train'

In [None]:
# 读取pkl文件里的数据
# data_content = data_process.read_pickle(os.path.join(dataset_root,f'lyft_infos_{identifier}.pkl'))
# database_content = data_process.read_pickle(os.path.join(dataset_root,f'lyft_dbinfos_train.pkl'))
torch.cuda.empty_cache()
num_classes = 5
batch_size = 6
num_workers = 8

### Set the random seed

In [None]:
# set the random seed
random.seed(2023)
np.random.seed(2023)
torch.manual_seed(2023)
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True       #用cudnn加速
    torch.cuda.manual_seed_all(2023)
else:
    device = torch.device("cpu")

### 获取数据增强后的数据

In [None]:
# 自定义collate_batch用于torch的Dataloader
def collate_batch(data):
    batched_pointcloud_list = []
    batched_gt_3d_bboxes_list = []
    batched_labels_list= []
    batched_names_list = []
    batched_difficulty_list = []
    batched_img_list = []
    batched_calibration_list = []
    for data_dict in data:
        cur_pc = data_dict['pc']
        cur_image_info = data_dict['img']
        cur_gt_labels = data_dict['gt_labels']
        cur_gt_names = data_dict['gt_names']
        cur_gt_bboxes_3d = data_dict['gt_bboxes_3d']
        cur_difficulty = data_dict['difficulty']
        cur_calbi_info = data_dict['calib']

        batched_pointcloud_list.append(torch.from_numpy(cur_pc))
        batched_gt_3d_bboxes_list.append(torch.from_numpy(cur_gt_bboxes_3d))
        batched_labels_list.append(torch.from_numpy(cur_gt_labels))
        batched_names_list.append(cur_gt_names) # List(str)
        batched_difficulty_list.append(torch.from_numpy(cur_difficulty))
        batched_img_list.append(cur_image_info)
        batched_calibration_list.append(cur_calbi_info)
    
    rt_data_dict = dict(
        batched_pts=batched_pointcloud_list,
        batched_img_info=batched_img_list,
        batched_labels=batched_labels_list,
        batched_names=batched_names_list,
        batched_gt_bboxes=batched_gt_3d_bboxes_list,
        batched_difficulty=batched_difficulty_list,
        batched_calib_info=batched_calibration_list
    )

    return rt_data_dict

In [None]:
# 获取数据增强后的数据
train_data = data_augment.DataSet(dataset_root=dataset_root,identifier='train')
val_data = data_augment.DataSet(dataset_root=dataset_root,identifier='val')

# print(train_data[0])

train_dataloader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True,num_workers=num_workers, drop_last=False, collate_fn=collate_batch)

val_dataloader = DataLoader(dataset=val_data, batch_size=batch_size, shuffle=True,num_workers=num_workers, drop_last=False, collate_fn=collate_batch)

### PointPillars Part

In [None]:
num_classes=3

pointpillars = PointPillars(num_classes=num_classes).cuda()

criterion = Losses()

max_iters = len(train_data) * 160


max_num_points=32
point_cloud_range=[0, -39.68, -3, 69.12, 39.68, 1]
max_voxels=(16000, 40000)
voxel_size=[0.16, 0.16, 4]

learning_rate = 0.0003

optimizer = torch.optim.AdamW(params=pointpillars.parameters(), 
                                  lr=learning_rate, 
                                  betas=(0.95, 0.99),
                                  weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,  
                                                    max_lr=learning_rate*10, 
                                                    total_steps=max_iters, 
                                                    pct_start=0.4, 
                                                    anneal_strategy='cos',
                                                    cycle_momentum=True, 
                                                    base_momentum=0.95*0.895, 
                                                    max_momentum=0.95,
                                                    div_factor=10)

import gc
# del variables
gc.collect()

if hasattr(torch.cuda, 'empty_cache'):
	torch.cuda.empty_cache()

# 存储checkpoint的路径
saved_ckpt_path = os.path.join('checkpoints')
os.makedirs(saved_ckpt_path, exist_ok=True)

### Training Part

In [None]:
def train(model, optimizer, criterion, scheduler, trainloader, device, valloader, max_epoch, ckpt_freq_epoch, saved_ckpt_path):

    for epoch in range(max_epoch):
        model = model.train()
        for i, data_dict in enumerate(tqdm(trainloader)):
            if torch.cuda.is_available():
                # move tensors to cuda
                for key in data_dict:
                    for j, item in enumerate(data_dict[key]):
                        if torch.is_tensor(item):
                            data_dict[key][j] = data_dict[key][j].cuda()

            optimizer.zero_grad()
            # 获取当前的data
            batched_pts = data_dict['batched_pts']
            batched_gt_bboxes = data_dict['batched_gt_bboxes']
            batched_labels = data_dict['batched_labels']
            # batched_difficulty = data_dict['batched_difficulty']

            # frame = inspect.currentframe()
            # gpu_tracker = MemTracker(frame)     # 创建显存检测对象
            # gpu_tracker.track()
            # pillars, coors_batch, npoints_per_pillar, features, encoded_features, backbone_result = pointpillars(batched_pts=batched_pts)
            pred_bbox_cls, pred_bbox_loc, bbox_dir_cls_pred, anchor_target_dict = model(batched_pts=batched_pts,mode='train',
                                                                                            batched_gt_bboxes=batched_gt_bboxes, 
                                                                                            batched_gt_labels=batched_labels)
            # print("bbox_cls_pred",pred_bbox_cls.size())
            # print("bbox_dir_cls_pred",bbox_dir_cls_pred.size())
            # print("bbox_pred",pred_bbox_loc.size())
            
            # 预测的box的类别
            pred_bbox_cls = pred_bbox_cls.permute(0, 2, 3, 1).reshape(-1, num_classes)
            # 预测的box参数
            pred_bbox_loc = pred_bbox_loc.permute(0, 2, 3, 1).reshape(-1, 7)
            # 预测的box方向
            bbox_dir_cls_pred = bbox_dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)

            anchor_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
            anchor_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
            anchor_bbox_loc = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
            # batched_bbox_reg_weights = anchor_target_dict['batched_bbox_reg_weights'].reshape(-1)
            anchor_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)
            # batched_dir_labels_weights = anchor_target_dict['batched_dir_labels_weights'].reshape(-1)
            
            # 预测结果在范围内
            pos_idx = (anchor_bbox_labels >= 0) & (anchor_bbox_labels < num_classes)
            pred_bbox_loc = pred_bbox_loc[pos_idx]
            anchor_bbox_loc = anchor_bbox_loc[pos_idx]

            # 来自于ground truth
            # delta_theta = sin(theta^gt-theta^anchor)      ------->        delta_theta = sin(theta^gt)*cos(theta^anchor) - cos(theta^gt)*sin(theta^anchor)
            pred_bbox_loc[:,-1] = torch.sin(pred_bbox_loc[:, -1].clone()) * torch.cos(anchor_bbox_loc[:,-1].clone())
            anchor_bbox_loc[:,-1] = torch.cos(pred_bbox_loc[:, -1].clone()) * torch.sin(anchor_bbox_loc[:,-1].clone())

            pred_bbox_cls = pred_bbox_cls[anchor_label_weights > 0]
            bbox_dir_cls_pred = bbox_dir_cls_pred[pos_idx]
            anchor_dir_labels = anchor_dir_labels[pos_idx]

            num_cls_pos = (anchor_bbox_labels < num_classes).sum()
            anchor_bbox_labels[anchor_bbox_labels < 0] = num_classes
            anchor_bbox_labels = anchor_bbox_labels[anchor_label_weights > 0]

            loss = criterion(pred_bbox_cls=pred_bbox_cls,
                                    pred_bbox_loc=pred_bbox_loc,
                                    pred_bbox_dir=bbox_dir_cls_pred,
                                    anchor_labels=anchor_bbox_labels, 
                                    num_cls_pos=num_cls_pos, 
                                    anchor_bbox_loc=anchor_bbox_loc, 
                                    anchor_bbox_dir_labels=anchor_dir_labels,
                                    num_classes=num_classes)
            loss.backward()
            optimizer.step()
            scheduler.step()

        # 20 epoch保存一次checkpoint
        if (epoch + 1) % ckpt_freq_epoch == 0:
            torch.save(pointpillars.state_dict(), os.path.join(saved_ckpt_path, f'epoch_{epoch+1}.pth'))

        if epoch % 2 == 0:
            continue
        # If you pass in a validation dataloader then compute the validation loss
        if not valloader is None:
            model.eval()
            with torch.no_grad():
                for _, data_dict in valloader:
                    if torch.cuda.is_available():
                        # move tensors to cuda
                        for key in data_dict:
                            for j, item in enumerate(data_dict[key]):
                                if torch.is_tensor(item):
                                    data_dict[key][j] = data_dict[key][j].cuda()
                    
                    optimizer.zero_grad()
                    # 获取当前的data
                    batched_pts = data_dict['batched_pts']
                    batched_gt_bboxes = data_dict['batched_gt_bboxes']
                    batched_labels = data_dict['batched_labels']
                    # batched_difficulty = data_dict['batched_difficulty']

                    # frame = inspect.currentframe()
                    # gpu_tracker = MemTracker(frame)     # 创建显存检测对象
                    # gpu_tracker.track()
                    # pillars, coors_batch, npoints_per_pillar, features, encoded_features, backbone_result = pointpillars(batched_pts=batched_pts)
                    pred_bbox_cls, pred_bbox_loc, bbox_dir_cls_pred, anchor_target_dict = model(batched_pts=batched_pts,mode='train',
                                                                                                    batched_gt_bboxes=batched_gt_bboxes, 
                                                                                                    batched_gt_labels=batched_labels)
                    
                    # 预测的box的类别
                    pred_bbox_cls = pred_bbox_cls.permute(0, 2, 3, 1).reshape(-1, num_classes)
                    # 预测的box参数
                    pred_bbox_loc = pred_bbox_loc.permute(0, 2, 3, 1).reshape(-1, 7)
                    # 预测的box方向
                    bbox_dir_cls_pred = bbox_dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)

                    anchor_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
                    anchor_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
                    anchor_bbox_loc = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
                    # batched_bbox_reg_weights = anchor_target_dict['batched_bbox_reg_weights'].reshape(-1)
                    anchor_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)
                    # batched_dir_labels_weights = anchor_target_dict['batched_dir_labels_weights'].reshape(-1)
                    
                    # 预测结果在范围内
                    pos_idx = (anchor_bbox_labels >= 0) & (anchor_bbox_labels < num_classes)
                    pred_bbox_loc = pred_bbox_loc[pos_idx]
                    anchor_bbox_loc = anchor_bbox_loc[pos_idx]

                    # 来自于ground truth
                    # delta_theta = sin(theta^gt-theta^anchor)      ------->        delta_theta = sin(theta^gt)*cos(theta^anchor) - cos(theta^gt)*sin(theta^anchor)
                    pred_bbox_loc[:,-1] = torch.sin(pred_bbox_loc[:, -1].clone()) * torch.cos(anchor_bbox_loc[:,-1].clone())
                    anchor_bbox_loc[:,-1] = torch.cos(pred_bbox_loc[:, -1].clone()) * torch.sin(anchor_bbox_loc[:,-1].clone())

                    pred_bbox_cls = pred_bbox_cls[anchor_label_weights > 0]
                    bbox_dir_cls_pred = bbox_dir_cls_pred[pos_idx]
                    anchor_dir_labels = anchor_dir_labels[pos_idx]

                    num_cls_pos = (anchor_bbox_labels < num_classes).sum()
                    anchor_bbox_labels[anchor_bbox_labels < 0] = num_classes
                    anchor_bbox_labels = anchor_bbox_labels[anchor_label_weights > 0]

                    loss = criterion(pred_bbox_cls=pred_bbox_cls,
                                            pred_bbox_loc=pred_bbox_loc,
                                            pred_bbox_dir=bbox_dir_cls_pred,
                                            anchor_labels=anchor_bbox_labels, 
                                            num_cls_pos=num_cls_pos, 
                                            anchor_bbox_loc=anchor_bbox_loc, 
                                            anchor_bbox_dir_labels=anchor_dir_labels,
                                            num_classes=num_classes)
    return model

In [None]:
# if runtime has GPU use GPU
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
print("Using device:", device)

In [None]:
# # model, optimizer, criterion, scheduler, trainloader, device, valloader, max_epoch, ckpt_freq_epoch, saved_ckpt_path
# model = train(model=pointpillars, 
#               optimizer=optimizer, 
#               criterion=criterion,
#               scheduler=scheduler, 
#               trainloader=train_dataloader, 
#               device=device, 
#               valloader=val_dataloader, 
#               max_epoch=30,
#               ckpt_freq_epoch=20,
#               saved_ckpt_path=saved_ckpt_path)

In [None]:
max_epoch = 1
no_cuda = False
for epoch in range(max_epoch):
    print(epoch)
    for i, data_dict in enumerate(tqdm(train_dataloader)):
        if torch.cuda.is_available():
            # move the tensors to the cuda
            for key in data_dict:
                for j, item in enumerate(data_dict[key]):
                    if torch.is_tensor(item):
                        data_dict[key][j] = data_dict[key][j].cuda()
                        
        optimizer.zero_grad()
        # 获取当前的data
        batched_pts = data_dict['batched_pts']
        batched_gt_bboxes = data_dict['batched_gt_bboxes']
        batched_labels = data_dict['batched_labels']
        batched_difficulty = data_dict['batched_difficulty']

        # frame = inspect.currentframe()
        # gpu_tracker = MemTracker(frame)     # 创建显存检测对象
        # gpu_tracker.track()
        # pillars, coors_batch, npoints_per_pillar, features, encoded_features, backbone_result = pointpillars(batched_pts=batched_pts)
        pred_bbox_cls, pred_bbox_loc, bbox_dir_cls_pred, anchor_target_dict = pointpillars(batched_pts=batched_pts,mode='train',
                                                                                        batched_gt_bboxes=batched_gt_bboxes, 
                                                                                        batched_gt_labels=batched_labels)
        # print("bbox_cls_pred",pred_bbox_cls.size())
        # print("bbox_dir_cls_pred",bbox_dir_cls_pred.size())
        # print("bbox_pred",pred_bbox_loc.size())
        
        # 预测的box的类别
        pred_bbox_cls = pred_bbox_cls.permute(0, 2, 3, 1).reshape(-1, num_classes)
        # 预测的box参数
        pred_bbox_loc = pred_bbox_loc.permute(0, 2, 3, 1).reshape(-1, 7)
        # 预测的box方向
        bbox_dir_cls_pred = bbox_dir_cls_pred.permute(0, 2, 3, 1).reshape(-1, 2)

        anchor_bbox_labels = anchor_target_dict['batched_labels'].reshape(-1)
        anchor_label_weights = anchor_target_dict['batched_label_weights'].reshape(-1)
        anchor_bbox_loc = anchor_target_dict['batched_bbox_reg'].reshape(-1, 7)
        # batched_bbox_reg_weights = anchor_target_dict['batched_bbox_reg_weights'].reshape(-1)
        anchor_dir_labels = anchor_target_dict['batched_dir_labels'].reshape(-1)
        # batched_dir_labels_weights = anchor_target_dict['batched_dir_labels_weights'].reshape(-1)
        
        # 预测结果在范围内
        pos_idx = (anchor_bbox_labels >= 0) & (anchor_bbox_labels < num_classes)
        pred_bbox_loc = pred_bbox_loc[pos_idx]
        anchor_bbox_loc = anchor_bbox_loc[pos_idx]

        # 来自于ground truth
        # delta_theta = sin(theta^gt-theta^anchor)      ------->        delta_theta = sin(theta^gt)*cos(theta^anchor) - cos(theta^gt)*sin(theta^anchor)
        pred_bbox_loc[:,-1] = torch.sin(pred_bbox_loc[:, -1].clone()) * torch.cos(anchor_bbox_loc[:,-1].clone())
        anchor_bbox_loc[:,-1] = torch.cos(pred_bbox_loc[:, -1].clone()) * torch.sin(anchor_bbox_loc[:,-1].clone())

        pred_bbox_cls = pred_bbox_cls[anchor_label_weights > 0]
        bbox_dir_cls_pred = bbox_dir_cls_pred[pos_idx]
        anchor_dir_labels = anchor_dir_labels[pos_idx]

        num_cls_pos = (anchor_bbox_labels < num_classes).sum()
        anchor_bbox_labels[anchor_bbox_labels < 0] = num_classes
        anchor_bbox_labels = anchor_bbox_labels[anchor_label_weights > 0]

        loss_dict = criterion(pred_bbox_cls=pred_bbox_cls,
                                  pred_bbox_loc=pred_bbox_loc,
                                  pred_bbox_dir=bbox_dir_cls_pred,
                                  anchor_labels=anchor_bbox_labels, 
                                  num_cls_pos=num_cls_pos, 
                                  anchor_bbox_loc=anchor_bbox_loc, 
                                  anchor_bbox_dir_labels=anchor_dir_labels,
                                  num_classes=num_classes)
        loss = loss_dict['total_loss']
        loss.backward()
        optimizer.step()
        scheduler.step()
        # gpu_tracker.track()
        