# DenseFusion

Code from https://github.com/j96w/DenseFusion

Train 코드를 살펴본 후, Dataset, Model, Loss 순서로 알아본다.

## Train

In [None]:
# ./experiments/scripts/train_ycb.sh
"""
experiments 폴더에는 dataset마다 train과 evaluation을 위한 shell script가 담겨있다. 
아래는 ycb video dataset을 사용해 train하는 shell script 예시이다.
"""

#!/bin/bash

set -x
set -e

export PYTHONUNBUFFERED="True"
export CUDA_VISIBLE_DEVICES=0

# dataset의 종류와 경로를 지정하여 train.py를 실행
python3 ./tools/train.py --dataset ycb\ 
  --dataset_root ./datasets/ycb/YCB_Video_Dataset

In [None]:
# ./tools/train.py

# --------------------------------------------------------
# DenseFusion 6D Object Pose Estimation by Iterative Dense Fusion
# Licensed under The MIT License [see LICENSE for details]
# Written by Chen
# --------------------------------------------------------

import _init_paths
import argparse
import os
import random
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets.ycb.dataset import PoseDataset as PoseDataset_ycb
from datasets.linemod.dataset import PoseDataset as PoseDataset_linemod
from lib.network import PoseNet, PoseRefineNet
from lib.loss import Loss
from lib.loss_refiner import Loss_refine
from lib.utils import setup_logger

parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default = 'ycb', help='ycb or linemod')
parser.add_argument('--dataset_root', type=str, default = '', help='dataset root dir (''YCB_Video_Dataset'' or ''Linemod_preprocessed'')')
parser.add_argument('--batch_size', type=int, default = 8, help='batch size')
parser.add_argument('--workers', type=int, default = 10, help='number of data loading workers')
parser.add_argument('--lr', default=0.0001, help='learning rate')
parser.add_argument('--lr_rate', default=0.3, help='learning rate decay rate')
parser.add_argument('--w', default=0.015, help='learning rate') # confidence regularization을 위해 loss에서 사용하는 balancing hyperparameter
parser.add_argument('--w_rate', default=0.3, help='learning rate decay rate')
parser.add_argument('--decay_margin', default=0.016, help='margin to decay lr & w')
parser.add_argument('--refine_margin', default=0.013, help='margin to start the training of iterative refinement') # ADD가 1.3cm보다 작아지면 refinement model로 학습
parser.add_argument('--noise_trans', default=0.03, help='range of the random noise of translation added to the training data') # 최대 3cm의 random translation
parser.add_argument('--iteration', type=int, default = 2, help='number of refinement iterations')
parser.add_argument('--nepoch', type=int, default=500, help='max number of epochs to train')
parser.add_argument('--resume_posenet', type=str, default = '',  help='resume PoseNet model')
parser.add_argument('--resume_refinenet', type=str, default = '',  help='resume PoseRefineNet model')
parser.add_argument('--start_epoch', type=int, default = 1, help='which epoch to start')
opt = parser.parse_args()


