# Library and Basic setting

In [1]:
import numpy as np

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl

# default setting
np.set_printoptions(precision=2)
torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=4)
torch.backends.cudnn.benchmark = True
torch.set_printoptions(sci_mode=False)

In [2]:
import argparse
import sys
import os
import time
import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=130,
                    help='size of mini batch')
parser.add_argument('--learning_rate', type=float, default=0.0002, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=.0001, help='decay rate for rmsprop')
parser.add_argument('--lamda_weights', type=float, default=.01, help='lamda weight')
parser.add_argument('--is_normalization', type=bool,
                    default=True, help='whether do data normalization')
parser.add_argument('--target_image_size', default=[300, 300], nargs=2, type=int,
                    help='Input images will be resized to this for data argumentation.')
parser.add_argument('--model_dir', type=str,
                    default='/notebooks/global_localization/gp_net_torch', help='rnn, gru, or lstm')
parser.add_argument('--test_dataset', type=str, default=['/notebooks/michigan_nn_data/2012_02_12',
                                                         '/notebooks/michigan_nn_data/2012_04_29',
                                                         '/notebooks/michigan_nn_data/2012_05_11',
                                                         '/notebooks/michigan_nn_data/2012_06_15',
                                                         '/notebooks/michigan_nn_data/2012_08_04',
                                                         '/notebooks/michigan_nn_data/2012_10_28',
                                                         '/notebooks/michigan_nn_data/2012_11_16',
                                                         '/notebooks/michigan_nn_data/2012_12_01'])
parser.add_argument('--train_dataset', type=str, default=['/notebooks/michigan_nn_data/2012_01_08',
                                                          '/notebooks/michigan_nn_data/2012_01_15',
                                                          '/notebooks/michigan_nn_data/2012_01_22',
                                                          '/notebooks/michigan_nn_data/2012_02_02',
                                                          '/notebooks/michigan_nn_data/2012_02_04',
                                                          '/notebooks/michigan_nn_data/2012_02_05',
                                                          '/notebooks/michigan_nn_data/2012_03_31',
                                                          '/notebooks/michigan_nn_data/2012_09_28'])
parser.add_argument('--norm_tensor', type=str,
                    default=['/notebooks/global_localization/norm_mean_std.pt'])

sys.argv = ['']
args = parser.parse_args()

# Load Dataset

In [3]:
import torchvision.transforms as transforms
import os
import sys
sys.path.append('..')
from torchlib.utils import LocalizationDataset
from torch.utils.data import DataLoader

[args.norm_mean, args.norm_std] = torch.load(*args.norm_tensor)

transform = transforms.Compose([transforms.ToTensor()])
dataset = LocalizationDataset(dataset_dirs = args.train_dataset, \
                              image_size = args.target_image_size, \
                              transform = transform, get_pair = False, sampling_rate=2)
num_data = len(dataset)
torch.manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [round(num_data*0.7), round(num_data*0.3)])

the rosdep view is empty: call 'sudo rosdep init' and 'rosdep update'
100%|██████████| 16446/16446 [00:22<00:00, 715.16it/s]
100%|██████████| 22584/22584 [00:31<00:00, 709.67it/s]
100%|██████████| 18655/18655 [00:26<00:00, 716.58it/s]
100%|██████████| 17310/17310 [00:24<00:00, 715.67it/s]
100%|██████████| 10766/10766 [00:14<00:00, 720.02it/s]
100%|██████████| 14878/14878 [00:22<00:00, 660.79it/s]
100%|██████████| 13452/13452 [00:20<00:00, 662.76it/s]
100%|██████████| 14037/14037 [00:24<00:00, 569.53it/s]


In [None]:
#from torch.cuda.amp import autocast, GradScaler
import gpytorch
import sys
sys.path.append('..')
from torchlib.GPs import Backbone, NN
from torchlib.cnn_auxiliary import normalize, denormalize_navie


