In [14]:
import torch.optim as optim
from pathlib import Path
from utils import config
# cfg = get_parser()
import torch.nn as nn
from tqdm import tqdm
import numpy as np
import importlib
import logging
import shutil
import spconv
import json
import yaml
import time
import torch
import os

from utils.evaluate_completion import get_eval_mask
from torch.utils.checkpoint import checkpoint
import models.model_utils as model_utils
from utils.np_ioueval import iouEval
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter

args = config.cfg

In [15]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float32  # Tensor type to be used
os.environ["CUDA_VISIBLE_DEVICES"] = args['gpu']



LEARNING_RATE_CLIP = 1e-6
MOMENTUM_ORIGINAL = 0.5
MOMENTUM_DECCAY = 0.5
BN_MOMENTUM_MAX = 0.001
NUM_CLASS_SEG = args['DATA']['classes_seg']
NUM_CLASS_COMPLET = args['DATA']['classes_completion']

exp_name = args['log_dir']

seg_head = importlib.import_module('models.'+args['Segmentation']['model_name'])
seg_model = seg_head.get_model

complet_head = importlib.import_module('models.'+args['Completion']['model_name'])
complet_model = complet_head.get_model

print(args['DATA']['dataset'])
if args['DATA']['dataset'] == 'SemanticKITTI':
    dataset = importlib.import_module('kitti_dataset')

SemanticKITTI


In [20]:
data_dir = '/media/neofelis/Jingyu-SSD/carla_dataset/Train'

coordinate_type = "cartesian"
cylindrical = coordinate_type=="cylindrical"
T = 1 #single sweep
B = 4 # Matching paper

from dataset import CarlaDataset, collate_fn_test

carla_ds = CarlaDataset(directory=data_dir, device=device, num_frames=T, cylindrical=cylindrical)
train_dataloader = DataLoader(carla_ds, batch_size=B, shuffle=True, collate_fn=collate_fn_test)

In [18]:
class J3SC_Net(nn.Module):
    def __init__(self, args):
        super().__init__()
        self.args = args
        self.seg_head = seg_model(args)
        self.complet_head = complet_model(args)
        self.voxelpool = model_utils.VoxelPooling(args)
        self.seg_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True)
        self.complet_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True)

    def forward(self, x):
        seg_inputs, complet_inputs, _ = x

        '''Segmentation Head'''
        seg_output, feat = self.seg_head(seg_inputs)
        torch.cuda.empty_cache()

        '''Completion Head'''
        coords = complet_inputs['complet_coords']
        coords = coords[:, [0, 3, 2, 1]]

        # if args['DATA']['dataset'] == 'SemanticKITTI':
        #     coords[:, 3] += 1  # TODO SemanticKITTI will generate [256,256,31]
        # elif args['DATA']['dataset'] == 'SemanticPOSS':
        #     coords[:, 3][coords[:, 3] > 31] = 31

        if args['Completion']['feeding'] == 'both':
            feeding = torch.cat([seg_output, feat],1)
        elif args['Completion']['feeding'] == 'feat':
            feeding = feat
        else:
            feeding = seg_output
        features = self.voxelpool(invoxel_xyz=complet_inputs['complet_invoxel_features'][:, :, :-1],
                                    invoxel_map=complet_inputs['complet_invoxel_features'][:, :, -1].long(),
                                    src_feat=feeding,
                                    voxel_center=complet_inputs['voxel_centers'])
        if self.args['Completion']['no_fuse_feat']:
            features[...] = 1
            features = features.detach()

        batch_complet = spconv.SparseConvTensor(features.float(), coords.int(), args['Completion']['full_scale'], args['TRAIN']['batch_size'])
        batch_complet = dataset.sparse_tensor_augmentation(batch_complet, complet_inputs['state'])

        if args['GENERAL']['debug']:
            model_utils.check_occupation(complet_inputs['complet_input'], batch_complet.dense())

        complet_output = self.complet_head(batch_complet)
        torch.cuda.empty_cache()

        return seg_output, complet_output, [self.seg_sigma, self.complet_sigma]

def bn_momentum_adjust(m, momentum):
    if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
        m.momentum = momentum



In [25]:
classifier = J3SC_Net(args).cuda()
criteria = model_utils.Loss(args).cuda()

training_epochs = args['TRAIN']['epochs']
# training_epoch = model_utils.checkpoint_restore(classifier, experiment_dir, True, train_from=args['TRAIN']['train_from'])
optimizer = optim.Adam(classifier.parameters(), lr=args['TRAIN']['learning_rate'], weight_decay=1e-4)

global_epoch = 0
best_iou_sem_complt = 0
best_iou_complt = 0
best_iou_seg = 0

for epoch in range(training_epochs):
    classifier.train()
    lr = max(args['TRAIN']['learning_rate'] * (args['TRAIN']['lr_decay'] ** (epoch // args['TRAIN']['decay_step'])), LEARNING_RATE_CLIP)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    momentum = max(MOMENTUM_ORIGINAL * (MOMENTUM_DECCAY ** (epoch // args['TRAIN']['decay_step'])), BN_MOMENTUM_MAX)
    if momentum < 0.01:
        momentum = 0.01
    classifier = classifier.apply(lambda x: bn_momentum_adjust(x, momentum))
    train_loss = 0
    with tqdm(total=len(train_dataloader)) as pbar:
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            # seg_label = batch
            print(len(batch[0]))


    

  0%|          | 0/8100 [00:00<?, ?it/s]

4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
4