def main():
    # Random seed 고정
    opt.manualSeed = random.randint(1, 10000)
    random.seed(opt.manualSeed)
    torch.manual_seed(opt.manualSeed)

    # 각 dataset(ycb/linemod)에 맞게 설정
    if opt.dataset == 'ycb':
        opt.num_objects = 21 #number of object classes in the dataset
        opt.num_points = 1000 #number of points on the input pointcloud
        opt.outf = 'trained_models/ycb' #folder to save trained models
        opt.log_dir = 'experiments/logs/ycb' #folder to save logs
        opt.repeat_epoch = 1 #number of repeat times for one epoch training
    elif opt.dataset == 'linemod':
        opt.num_objects = 13
        opt.num_points = 500
        opt.outf = 'trained_models/linemod'
        opt.log_dir = 'experiments/logs/linemod'
        opt.repeat_epoch = 20
    else:
        print('Unknown dataset')
        return

    # DenseFusion 모델 정의(Refinement 없는 모델 : estimator, Refinement 있는 모델 : refiner)
    estimator = PoseNet(num_points = opt.num_points, num_obj = opt.num_objects)
    estimator.cuda()
    refiner = PoseRefineNet(num_points = opt.num_points, num_obj = opt.num_objects)
    refiner.cuda()

    """ 학습에 3가지 경우가 존재
    1. 처음부터 학습
    2. 저장해둔 estimator에서 다시 시작
    3. 저장해둔 refiner에서 다시 시작
    # 기존에 학습하던 모델의 학습을 재개하는 경우, 모델 파라미터 load
    """
    if opt.resume_posenet != '': # 2
        estimator.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_posenet)))

    if opt.resume_refinenet != '': # 3
        refiner.load_state_dict(torch.load('{0}/{1}'.format(opt.outf, opt.resume_refinenet)))
        opt.refine_start = True
        opt.decay_start = True
        opt.lr *= opt.lr_rate
        opt.w *= opt.w_rate
        opt.batch_size = int(opt.batch_size / opt.iteration) 
        optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)  
    else: # 1, 2
        opt.refine_start = False
        opt.decay_start = False
        optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

    # dataset, dataloader 정의
    # Train
    if opt.dataset == 'ycb':
        dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
    elif opt.dataset == 'linemod':
        dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
    # Validation
    if opt.dataset == 'ycb':
        test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    elif opt.dataset == 'linemod':
        test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
    testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
    
    opt.sym_list = dataset.get_sym_list() # 대칭적인 물체의 index를 반환
    opt.num_points_mesh = dataset.get_num_points_mesh() # Loss에 사용할 point_mesh 개수를 반환(refine x -> num_pt_mesh_small, refine o -> num_pt_mesh_large)

    print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

    # Loss 정의
    criterion = Loss(opt.num_points_mesh, opt.sym_list)
    criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

    best_test = np.Inf

    # 처음부터 학습한다면 기존 log 파일들을 삭제
    if opt.start_epoch == 1:
        for log in os.listdir(opt.log_dir):
            os.remove(os.path.join(opt.log_dir, log))

    # 시작 시간
    st_time = time.time()

    for epoch in range(opt.start_epoch, opt.nepoch):
        # Train
        logger = setup_logger('epoch%d' % epoch, os.path.join(opt.log_dir, 'epoch_%d_log.txt' % epoch))
        logger.info('Train time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Training started'))
        train_count = 0
        train_dis_avg = 0.0 # ADD
        if opt.refine_start: # 3
            estimator.eval() # estimator 평가모드
            refiner.train() # refiner 학습모드
        else: # 1, 2
            estimator.train() # estimator 학습모드
        optimizer.zero_grad()

        for rep in range(opt.repeat_epoch):
            for i, data in enumerate(dataloader, 0):
                points, choose, img, target, model_points, idx = data
                points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                                 Variable(choose).cuda(), \
                                                                 Variable(img).cuda(), \
                                                                 Variable(target).cuda(), \
                                                                 Variable(model_points).cuda(), \
                                                                 Variable(idx).cuda()
                ##### estimator input #####
                # img : image crop
                # points : masked point cloud
                # choose : point cloud의 각 point가 crop한 이미지의 어느 위치에 해당하는지를 나타내는 index(crop image를 1D로 flatten했을 때를 기준으로)
                # idx : 예측하려는 객체의 class index
                pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx) # estimator forward pass -> 각 point에서의 quaternion, translation, confidence와 color embedding 반환
                
                ##### estimator loss input #####
                # pred_r : rotation prediction per pixel
                # pred_t : translation prediction per pixel
                # pred_c : confidence prediction per pixel
                # target : target pose로 변환한 3D model_points
                # model_points : 변환하기 전 3D model_points
                loss, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)
                
                if opt.refine_start: # 3
                    for ite in range(0, opt.iteration): # refinement 반복
                        pred_r, pred_t = refiner(new_points, emb, idx) # refiner forward pass
                        dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)
                        dis.backward() # refiner backward pass
                else: # 1, 2
                    loss.backward() # estimator backward pass

                train_dis_avg += dis.item()
                train_count += 1

                if train_count % opt.batch_size == 0:
                    logger.info('Train time {0} Epoch {1} Batch {2} Frame {3} Avg_dis:{4}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, int(train_count / opt.batch_size), train_count, train_dis_avg / opt.batch_size))
                    optimizer.step()
                    optimizer.zero_grad()
                    train_dis_avg = 0

                if train_count != 0 and train_count % 1000 == 0:
                    if opt.refine_start:
                        torch.save(refiner.state_dict(), '{0}/pose_refine_model_current.pth'.format(opt.outf))
                    else:
                        torch.save(estimator.state_dict(), '{0}/pose_model_current.pth'.format(opt.outf))

        print('>>>>>>>>----------epoch {0} train finish---------<<<<<<<<'.format(epoch))

        # Validation
        logger = setup_logger('epoch%d_test' % epoch, os.path.join(opt.log_dir, 'epoch_%d_test_log.txt' % epoch))
        logger.info('Test time {0}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)) + ', ' + 'Testing started'))
        test_dis = 0.0
        test_count = 0
        estimator.eval() # estimator 평가모드
        refiner.eval() # refiner 평가모드

        for j, data in enumerate(testdataloader, 0):
            points, choose, img, target, model_points, idx = data
            points, choose, img, target, model_points, idx = Variable(points).cuda(), \
                                                             Variable(choose).cuda(), \
                                                             Variable(img).cuda(), \
                                                             Variable(target).cuda(), \
                                                             Variable(model_points).cuda(), \
                                                             Variable(idx).cuda()
            pred_r, pred_t, pred_c, emb = estimator(img, points, choose, idx) 
            _, dis, new_points, new_target = criterion(pred_r, pred_t, pred_c, target, model_points, idx, points, opt.w, opt.refine_start)

            if opt.refine_start: # 3
                for ite in range(0, opt.iteration):
                    pred_r, pred_t = refiner(new_points, emb, idx)
                    dis, new_points, new_target = criterion_refine(pred_r, pred_t, new_target, model_points, idx, new_points)

            test_dis += dis.item()
            logger.info('Test time {0} Test Frame No.{1} dis:{2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), test_count, dis))

            test_count += 1

        test_dis = test_dis / test_count
        logger.info('Test time {0} Epoch {1} TEST FINISH Avg dis: {2}'.format(time.strftime("%Hh %Mm %Ss", time.gmtime(time.time() - st_time)), epoch, test_dis))
        if test_dis <= best_test:
            best_test = test_dis
            if opt.refine_start: # 3
                torch.save(refiner.state_dict(), '{0}/pose_refine_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
            else: # 1, 2
                torch.save(estimator.state_dict(), '{0}/pose_model_{1}_{2}.pth'.format(opt.outf, epoch, test_dis))
            print(epoch, '>>>>>>>>----------BEST TEST MODEL SAVED---------<<<<<<<<')

        # decay margin에 도달하면 learning rate 감소(scheduler 역할)
        if best_test < opt.decay_margin and not opt.decay_start:
            opt.decay_start = True
            opt.lr *= opt.lr_rate
            opt.w *= opt.w_rate
            optimizer = optim.Adam(estimator.parameters(), lr=opt.lr)

        # refine margin에 도달하면 refiner로 학습 모델 전환
        if best_test < opt.refine_margin and not opt.refine_start:
            opt.refine_start = True
            opt.batch_size = int(opt.batch_size / opt.iteration)
            optimizer = optim.Adam(refiner.parameters(), lr=opt.lr)

            # refiner에서 dataset, dataloader 새로 정의(3D model point 개수가 달라짐)
            # Train
            if opt.dataset == 'ycb':
                dataset = PoseDataset_ycb('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
            elif opt.dataset == 'linemod':
                dataset = PoseDataset_linemod('train', opt.num_points, True, opt.dataset_root, opt.noise_trans, opt.refine_start)
            dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.workers)
            # Validation
            if opt.dataset == 'ycb':
                test_dataset = PoseDataset_ycb('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
            elif opt.dataset == 'linemod':
                test_dataset = PoseDataset_linemod('test', opt.num_points, False, opt.dataset_root, 0.0, opt.refine_start)
            testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=opt.workers)
            
            opt.sym_list = dataset.get_sym_list()
            opt.num_points_mesh = dataset.get_num_points_mesh()

            print('>>>>>>>>----------Dataset loaded!---------<<<<<<<<\nlength of the training set: {0}\nlength of the testing set: {1}\nnumber of sample points on mesh: {2}\nsymmetry object list: {3}'.format(len(dataset), len(test_dataset), opt.num_points_mesh, opt.sym_list))

            criterion = Loss(opt.num_points_mesh, opt.sym_list)
            criterion_refine = Loss_refine(opt.num_points_mesh, opt.sym_list)

if __name__ == '__main__':
    main()


## Dataset

In [None]:
# ./datasets/ycb/dataset.py
"""
학습과 평가에 사용할 dataset class를 정의한다.
아래는 ycb video dataset의 예시이다.
"""

import torch.utils.data as data
from PIL import Image
import os
import os.path
import torch
import numpy as np
import torchvision.transforms as transforms
import argparse
import time
import random
from lib.transformations import quaternion_from_euler, euler_matrix, random_quaternion, quaternion_matrix
import numpy.ma as ma
import copy
import scipy.misc
import scipy.io as scio


class PoseDataset(data.Dataset):
    def __init__(self, mode, num_pt, add_noise, root, noise_trans, refine):
        if mode == 'train':
            self.path = 'datasets/ycb/dataset_config/train_data_list.txt' # ex) data/0000/000000 or data_syn/000000
        elif mode == 'test':
            self.path = 'datasets/ycb/dataset_config/test_data_list.txt'
        self.num_pt = num_pt # 1000
        self.root = root # ./datasets/ycb/YCB_Video_Dataset
        self.add_noise = add_noise # True
        self.noise_trans = noise_trans # 0.03(=3cm)

        # Data list를 만드는 과정
        self.list = []
        self.real = [] 
        self.syn = [] 
        input_file = open(self.path)
        while 1:
            input_line = input_file.readline()
            if not input_line:
                break
            if input_line[-1:] == '\n': # \n 제거
                input_line = input_line[:-1]

            # 실제 or 합성 데이터
            if input_line[:5] == 'data/':
                self.real.append(input_line)
            else: 
                self.syn.append(input_line)
            # 전체 데이터
            self.list.append(input_line)
        input_file.close()

        self.length = len(self.list)
        self.len_real = len(self.real)
        self.len_syn = len(self.syn)

        class_file = open('datasets/ycb/dataset_config/classes.txt') # ex) 002_master_chef_can
        class_id = 1
        self.cld = {}

        # 클래스 별로 3D model points를 읽고 dict( ex) 1 : [[x1, y1, z1], [x2, y2, z2] ...] )로 저장
        while 1:
            class_input = class_file.readline()
            if not class_input:
                break

            input_file = open('{0}/models/{1}/points.xyz'.format(self.root, class_input[:-1]))
            self.cld[class_id] = []
            while 1:
                input_line = input_file.readline()
                if not input_line:
                    break
                input_line = input_line[:-1].split(' ')
                self.cld[class_id].append([float(input_line[0]), float(input_line[1]), float(input_line[2])])
            self.cld[class_id] = np.array(self.cld[class_id])
            input_file.close()
            
            class_id += 1

        self.cam_cx_1 = 312.9869
        self.cam_cy_1 = 241.3109
        self.cam_fx_1 = 1066.778
        self.cam_fy_1 = 1067.487

        self.cam_cx_2 = 323.7872
        self.cam_cy_2 = 279.6921
        self.cam_fx_2 = 1077.836
        self.cam_fy_2 = 1078.189

        self.xmap = np.array([[j for i in range(640)] for j in range(480)])
        self.ymap = np.array([[i for i in range(640)] for j in range(480)])
        
        self.trancolor = transforms.ColorJitter(0.2, 0.2, 0.2, 0.05) # brightness[0.8, 1.2], contrast[0.8, 1.2], saturation[0.8, 1.2], hue[-0.05, 0.05]
        self.noise_img_loc = 0.0
        self.noise_img_scale = 7.0
        self.minimum_num_pt = 50
        self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.symmetry_obj_idx = [12, 15, 18, 19, 20]
        self.num_pt_mesh_small = 500 
        self.num_pt_mesh_large = 2600
        self.refine = refine
        self.front_num = 2

        print(len(self.list))

    def __getitem__(self, index):
        # 현재 index에 해당하는 color, depth, label 및 meta data load
        img = Image.open('{0}/{1}-color.png'.format(self.root, self.list[index]))
        depth = np.array(Image.open('{0}/{1}-depth.png'.format(self.root, self.list[index])))
        label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, self.list[index])))
        meta = scio.loadmat('{0}/{1}-meta.mat'.format(self.root, self.list[index]))

        # 각 이미지에 맞는 camera parameter 지정
        if self.list[index][:8] != 'data_syn' and int(self.list[index][5:9]) >= 60: # 실제 데이터 중 60번 이상은 cam2
            cam_cx = self.cam_cx_2
            cam_cy = self.cam_cy_2
            cam_fx = self.cam_fx_2
            cam_fy = self.cam_fy_2
        else:
            cam_cx = self.cam_cx_1
            cam_cy = self.cam_cy_1
            cam_fx = self.cam_fx_1
            cam_fy = self.cam_fy_1

        # Preprocessing

        """ Data Augmentation
        add_noise가 True인 경우, 다음 세 가지 증강을 실시
        (1) ColorJitter
        (2) Random pose translation
        (3) 데이터 합성 
            (3-1) 배경이 존재하지 않는 합성 데이터 뒤에 랜덤으로 선택한 실제 데이터를 배경으로 추가
            (3-2) 데이터 앞에 합성 데이터의 객체를 추가
        """

        # (3-1) mask -> label에서 값이 0인 부분을 True, 나머지를 False로 masking한 array
        mask_back = ma.getmaskarray(ma.masked_equal(label, 0))

        # (3-2) mask와 뒤쪽에 들어갈 label 준비
        add_front = False
        if self.add_noise:
            for k in range(5): # 최대 5번 시도
                seed = random.choice(self.syn) # 랜덤으로 합성 데이터 하나를 선택
                front = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, seed)).convert("RGB"))) # (1)
                front = np.transpose(front, (2, 0, 1)) # 채널 순서 (H, W, C) -> (C, H, W)
                f_label = np.array(Image.open('{0}/{1}-label.png'.format(self.root, seed))) # Label 이미지
                front_label = np.unique(f_label).tolist()[1:] # np.unique를 통해 해당 합성 데이터의 label list 저장
                if len(front_label) < self.front_num: # 랜덤으로 선택한 합성 데이터의 label 개수가 front_num보다 작을 경우 다시
                   continue
                front_label = random.sample(front_label, self.front_num) # label 중에서 front_num 개수만큼 랜덤 선택
                # 합성 데이터의 front_label에 속하지 않는 배경 부분을 추출하는 과정
                for f_i in front_label:
                    mk = ma.getmaskarray(ma.masked_not_equal(f_label, f_i)) # Label 이미지에서 f_i에 해당하면 False, 나머지는 True
                    # mk들을 곱해 공통 배경 부분만 True로 남겨둠
                    if f_i == front_label[0]: 
                        mask_front = mk
                    else:  
                        mask_front = mask_front * mk

                t_label = label * mask_front # label에 합성 데이터의 배경 array를 곱해 합성 데이터의 객체에 의해 가려진 실제 데이터를 표현
                if len(t_label.nonzero()[0]) > 1000: # 기존 데이터가 충분히 포함되어 있으면 채택
                    label = t_label
                    add_front = True
                    break
        
        # Class indexes
        obj = meta['cls_indexes'].flatten().astype(np.int32)

        # 특정 객체 영역만 crop하여 모델에 넣기 위한 과정
        while 1:
            idx = np.random.randint(0, len(obj)) # 랜덤으로 객체 선택
            mask_depth = ma.getmaskarray(ma.masked_not_equal(depth, 0)) # Depth 이미지에서 값이 0인 부분을 False, 나머지는 True로 하여 깊이가 유효한 영역 선택
            mask_label = ma.getmaskarray(ma.masked_equal(label, obj[idx])) # 객체에 해당하면 True, 나머지는 False
            mask = mask_label * mask_depth # 선택한 객체에 해당하는 영역과 유효한 depth 영역을 곱
            if len(mask.nonzero()[0]) > self.minimum_num_pt: # 객체 point가 최소 개수를 충족하면 break
                break
        
        # (1)
        if self.add_noise:
            img = self.trancolor(img)

        # mask_label에서 bbox 추출
        rmin, rmax, cmin, cmax = get_bbox(mask_label)
        img = np.transpose(np.array(img)[:, :, :3], (2, 0, 1))[:, rmin:rmax, cmin:cmax] # Color 이미지에서 RGB 부분 선택 -> (C, H, W)로 변경 -> bbox 영역만 crop


        """ 변수 되짚어보기
        mask_back : 전체 이미지에서 배경만 True인 mask 
        mask_label : 전체 이미지에서 객체 영역만 True인 mask
        mask : mask_label에서 depth도 유효한 영역만 True인 mask
        
        if add_noise is True ...
        mask_front : 전체 이미지에서 합성 데이터가 들어갈 영역만 제외하고 전부 True인 mask
        """

        # (3-1)
        if self.list[index][:8] == 'data_syn':
            seed = random.choice(self.real) # 랜덤으로 실제 데이터 중 하나를 선택
            back = np.array(self.trancolor(Image.open('{0}/{1}-color.png'.format(self.root, seed)).convert("RGB")))
            back = np.transpose(back, (2, 0, 1))[:, rmin:rmax, cmin:cmax] # 실제 데이터에서 bbox영역만 crop
            img_masked = back * mask_back[rmin:rmax, cmin:cmax] + img # 합성 데이터의 bbox부분에서 배경 부분은 실제 데이터의 값으로 채움
        else: # 실제 데이터는 그대로
            img_masked = img

        # (3-2) bbox 영역에 합성 데이터 추가 (합성 데이터에 합성 데이터를 추가하는 경우, 배경이 채워지지 않으므로 실제 데이터로 배경을 채우는 과정을 우선하고 더해줌)
        if self.add_noise and add_front:
            img_masked = img_masked * mask_front[rmin:rmax, cmin:cmax] + front[:, rmin:rmax, cmin:cmax] * ~(mask_front[rmin:rmax, cmin:cmax])

        if self.list[index][:8] == 'data_syn':
            img_masked = img_masked + np.random.normal(loc=0.0, scale=7.0, size=img_masked.shape)
        
        # => 여기까지가 crop한 color image인 img_masked를 뽑는 과정
        
        # img_masked 저장
        # p_img = np.transpose(img_masked, (1, 2, 0))
        # scipy.misc.imsave('temp/{0}_input.png'.format(index), p_img)
        # scipy.misc.imsave('temp/{0}_label.png'.format(index), mask[rmin:rmax, cmin:cmax].astype(np.int32))

        # Target pose의 transformation matrix
        target_r = meta['poses'][:, :, idx][:, 0:3] # target rotation
        target_t = np.array([meta['poses'][:, :, idx][:, 3:4].flatten()]) # target translation
        add_t = np.array([random.uniform(-self.noise_trans, self.noise_trans) for i in range(3)]) # (2)

        # num_pt 개수만큼 point를 추출하는 과정
        choose = mask[rmin:rmax, cmin:cmax].flatten().nonzero() # bbox 영역에 있는 유효한 객체 point 
        if len(choose) > self.num_pt: # 객체 point 개수가 num_pt보다 많으면 num_pt만큼 랜덤 선택
            c_mask = np.zeros(len(choose), dtype=int)
            c_mask[:self.num_pt] = 1
            np.random.shuffle(c_mask)
            choose = choose[c_mask.nonzero()]
        else: # 객체 point 개수가 num_pt 이하면 부족한 개수만큼 padding(중복 추출)
            choose = np.pad(choose, (0, self.num_pt - len(choose)), 'wrap')
        
        # Depth 이미지에서 depth 값을, xmap, ymap에서 location을 추출하여 point cloud 생성 (ex) depth:10, xmap:2, ymap:3 -> Image coordinate의 [2, 3] 위치에서 깊이 10)
        depth_masked = depth[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
        xmap_masked = self.xmap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
        ymap_masked = self.ymap[rmin:rmax, cmin:cmax].flatten()[choose][:, np.newaxis].astype(np.float32)
        choose = np.array([choose])

        # Image coordinate을 camera coordinate으로 변환
        cam_scale = meta['factor_depth'][0][0]
        pt2 = depth_masked / cam_scale
        pt0 = (ymap_masked - cam_cx) * pt2 / cam_fx
        pt1 = (xmap_masked - cam_cy) * pt2 / cam_fy
        cloud = np.concatenate((pt0, pt1, pt2), axis=1)
        if self.add_noise: 
            cloud = np.add(cloud, add_t) # (2)

        # => 여기까지가 masked point cloud인 cloud를 뽑는 과정

        # cloud 저장
        # fw = open('temp/{0}_cld.xyz'.format(index), 'w')
        # for it in cloud:
        #    fw.write('{0} {1} {2}\n'.format(it[0], it[1], it[2]))
        # fw.close()

        # 3D model points에서 필요한 개수(num_points_mesh)만 랜덤 선택
        dellist = [j for j in range(0, len(self.cld[obj[idx]]))] # 해당 객체의 3D model points 개수로 list 생성
        if self.refine:
            dellist = random.sample(dellist, len(self.cld[obj[idx]]) - self.num_pt_mesh_large)
        else:
            dellist = random.sample(dellist, len(self.cld[obj[idx]]) - self.num_pt_mesh_small)
        model_points = np.delete(self.cld[obj[idx]], dellist, axis=0) # 필요없는 개수만큼 삭제

        # => 여기까지가 3D model point cloud인 model_points를 뽑는 과정

        # model_points 저장
        # fw = open('temp/{0}_model_points.xyz'.format(index), 'w')
        # for it in model_points:
        #    fw.write('{0} {1} {2}\n'.format(it[0], it[1], it[2]))
        # fw.close()

        # 3D model points에 target rotation, target translation 적용
        target = np.dot(model_points, target_r.T)
        if self.add_noise:
            target = np.add(target, target_t + add_t) # (2)
        else:
            target = np.add(target, target_t)
        
        # => 여기까지가 3D model point cloud를 target pose로 변환한 target을 뽑는 과정
        
        # target 저장
        # fw = open('temp/{0}_tar.xyz'.format(index), 'w')
        # for it in target:
        #    fw.write('{0} {1} {2}\n'.format(it[0], it[1], it[2]))
        # fw.close()
        
        return torch.from_numpy(cloud.astype(np.float32)), \
               torch.LongTensor(choose.astype(np.int32)), \
               self.norm(torch.from_numpy(img_masked.astype(np.float32))), \
               torch.from_numpy(target.astype(np.float32)), \
               torch.from_numpy(model_points.astype(np.float32)), \
               torch.LongTensor([int(obj[idx]) - 1])

    def __len__(self):
        return self.length

    # 대칭적인 물체의 index를 반환하는 함수
    def get_sym_list(self):
        return self.symmetry_obj_idx

    # Loss 계산에 사용할 point 개수를 반환하는 함수
    def get_num_points_mesh(self):
        if self.refine:
            return self.num_pt_mesh_large # refine o -> 2600
        else:
            return self.num_pt_mesh_small # refine x -> 500


border_list = [-1, 40, 80, 120, 160, 200, 240, 280, 320, 360, 400, 440, 480, 520, 560, 600, 640, 680]
img_width = 480 
img_length = 640

def get_bbox(label):
    rows = np.any(label, axis=1)
    cols = np.any(label, axis=0)
    rmin, rmax = np.where(rows)[0][[0, -1]]
    cmin, cmax = np.where(cols)[0][[0, -1]]
    rmax += 1
    cmax += 1
    r_b = rmax - rmin
    for tt in range(len(border_list)):
        if r_b > border_list[tt] and r_b < border_list[tt + 1]:
            r_b = border_list[tt + 1]
            break
    c_b = cmax - cmin
    for tt in range(len(border_list)):
        if c_b > border_list[tt] and c_b < border_list[tt + 1]:
            c_b = border_list[tt + 1]
            break
    center = [int((rmin + rmax) / 2), int((cmin + cmax) / 2)]
    rmin = center[0] - int(r_b / 2)
    rmax = center[0] + int(r_b / 2)
    cmin = center[1] - int(c_b / 2)
    cmax = center[1] + int(c_b / 2)
    if rmin < 0:
        delt = -rmin
        rmin = 0
        rmax += delt
    if cmin < 0:
        delt = -cmin
        cmin = 0
        cmax += delt
    if rmax > img_width:
        delt = rmax - img_width
        rmax = img_width
        rmin -= delt
    if cmax > img_length:
        delt = cmax - img_length
        cmax = img_length
        cmin -= delt
    return rmin, rmax, cmin, cmax


In [4]:
import numpy as np
xmap = np.array([[j for i in range(640)] for j in range(480)])
print(xmap.shape)
print(xmap)

(480, 640)
[[  0   0   0 ...   0   0   0]
 [  1   1   1 ...   1   1   1]
 [  2   2   2 ...   2   2   2]
 ...
 [477 477 477 ... 477 477 477]
 [478 478 478 ... 478 478 478]
 [479 479 479 ... 479 479 479]]


## Model

In [None]:
# ./lib/network.py

import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from PIL import Image
import numpy as np
import pdb
import torch.nn.functional as F
from lib.pspnet import PSPNet

# Color embedding은 PSPNet을 사용하여 추출
""" PSPNet
- Semantic segmentation 모델로 기존 FCN 모델이 비슷하거나 작은 객체들의 class를 잘 구분하지 못하는 한계점을 극복하고자
  Pyramid Pooling Module을 통해 global contextual information을 활용
"""
# Feature를 추출하는 backbone은 다양한 깊이의 resnet 중에 선택
psp_models = {
    'resnet18': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet18'),
    'resnet34': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=512, deep_features_size=256, backend='resnet34'),
    'resnet50': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet50'),
    'resnet101': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet101'),
    'resnet152': lambda: PSPNet(sizes=(1, 2, 3, 6), psp_size=2048, deep_features_size=1024, backend='resnet152')
}

