# Library and Basic setting

In [1]:
import numpy as np

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl

# default setting
np.set_printoptions(precision=2)
torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=4)
torch.backends.cudnn.benchmark = True
torch.set_printoptions(sci_mode=False)

In [2]:
import argparse
import sys
import os
import time
import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=20, help='size of mini batch')
parser.add_argument('--learning_rate', type=float, default=0.00001, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=.0001, help='decay rate for rmsprop')
parser.add_argument('--lamda_weights', type=float, default=.01, help='lamda weight')
parser.add_argument('--model_dir', type=str, default='/notebooks/global_localization/lightning/baseline-dual')
parser.add_argument('--train_dataset', type=str, default='/dataset/train.lmdb')
parser.add_argument('--norm_tensor', type=str, default=['/notebooks/global_localization/norm_mean_std.pt'])

sys.argv = ['']
args = parser.parse_args()

# Load Dataset

In [3]:
import lmdb

import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torchvision.transforms import ToPILImage

import tf.transformations as tf_tran

class Michigan_Image:
    def __init__(self, img, pose, sparse=False):
        img = cv2.imread(img, cv2.IMREAD_UNCHANGED)
        self.img = sp.coo_matrix(img) if sparse else img
        self.pose = np.loadtxt(pose,dtype=np.float32)

class MichiganDatasetLMDB(Dataset):
    def __init__(self, db_path, transform=None, offset=False, 
                 target_image_size=[300, 300],
                 pair=False, num_connected_frames=10):
        self.db_path = db_path
        self.env = lmdb.open(db_path,readonly=True, lock=False,
                             readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = txn.stat()['entries']
            self.keys= [str(i).encode('ascii') for i in range(self.length)]
            assert self.length >= 1

        self.transform = transform
        self.offset = offset
        self.target_image_size=target_image_size
        self.pair = pair
        self.num_connected_frames = num_connected_frames
        
    def _quaternized_pose(self, pose):
        [px, py, pz, ex, ey, ez] = pose
        [qx, qy, qz, qw] = tf_tran.quaternion_from_euler(ex, ey, ez, 'rxyz')
        pose = torch.Tensor([px, py, pz, qx, qy, qz, qw])
        return pose
        
    def _image_argumentation(self, img, target, rot_angle=None):
        RES = 1.0

        margin_row = img.shape[0] - self.target_image_size[0]
        margin_col = img.shape[1] - self.target_image_size[1]

        [px, py, pz, qx, qy, qz, qw] = target

        if self.offset:
            #np.random.seed(0) # fixed
            offset_row = int(max(0.0, min(1.0, np.random.normal(0.5, 0.1))) * margin_row)
            offset_col = int(max(0.0, min(1.0, np.random.normal(0.5, 0.1))) * margin_col)
            
            offset_x = (offset_row - int(margin_row / 2)) * RES
            offset_y = -(offset_col - int(margin_col / 2)) * RES

            if rot_angle is None:
                # not rotate
                deltaT = tf_tran.identity_matrix()
            else:
                # rotate image 180 with noise
                angle = rot_angle

                deltaT = tf_tran.rotation_matrix(angle, (0,0,1))
                img = imutils.rotate(img[:, :, 0], angle/math.pi*180)
                img = np.array(img)[:, :, np.newaxis]

            deltaT[0:3, 3] = [offset_x, offset_y, 0.]
            
            '''gt'''
            T = tf_tran.quaternion_matrix([qx, qy, qz, qw])
            T[0:3, 3] = [px, py, pz]

            # apply on global pose
            T = np.matmul(deltaT, T)

            position = np.array(tf_tran.translation_from_matrix(T), dtype=np.single)
            quaternion = np.array(tf_tran.quaternion_from_matrix(T), dtype=np.single)

            target = np.concatenate((position, quaternion), axis=0)
        else:
            offset_row = int(margin_row / 2)
            offset_col = int(margin_col / 2)

        img = img[offset_row:offset_row + self.target_image_size[0], 
                  offset_col:offset_col + self.target_image_size[1], :]

        return img, target
    
    def _request_data(self, index):
        env = self.env
        with env.begin(write=False, buffers=True) as txn:
            byteflow = txn.get(self.keys[index])
            data = pickle.loads(byteflow)
        img, pose = data.img, data.pose
        img = img[:, :, np.newaxis]
        pose = self._quaternized_pose(pose)
        #img, pose = self._image_argumentation(img, pose)
        
        return img, pose   
        
    def __len__(self):
        if self.pair:
            length = self.length - self.num_connected_frames
        else:
            length = self.length
        return length

    def __getitem__(self, index):
        if not self.pair:
            img, pose = self._request_data(index)
            
            if self.transform:
                img = self.transform(img)   
            return img, pose
        else:
            paired_frame_offset = np.random.randint(1, self.num_connected_frames)
            index_offset = index + paired_frame_offset
            img_1, pose_1 = self._request_data(index)
            img_2, pose_2 = self._request_data(index_offset)
            
            # check continuity
            # if difference is high (not continue), change to another image
            while torch.abs(pose_1-pose_2)[:3].max().item() > 15:
                #print('not continue.')
                index_offset = index - paired_frame_offset #invert index direction
                img_2, pose_2 = self._request_data(index_offset)
                paired_frame_offset = np.random.randint(1, self.num_connected_frames)
            
            if self.transform:
                img_1 = self.transform(img_1)
                img_2 = self.transform(img_2)
            return [img_1,img_2],[pose_1,pose_2]
            
[args.norm_mean, args.norm_std] = torch.load(*args.norm_tensor)

torch.manual_seed(42)
transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize(300),
                                transforms.RandomRotation(180),
                                transforms.ToTensor()])
