In [6]:
import torch.optim as optim
# from pathlib import Path
# from utils import config
# import torch.nn as nn
# from tqdm import tqdm
# import numpy as np
# import importlib
# import logging
# import shutil
# import spconv
# import json
# import yaml
# import time
# import torch
# import os

# from utils.evaluate_completion import get_eval_mask
# from torch.utils.checkpoint import checkpoint
# import models.model_utils as model_utils
# from utils.np_ioueval import iouEval

# args = config.cfg

In [None]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
dtype = torch.float32  # Tensor type to be used

In [None]:
def main(args):
    '''main'''

    LEARNING_RATE_CLIP = 1e-6
    MOMENTUM_ORIGINAL = 0.5
    MOMENTUM_DECCAY = 0.5
    BN_MOMENTUM_MAX = 0.001
    NUM_CLASS_SEG = args['DATA']['classes_seg']
    NUM_CLASS_COMPLET = args['DATA']['classes_completion']

    exp_name = args['log_dir']

    if exp_name is not None:
        experiment_dir = './log/' + exp_name
        experiment_dir = Path(experiment_dir)
        experiment_dir.mkdir(exist_ok=True)
        experiment_dir = str(experiment_dir)
    else:
        experiment_dir = Path('./log/')
        experiment_dir.mkdir(exist_ok=True)
        experiment_dir = experiment_dir.joinpath('temp')
        experiment_dir.mkdir(exist_ok=True)
        experiment_dir = str(experiment_dir)

    logger = logging.getLogger("Model")
    logger.setLevel(logging.INFO)

    with open(os.path.join(experiment_dir, 'args.txt'), 'w') as f:
        json.dump(args, f, indent=2)

    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
    file_handler = logging.FileHandler('%s/train.txt'%(experiment_dir))
    file_handler.setLevel(logging.INFO)
    file_handler.setFormatter(formatter)
    logger.addHandler(file_handler)

    def log_string(str):
        logger.info(str)
        print(str)

    shutil.copy('train.py', str(experiment_dir))
    shutil.copy('kitti_dataset.py', str(experiment_dir))
    # shutil.copy('carla_dataset.py', str(experiment_dir))
    shutil.copy('poss_dataset.py', str(experiment_dir))
    shutil.copy('models/model_utils.py', str(experiment_dir))
    shutil.copy('models/'+args['Segmentation']['model_name'] + '.py', str(experiment_dir))
    shutil.copy('models/'+args['Completion']['model_name'] + '.py', str(experiment_dir))

    seg_head = importlib.import_module('models.'+args['Segmentation']['model_name'])
    seg_model = seg_head.get_model

    complet_head = importlib.import_module('models.'+args['Completion']['model_name'])
    complet_model = complet_head.get_model

    if args['DATA']['dataset'] == 'SemanticKITTI':
        dataset = importlib.import_module('kitti_dataset')
    elif args['DATA']['dataset'] == 'SemanticPOSS':
        dataset = importlib.import_module('poss_dataset')
    else:
        raise TypeError

    class J3SC_Net(nn.Module):
        def __init__(self, args):
            super().__init__()
            self.args = args
            self.seg_head = seg_model(args)
            self.complet_head = complet_model(args)
            self.voxelpool = model_utils.VoxelPooling(args)
            self.seg_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True)
            self.complet_sigma = nn.Parameter(torch.Tensor(1).uniform_(0.2, 1), requires_grad=True)

        def forward(self, x):
            seg_inputs, complet_inputs, _ = x

            '''Segmentation Head'''
            seg_output, feat = self.seg_head(seg_inputs)
            torch.cuda.empty_cache()

            '''Completion Head'''
            coords = complet_inputs['complet_coords']
            coords = coords[:, [0, 3, 2, 1]]

            # if args['DATA']['dataset'] == 'SemanticKITTI':
            #     coords[:, 3] += 1  # TODO SemanticKITTI will generate [256,256,31]
            # elif args['DATA']['dataset'] == 'SemanticPOSS':
            #     coords[:, 3][coords[:, 3] > 31] = 31

            if args['Completion']['feeding'] == 'both':
                feeding = torch.cat([seg_output, feat],1)
            elif args['Completion']['feeding'] == 'feat':
                feeding = feat
            else:
                feeding = seg_output
            features = self.voxelpool(invoxel_xyz=complet_inputs['complet_invoxel_features'][:, :, :-1],
                                      invoxel_map=complet_inputs['complet_invoxel_features'][:, :, -1].long(),
                                      src_feat=feeding,
                                      voxel_center=complet_inputs['voxel_centers'])
            if self.args['Completion']['no_fuse_feat']:
                features[...] = 1
                features = features.detach()

            batch_complet = spconv.SparseConvTensor(features.float(), coords.int(), args['Completion']['full_scale'], args['TRAIN']['batch_size'])
            batch_complet = dataset.sparse_tensor_augmentation(batch_complet, complet_inputs['state'])

            if args['GENERAL']['debug']:
                model_utils.check_occupation(complet_inputs['complet_input'], batch_complet.dense())

            complet_output = self.complet_head(batch_complet)
            torch.cuda.empty_cache()

            return seg_output, complet_output, [self.seg_sigma, self.complet_sigma]

    def bn_momentum_adjust(m, momentum):
        if isinstance(m, torch.nn.BatchNorm2d) or isinstance(m, torch.nn.BatchNorm1d):
            m.momentum = momentum

    classifier = J3SC_Net(args).cuda()
    criteria = model_utils.Loss(args).cuda()