# Color embedding을 추출하는 Modified PSPNet
class ModifiedResnet(nn.Module):

    def __init__(self, usegpu=True):
        super(ModifiedResnet, self).__init__()

        self.model = psp_models['resnet18'.lower()]()
        self.model = nn.DataParallel(self.model)

    def forward(self, x):
        x = self.model(x)
        return x

# Embedding으로부터 pixel-wise feature를 만들어내는 모듈
class PoseNetFeat(nn.Module):
    def __init__(self, num_points):
        super(PoseNetFeat, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)

        self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
        self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

        self.conv5 = torch.nn.Conv1d(256, 512, 1)
        self.conv6 = torch.nn.Conv1d(512, 1024, 1)

        self.ap1 = torch.nn.AvgPool1d(num_points)
        self.num_points = num_points

    def forward(self, x, emb):
        # 1. Pixel-wise feature 생성
        x = F.relu(self.conv1(x)) # Geometry embedding 1 (3 -> 64)
        emb = F.relu(self.e_conv1(emb)) # Color embedding 1 (32 -> 64)
        pointfeat_1 = torch.cat((x, emb), dim=1) # Pixel wise dense fusion 1 (Dim: 128)

        x = F.relu(self.conv2(x)) # Geometry embedding 2 (64 -> 128)
        emb = F.relu(self.e_conv2(emb)) # Color embedding 2 (64 -> 128)
        pointfeat_2 = torch.cat((x, emb), dim=1) # Pixel wise dense fusion 2 (Dim: 256)

        # 2. Global feature 생성
        x = F.relu(self.conv5(pointfeat_2)) # (256 -> 512)
        x = F.relu(self.conv6(x)) # (512 -> 1024)

        ap_x = self.ap1(x) # 전체 픽셀에 대하여 average pooling (Dim: 1024)

        # 3. 1과 2를 결합해 pixel-wise feature 생성
        ap_x = ap_x.view(-1, 1024, 1).repeat(1, 1, self.num_points) # Global feature를 num_points만큼 확장
        # 2개의 pixel-wise feature와 global feature 결합 (Dim: 128 + 256 + 1024)
        return torch.cat([pointfeat_1, pointfeat_2, ap_x], 1)