# Gaussian Process Model
class GP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, output_dim=3):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([output_dim])
        )
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ), num_tasks=output_dim
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([1]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([1])),
            batch_shape=torch.Size([1]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class Baseline(pl.LightningModule):

    def __init__(self):
        super().__init__()
        # parameters
        inducing_points = torch.zeros(3, 300, 128)
        self.backbone = Backbone()
        self.nn = NN()
        self.gp = GP(inducing_points)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
            num_tasks=3)
        [norm_mean, norm_std] = torch.load(*args.norm_tensor)
        self.norm_mean = torch.nn.parameter.Parameter(norm_mean,requires_grad=False)
        self.norm_std = torch.nn.parameter.Parameter(norm_std,requires_grad=False) 
        
        '''
        # load pre-trained model
        state_dict = torch.load(os.path.join(args.model_dir, 'pretrained.pth'),map_location=self.device)
        for key in list(state_dict):
            if 'net.resnet.' in key:
                state_dict[key.replace('net.resnet.','backbone.resnet.')] = state_dict.pop(key)
            if 'net.global_regressor.' in key:
                state_dict[key.replace('net.global_regressor.','nn.global_regressor.')] = state_dict.pop(key)
            elif 'net.global_context.' in key:
                state_dict[key.replace('net.global_context.','nn.global_context.')] = state_dict.pop(key)
        self.load_state_dict(state_dict,strict = False)
        '''
        
        # shut down backbone learning
        self.__disable_grad(self.backbone)

    def forward(self, x):
        dense_feat = self.backbone(x)
        output, feature_t, feature_r = self.nn(dense_feat)
        _, rot_pred = torch.split(output, [3, 4], dim=1)
        trans_pred = self.gp(feature_t)
        return trans_pred, rot_pred

    def training_step(self, batch, batch_idx):
        x, y = batch.values()
        y = normalize(y, self.norm_mean, self.norm_std)
        train_loss,trans_loss,rot_loss = self.__loss(x, y)
        #self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        tensorboard = self.logger.experiment
        #tensorboard.add_scalar('train_loss',float(train_loss),self.global_step)
        tensorboard.add_scalars('train_loss',
                                {'total_loss':float(train_loss),
                                'trans_loss':float(trans_loss),
                                'rot_loss':float(rot_loss)},
                                self.global_step)
        return train_loss

    def __loss(self, x, y):
        # target
        trans_target, rot_target = torch.split(y, [3, 4], dim=1)
        # predict
        trans_pred, rot_pred = self.forward(x)

        # trans loss
        mll = gpytorch.mlls.PredictiveLogLikelihood(self.likelihood, self.gp, num_data=num_data)
        trans_loss = -1.*mll(trans_pred, trans_target)
        # rot loss
        rot_loss = 1. - \
            torch.mean(torch.square(
                torch.sum(torch.mul(rot_pred, rot_target), dim=1)))

        total_loss = trans_loss + args.lamda_weights * rot_loss

        return total_loss, trans_loss, rot_loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch.values()
        trans_target, rot_target = torch.split(y, [3, 4], dim=1)
        trans_pred, rot_pred = self.forward(x)
        
        trans_pred, trans_mean, trans_var = self._eval_gp(trans_pred)
        trans_pred = denormalize_navie(trans_pred, self.norm_mean, self.norm_std)
        trans_mean = denormalize_navie(trans_mean, self.norm_mean, self.norm_std)
        trans_var = trans_var.mul(self.norm_std)
        samples = self._sample(trans_mean, trans_var, 100)
        
        trans_loss = torch.sqrt(torch.sum((trans_pred - trans_target)**2,dim=1)).mean()
        rot_loss = 1. - torch.mean(torch.square(torch.sum(torch.mul(rot_pred,rot_target),dim=1)))
        #return trans_pred, rot_pred, trans_target, rot_target, samples
        
        val_loss = trans_loss
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=False, logger=False)
        tensorboard = self.logger.experiment
        tensorboard.add_scalars('val_loss',
                                {'trans_loss':float(trans_loss),
                                'rot_loss':float(rot_loss)},
                                self.current_epoch*self.trainer.num_val_batches[0]+batch_idx)
        return val_loss
    
    def _eval_gp(self, trans_pred):
        c_mean, c_var = trans_pred.mean, trans_pred.variance
        y_mean, y_var = self.likelihood(trans_pred).mean, self.likelihood(trans_pred).variance
        
        return y_mean, c_mean, c_var
    
    def _sample(self, mean, var, num_sample = 100):
        dist = torch.distributions.Normal(mean, var)
        samples = dist.sample([num_sample])
        return samples

    def configure_optimizers(self):
        lr,weight_decay = args.learning_rate,args.weight_decay
        optimizer_args = [
            {'params': self.gp.parameters(), 'lr': lr, 'weight_decay': weight_decay},
            {'params': self.likelihood.parameters(), 'lr': lr,
             'weight_decay': weight_decay},
            {'params': self.nn.global_regressor.parameters(), 'lr': lr * 0.01,
             'weight_decay': weight_decay},
            {'params': self.nn.global_context.parameters(), 'lr': lr * 0.001, 'weight_decay': weight_decay}]
        
        optimizer = torch.optim.Adam(optimizer_args)

        return optimizer
    
    def show_require_grad(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                print (name, param.shape)
                
    def __disable_grad(self,model):
        for param in model.parameters():
            param.requires_grad = False
            
    def get_progress_bar_dict(self):
        tqdm_dict = super().get_progress_bar_dict()
        if 'v_num' in tqdm_dict:
            del tqdm_dict['v_num']
        return tqdm_dict
    
    def train_dataloader(self):
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=os.cpu_count(),drop_last=True)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                                shuffle=False, num_workers=os.cpu_count(),drop_last=True)
        return val_loader
    
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import os
os.system('rm -rf lightning_logs')
logger = TensorBoardLogger('lightning_logs')
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    filepath='model-{epoch:02d}-{val_loss:.2f}',
    save_top_k=2,
    mode='min',
    save_weights_only = True)

trainer = pl.Trainer(gpus=1,precision=32,
                     limit_train_batches=0.8,
                     limit_val_batches=0.2,
                     accumulate_grad_batches=1,
                     reload_dataloaders_every_epoch = True,
                     logger=logger,
                     checkpoint_callback=checkpoint_callback)
model = Baseline.load_from_checkpoint('pretrained-model-epoch=98-val_loss=0.32.ckpt')
trainer.fit(model)