dataset = MichiganDatasetLMDB(args.train_dataset, transform, pair=True, num_connected_frames=40)
  
num_data = len(dataset)
torch.manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [round(num_data*0.7), round(num_data*0.3)])

In [4]:
import torch.nn as nn
from torch.hub import load_state_dict_from_url

class Context(nn.Module):
    def __init__(self, input_channel=2048):
        super().__init__()
        self.squeeze = nn.Sequential(
            nn.Conv2d(in_channels=input_channel,out_channels=128,kernel_size=1),
            nn.ReLU()
        )
        self.context5_1 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU()
        )
        self.context5_2 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU()
        )
        self.context5_3 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU()
        )
        self.context5_4 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=128,kernel_size=3,padding=1),
            nn.ReLU()
        )
        self.squeeze2 = nn.Sequential(
            nn.Conv2d(in_channels=128,out_channels=64,kernel_size=1),
            nn.ReLU()
        )
        
    def forward(self, x):
        x = self.squeeze(x)
        x = self.context5_1(x) #context5_1 
        x = self.context5_2(x) #context5_2
        x = self.context5_3(x) #context5_3
        x = self.context5_4(x) #context5_4
        x = self.squeeze2(x)
        return x
    
class Regressor(nn.Module):
    def __init__(self, in_features=6400):
        super().__init__()
        self.flatten = nn.Flatten()
        # Part 1: trans
        self.fc1_trans = nn.Sequential(
            nn.Linear(in_features=in_features,out_features=4096),
            nn.ReLU(),
            #nn.Dropout(0.2)
        )
        self.fc2_trans = nn.Sequential(
            nn.Linear(in_features=4096,out_features=4096),
            nn.ReLU(),
            #nn.Dropout(0.2)
        )
        self.fc3_trans = nn.Sequential(
            nn.Linear(in_features=4096,out_features=128),
            nn.ReLU(),
            #nn.Dropout(0.2)
        )
        self.logits_t = nn.Linear(in_features=128,out_features=3)
        
        # Part 2: rot
        '''
        self.fc1_rot = nn.Sequential(
            nn.Linear(in_features=in_features,out_features=4096),
            nn.ReLU()
        )
        self.fc2_rot = nn.Sequential(
            nn.Linear(in_features=4096,out_features=4096),
            nn.ReLU()
        )
        self.fc3_rot = nn.Sequential(
            nn.Linear(in_features=4096,out_features=128),
            nn.ReLU()
        )
        self.logits_r = nn.Linear(in_features=128,out_features=4)
        '''
    def forward(self, x):
        x = self.flatten(x)
        # Part 1: trans
        net_t = self.fc1_trans(x)
        net_t = self.fc2_trans(net_t)
        feature_t = self.fc3_trans(net_t)        
        # Part 2: rot
        '''
        net_r = self.fc1_rot(x)
        net_r = self.fc2_rot(net_r)
        feature_r = self.fc3_rot(net_r)
        '''
        # Part 3: FC layer
        logits_t = self.logits_t(feature_t)
        '''
        logits_r = self.logits_r(feature_r)
        
        logits_r = nn.functional.normalize(logits_r, p=2, dim=1)
        
        logits = torch.cat([logits_t, logits_r],dim=1)
        '''
        return logits_t, feature_t
    
class vggnet(nn.Module):
    def __init__(self, opt="context", input_channel = 2048):
        super().__init__()
        self.opt = opt
        if opt == "context":
            self.context = Context(input_channel)
        elif opt == "regressor":
            self.regressor = Regressor()
        
    def forward(self,x):
        if self.opt == "context":
            return self.context(x)
        elif self.opt == "regressor":
            return self.regressor(x)

In [5]:
import sys
sys.path.append('..')
from torchlib import resnet
from torchlib.cnn_auxiliary import normalize, denormalize, denormalize_navie, get_relative_pose

def translational_rotational_loss(pred=None, gt=None, loss_type='mse'):
    trans_pred = pred[:,:3]
    trans_gt = gt[:,:3]
    
    if loss_type=='mse':
        trans_loss = nn.functional.mse_loss(input=trans_pred, target=trans_gt)
    else:
        trans_loss = torch.sum((trans_pred - trans_gt)**2,dim=1).mean()
    
    loss = trans_loss

    return loss