# Densefusion 모델 -> estimator
class PoseNet(nn.Module):
    def __init__(self, num_points, num_obj):
        # num_points : number of points on the input pointcloud
        # num_obj : number of object classes in the dataset

        super(PoseNet, self).__init__()
        self.num_points = num_points 
        self.cnn = ModifiedResnet()
        self.feat = PoseNetFeat(num_points)
        
        self.conv1_r = torch.nn.Conv1d(1408, 640, 1)
        self.conv1_t = torch.nn.Conv1d(1408, 640, 1)
        self.conv1_c = torch.nn.Conv1d(1408, 640, 1)

        self.conv2_r = torch.nn.Conv1d(640, 256, 1)
        self.conv2_t = torch.nn.Conv1d(640, 256, 1)
        self.conv2_c = torch.nn.Conv1d(640, 256, 1)

        self.conv3_r = torch.nn.Conv1d(256, 128, 1)
        self.conv3_t = torch.nn.Conv1d(256, 128, 1)
        self.conv3_c = torch.nn.Conv1d(256, 128, 1)

        self.conv4_r = torch.nn.Conv1d(128, num_obj*4, 1) #quaternion
        self.conv4_t = torch.nn.Conv1d(128, num_obj*3, 1) #translation
        self.conv4_c = torch.nn.Conv1d(128, num_obj*1, 1) #confidence

        self.num_obj = num_obj

    def forward(self, img, x, choose, obj):
        # 1. Color 이미지에서 PSPNet으로 color Embedding 추출
        out_img = self.cnn(img) # crop한 img size와 동일
        
        # 2. 선택된 point cloud index를 기반으로 color Embedding 추출
        bs, di, _, _ = out_img.size() # Embedding 크기: (batch, dim, H, W)

        emb = out_img.view(bs, di, -1) # 각 픽셀의 embedding을 1D로 펼침
        choose = choose.repeat(1, di, 1) # 각 픽셀의 embedding 전체를 선택하기 위한 확장
        emb = torch.gather(emb, 2, choose).contiguous() # point cloud index에 해당하는 color embedding만 추출 -> (batch, dim, num_points)

        # 3. Point cloud와 color embedding을 결합하여 feature 생성
        x = x.transpose(2, 1).contiguous() # Point cloud를 (batch, 3, num_points)로 변환
        ap_x = self.feat(x, emb) # PoseNetFeat에서 feature 생성

        # 4. Feature로부터 pose 추정(quaternion, translation, confidence)
        # Dim: 1408 -> 640
        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))
        cx = F.relu(self.conv1_c(ap_x))      
        # Dim: 640 -> 256
        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        cx = F.relu(self.conv2_c(cx))
        # Dim: 256 -> 128
        rx = F.relu(self.conv3_r(rx))
        tx = F.relu(self.conv3_t(tx))
        cx = F.relu(self.conv3_c(cx))
        # Dim: 128 -> (num_obj * n) : prediction per pixel
        rx = self.conv4_r(rx).view(bs, self.num_obj, 4, self.num_points) # (batch, num_obj*4, num_points) -> (batch, num_obj, 4, num_points) : 각 클래스와 매칭되는 21개의 채널
        tx = self.conv4_t(tx).view(bs, self.num_obj, 3, self.num_points) # (batch, num_obj*3, num_points) -> (batch, num_obj, 3, num_points)
        cx = torch.sigmoid(self.conv4_c(cx)).view(bs, self.num_obj, 1, self.num_points) # sigmoid -> (batch, num_obj*1, num_points) -> (batch, num_obj, 1, num_points)
        
        # 5. Batch 차원을 없애고 21개의 클래스 중 해당 객체의 값만 남김(Dataloader의 batch_size가 항상 1이라 가능)
        b = 0
        out_rx = torch.index_select(rx[b], 0, obj[b]) # (batch, num_obj, 4, num_points) -> (1, 4, num_points)
        out_tx = torch.index_select(tx[b], 0, obj[b])
        out_cx = torch.index_select(cx[b], 0, obj[b])
        
        # 6. output 차원 변경
        out_rx = out_rx.contiguous().transpose(2, 1).contiguous() # (1, 4, num_points) -> (1, num_points, 4)
        out_tx = out_tx.contiguous().transpose(2, 1).contiguous() # (1, 3, num_points) -> (1, num_points, 3)
        out_cx = out_cx.contiguous().transpose(2, 1).contiguous() # (1, 1, num_points) -> (1, num_points, 1)
        
        return out_rx, out_tx, out_cx, emb.detach()
 


class PoseRefineNetFeat(nn.Module):
    def __init__(self, num_points):
        super(PoseRefineNetFeat, self).__init__()
        self.conv1 = torch.nn.Conv1d(3, 64, 1)
        self.conv2 = torch.nn.Conv1d(64, 128, 1)

        self.e_conv1 = torch.nn.Conv1d(32, 64, 1)
        self.e_conv2 = torch.nn.Conv1d(64, 128, 1)

        self.conv5 = torch.nn.Conv1d(384, 512, 1)
        self.conv6 = torch.nn.Conv1d(512, 1024, 1)

        self.ap1 = torch.nn.AvgPool1d(num_points)
        self.num_points = num_points

    def forward(self, x, emb):
        x = F.relu(self.conv1(x))
        emb = F.relu(self.e_conv1(emb))
        pointfeat_1 = torch.cat([x, emb], dim=1)

        x = F.relu(self.conv2(x))
        emb = F.relu(self.e_conv2(emb))
        pointfeat_2 = torch.cat([x, emb], dim=1)

        pointfeat_3 = torch.cat([pointfeat_1, pointfeat_2], dim=1) # 1과 2를 결합 (Dim: 64 + 128)

        # Global feature 생성
        x = F.relu(self.conv5(pointfeat_3)) # (384 -> 512)
        x = F.relu(self.conv6(x)) # (512 -> 1024)

        ap_x = self.ap1(x) # 전체 픽셀에 대하여 average pooling (Dim: 1024)

        ap_x = ap_x.view(-1, 1024)
        return ap_x # Global feature만 반환