class Backbone(nn.Module):
    def __init__(self,pretrained=True):
        super().__init__()
        self.resnet = resnet.resnet50(pretrained=pretrained)
    def forward(self,input_data):
        dense_feat = self.resnet(input_data)
        return dense_feat

class NN(nn.Module):
    def __init__(self):
        super().__init__()
        self.global_context = vggnet(input_channel=2048,opt="context")
        self.global_regressor = vggnet(opt="regressor")
        
    def forward(self,input_data):
        context_feat = self.global_context(input_data)
        output, feature = self.global_regressor(context_feat)
        return output, feature
    
class Model(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.backbone = Backbone()
        self.nn = NN()
        
        [norm_mean, norm_std] = torch.load(*args.norm_tensor)
        self.norm_mean = torch.nn.parameter.Parameter(norm_mean,requires_grad=False)
        self.norm_std = torch.nn.parameter.Parameter(norm_std,requires_grad=False) 

    def forward(self, x):
        dense_feat = self.backbone(x)
        output, _ = self.nn(dense_feat)
        return output
    
    def training_step(self, batch, batch_idx):
        [img1,img2],[pose1,pose2] = batch
        
        pose1 = normalize(pose1, self.norm_mean, self.norm_std)
        pose2 = normalize(pose2, self.norm_mean, self.norm_std)
        
        global_loss, consistent_loss = self._loss(img1, img2, pose1, pose2)
        loss = global_loss + consistent_loss
        
        tensorboard = self.logger.experiment
        #tensorboard.add_scalar('train_loss',float(train_loss),self.global_step)
        tensorboard.add_scalars('train_loss',
                                {'total_loss':float(loss),
                                'global_loss':float(global_loss),
                                'consistent_loss':float(consistent_loss)},
                                self.global_step)
        return loss
        
    def _loss(self,x0, x1, y0, y1):
        
        y0[:,3:] = 0
        y1[:,3:] = 0
        
        # target relative
        relative_target_normed = get_relative_pose(y0, y1)
        # forward output
        global_output0 = self.forward(x0)
        global_output1 = self.forward(x1)
        # output relative
        global_output0 = torch.cat([global_output0,torch.zeros_like(y0[:,3:])], dim=1)
        global_output1 = torch.cat([global_output1,torch.zeros_like(y1[:,3:])], dim=1)
        relative_consistence = get_relative_pose(global_output0,global_output1)
        
        # target loss
        global_loss = translational_rotational_loss(pred=global_output0, gt=y0)
        # relative loss
        geometry_consistent_loss = translational_rotational_loss(pred=relative_consistence, \
                                                                 gt=relative_target_normed)        
        
        return global_loss, geometry_consistent_loss
    
    def validation_step(self, batch, batch_idx):
        [x,_],[y,_] = batch
        
        trans_target = y[:,:3]
        trans_pred = self.forward(x)
        trans_pred = denormalize_navie(trans_pred, self.norm_mean, self.norm_std)
        
        trans_loss = torch.sqrt(torch.sum((trans_pred - trans_target)**2,dim=1)).mean()
        
        val_loss = trans_loss
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=False, logger=False)
        tensorboard = self.logger.experiment
        tensorboard.add_scalars('val_loss',
                                {'trans_loss':float(trans_loss)},
                                self.current_epoch*self.trainer.num_val_batches[0]+batch_idx)
        return val_loss
        
    def configure_optimizers(self):
        lr,weight_decay = args.learning_rate,args.weight_decay
        optimizer_args = [
            {'params': self.backbone.parameters(), 'lr': lr, 'weight_decay': weight_decay},
            {'params': self.nn.parameters(), 'lr': lr, 'weight_decay': weight_decay}]
        
        optimizer = torch.optim.Adam(optimizer_args)
        #optimizer = RAdam(optimizer_args)
        
        return optimizer
    
    def show_require_grad(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                print (name, param.shape)
    
    def get_progress_bar_dict(self):
        tqdm_dict = super().get_progress_bar_dict()
        if 'v_num' in tqdm_dict:
            del tqdm_dict['v_num']
        return tqdm_dict
    
    def train_dataloader(self):
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size, pin_memory=True,
                                  shuffle=True, num_workers=os.cpu_count(), drop_last=True)
        return train_loader
    
    def val_dataloader(self):
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size, pin_memory=True,
                                shuffle=False, num_workers=os.cpu_count(), drop_last=True)
        return val_loader

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import os
os.system('rm -rf lightning_logs')
logger = TensorBoardLogger('lightning_logs')
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath='model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=2,
    mode='min',
    save_weights_only = True)

trainer = pl.Trainer(gpus=1,precision=16,
                     limit_train_batches=1.0,
                     limit_val_batches=1.0,
                     accumulate_grad_batches=1,
                     reload_dataloaders_every_epoch = True,
                     logger=logger,
                     checkpoint_callback=checkpoint_callback)
model = Model.load_from_checkpoint(os.path.join(args.model_dir,'model-epoch=20-val_loss=2.14.ckpt'))
#model.show_require_grad()
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.