# Refinement를 추가한 Densefusion 모델 -> refiner
class PoseRefineNet(nn.Module):
    def __init__(self, num_points, num_obj):
        super(PoseRefineNet, self).__init__()
        self.num_points = num_points
        self.feat = PoseRefineNetFeat(num_points)
        
        self.conv1_r = torch.nn.Linear(1024, 512)
        self.conv1_t = torch.nn.Linear(1024, 512)

        self.conv2_r = torch.nn.Linear(512, 128)
        self.conv2_t = torch.nn.Linear(512, 128)

        self.conv3_r = torch.nn.Linear(128, num_obj*4) #quaternion
        self.conv3_t = torch.nn.Linear(128, num_obj*3) #translation

        self.num_obj = num_obj

    def forward(self, x, emb, obj): # new point cloud, color embedding, class index
        bs = x.size()[0]
        
        x = x.transpose(2, 1).contiguous()
        ap_x = self.feat(x, emb)

        # Feature로부터 quaternion, translation 추정
        # Dim: 1024 -> 512
        rx = F.relu(self.conv1_r(ap_x))
        tx = F.relu(self.conv1_t(ap_x))   
        # Dim: 512 -> 128
        rx = F.relu(self.conv2_r(rx))
        tx = F.relu(self.conv2_t(tx))
        # Dim: 128 -> (num_obj*n)
        rx = self.conv3_r(rx).view(bs, self.num_obj, 4)
        tx = self.conv3_t(tx).view(bs, self.num_obj, 3)
        
        # Batch 차원을 없애고 21개의 클래스 중 해당 객체의 값만 남김
        b = 0
        out_rx = torch.index_select(rx[b], 0, obj[b])
        out_tx = torch.index_select(tx[b], 0, obj[b])

        return out_rx, out_tx


## Loss

In [None]:
# ./lib/loss.py

from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch
import time
import numpy as np
import torch.nn as nn
import random
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor


def loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)
    bs, num_p, _ = pred_c.size() # (1, num_points, 1)

    # quaternion 정규화
    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    
    # 3x3 rotation matrix로 변환 : quaternion(4) -> rotation(3, 3)
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)

    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous() # 행렬곱을 위한 transpose

    # 3D model_points, target을 각각의 num_points에서 예측한 rotation, translation으로 변환하기 위해 확장
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target

    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t
    points = points.contiguous().view(bs * num_p, 1, 3)
    pred_c = pred_c.contiguous().view(bs * num_p) 

    # 예측한 rotaion과 translation으로 3D model_points 변환
    pred = torch.add(torch.bmm(model_points, base), points + pred_t) # bmm = batch matrix multiplication
    
    if not refine:
        if idx[0].item() in sym_list: # 대칭적인 물체인 경우
            target = target[0].transpose(1, 0).contiguous().view(3, -1)
            pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
            # target에서 가장 가까운 point와 매칭
            inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
            target = torch.index_select(target, 1, inds.view(-1) - 1)
            target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
            pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

    # pred, target 간 유클리드 거리를 계산하고 평균내어 각 픽셀에서 평균 거리를 계산
    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)

    # 거리 기반 loss 계산(confidence로 가중치)
    loss = torch.mean((dis * pred_c - w * torch.log(pred_c)), dim=0)
    
    # 가장 confidence가 높은 예측 선택
    pred_c = pred_c.view(bs, num_p)
    how_max, which_max = torch.max(pred_c, 1)
    dis = dis.view(bs, num_p)

    t = ori_t[which_max[0]] + points[which_max[0]]
    points = points.view(1, bs * num_p, 3)

    ori_base = ori_base[which_max[0]].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_p, 1).contiguous().view(1, bs * num_p, 3)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis[0][which_max[0]].item(), pred_c[0][which_max[0]].item(), idx[0].item())
    del knn
    return loss, dis[0][which_max[0]], new_points.detach(), new_target.detach() # loss, max distance, new_points, new_target


class Loss(_Loss):

    def __init__(self, num_points_mesh, sym_list):
        super(Loss, self).__init__(True)
        self.num_pt_mesh = num_points_mesh
        self.sym_list = sym_list

    def forward(self, pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine):

        return loss_calculation(pred_r, pred_t, pred_c, target, model_points, idx, points, w, refine, self.num_pt_mesh, self.sym_list)


In [None]:
# ./lib/loss_refiner.py

from torch.nn.modules.loss import _Loss
from torch.autograd import Variable
import torch
import time
import numpy as np
import torch.nn as nn
import random
import torch.backends.cudnn as cudnn
from lib.knn.__init__ import KNearestNeighbor


def loss_calculation(pred_r, pred_t, target, model_points, idx, points, num_point_mesh, sym_list):
    knn = KNearestNeighbor(1)
    pred_r = pred_r.view(1, 1, -1)
    pred_t = pred_t.view(1, 1, -1)
    bs, num_p, _ = pred_r.size()
    num_input_points = len(points[0])

    pred_r = pred_r / (torch.norm(pred_r, dim=2).view(bs, num_p, 1))
    
    base = torch.cat(((1.0 - 2.0*(pred_r[:, :, 2]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1),\
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] - 2.0*pred_r[:, :, 0]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 1]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 3]*pred_r[:, :, 0]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 3]**2)).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (-2.0*pred_r[:, :, 0]*pred_r[:, :, 2] + 2.0*pred_r[:, :, 1]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (2.0*pred_r[:, :, 0]*pred_r[:, :, 1] + 2.0*pred_r[:, :, 2]*pred_r[:, :, 3]).view(bs, num_p, 1), \
                      (1.0 - 2.0*(pred_r[:, :, 1]**2 + pred_r[:, :, 2]**2)).view(bs, num_p, 1)), dim=2).contiguous().view(bs * num_p, 3, 3)
    
    ori_base = base
    base = base.contiguous().transpose(2, 1).contiguous()
    model_points = model_points.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    target = target.view(bs, 1, num_point_mesh, 3).repeat(1, num_p, 1, 1).view(bs * num_p, num_point_mesh, 3)
    ori_target = target
    pred_t = pred_t.contiguous().view(bs * num_p, 1, 3)
    ori_t = pred_t

    pred = torch.add(torch.bmm(model_points, base), pred_t)

    if idx[0].item() in sym_list:
        target = target[0].transpose(1, 0).contiguous().view(3, -1)
        pred = pred.permute(2, 0, 1).contiguous().view(3, -1)
        inds = knn(target.unsqueeze(0), pred.unsqueeze(0))
        target = torch.index_select(target, 1, inds.view(-1) - 1)
        target = target.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()
        pred = pred.view(3, bs * num_p, num_point_mesh).permute(1, 2, 0).contiguous()

    dis = torch.mean(torch.norm((pred - target), dim=2), dim=1)

    t = ori_t[0]
    points = points.view(1, num_input_points, 3)

    ori_base = ori_base[0].view(1, 3, 3).contiguous()
    ori_t = t.repeat(bs * num_input_points, 1).contiguous().view(1, bs * num_input_points, 3)
    new_points = torch.bmm((points - ori_t), ori_base).contiguous()

    new_target = ori_target[0].view(1, num_point_mesh, 3).contiguous()
    ori_t = t.repeat(num_point_mesh, 1).contiguous().view(1, num_point_mesh, 3)
    new_target = torch.bmm((new_target - ori_t), ori_base).contiguous()

    # print('------------> ', dis.item(), idx[0].item())
    del knn
    return dis, new_points.detach(), new_target.detach()


class Loss_refine(_Loss):

    def __init__(self, num_points_mesh, sym_list):
        super(Loss_refine, self).__init__(True)
        self.num_pt_mesh = num_points_mesh
        self.sym_list = sym_list


    def forward(self, pred_r, pred_t, target, model_points, idx, points):
        return loss_calculation(pred_r, pred_t, target, model_points, idx, points, self.num_pt_mesh, self.sym_list)
