# Library and Basic setting

In [1]:
import numpy as np

import torch
from torch.nn import functional as F
from torch import nn
import pytorch_lightning as pl

# default setting
np.set_printoptions(precision=2)
torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=4)
torch.backends.cudnn.benchmark = True
torch.set_printoptions(sci_mode=False)

In [2]:
import argparse
import sys
import os
import time
import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=170,
                    help='size of mini batch')
parser.add_argument('--learning_rate', type=float, default=0.02, help='learning rate')
parser.add_argument('--weight_decay', type=float, default=.0001, help='decay rate for rmsprop')
parser.add_argument('--lamda_weights', type=float, default=.01, help='lamda weight')
parser.add_argument('--is_normalization', type=bool,
                    default=True, help='whether do data normalization')
parser.add_argument('--target_image_size', default=[300, 300], nargs=2, type=int,
                    help='Input images will be resized to this for data argumentation.')
parser.add_argument('--model_dir', type=str,
                    default='/notebooks/global_localization/lightning/baseline', help='rnn, gru, or lstm')
parser.add_argument('--test_dataset', type=str, default=['/notebooks/michigan_nn_data/2012_02_12',
                                                         '/notebooks/michigan_nn_data/2012_04_29',
                                                         '/notebooks/michigan_nn_data/2012_05_11',
                                                         '/notebooks/michigan_nn_data/2012_06_15',
                                                         '/notebooks/michigan_nn_data/2012_08_04',
                                                         '/notebooks/michigan_nn_data/2012_10_28',
                                                         '/notebooks/michigan_nn_data/2012_11_16',
                                                         '/notebooks/michigan_nn_data/2012_12_01'])
parser.add_argument('--train_dataset', type=str, default=['/notebooks/michigan_nn_data/2012_01_08',
                                                          '/notebooks/michigan_nn_data/2012_01_15',
                                                          '/notebooks/michigan_nn_data/2012_01_22',
                                                          '/notebooks/michigan_nn_data/2012_02_02',
                                                          '/notebooks/michigan_nn_data/2012_02_04',
                                                          '/notebooks/michigan_nn_data/2012_02_05',
                                                          '/notebooks/michigan_nn_data/2012_03_31',
                                                          '/notebooks/michigan_nn_data/2012_09_28'])
parser.add_argument('--norm_tensor', type=str,
                    default=['/notebooks/global_localization/norm_mean_std.pt'])

sys.argv = ['']
args = parser.parse_args()

# Load Dataset

In [3]:
import torchvision.transforms as transforms
import os
import sys
sys.path.append('..')
from torchlib.utils import LocalizationDataset
from torch.utils.data import DataLoader

[args.norm_mean, args.norm_std] = torch.load(*args.norm_tensor)

transform = transforms.Compose([transforms.ToTensor()])
dataset = LocalizationDataset(dataset_dirs = args.train_dataset, \
                              image_size = args.target_image_size, \
                              transform = transform, get_pair = False, sampling_rate=20)
num_data = len(dataset)
torch.manual_seed(42)
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [round(num_data*0.7), round(num_data*0.3)])

100%|██████████| 1644/1644 [00:02<00:00, 716.88it/s]
100%|██████████| 2258/2258 [00:03<00:00, 712.47it/s]
100%|██████████| 1865/1865 [00:02<00:00, 720.97it/s]
100%|██████████| 1731/1731 [00:02<00:00, 721.39it/s]
100%|██████████| 1076/1076 [00:01<00:00, 721.35it/s]
100%|██████████| 1487/1487 [00:02<00:00, 717.08it/s]
100%|██████████| 1345/1345 [00:01<00:00, 718.61it/s]
100%|██████████| 1403/1403 [00:02<00:00, 696.80it/s]


In [4]:
#from torch.cuda.amp import autocast, GradScaler
import gpytorch
import sys
sys.path.append('..')
from torchlib.GPs import Backbone, NN
from torchlib.cnn_auxiliary import normalize, denormalize_navie


# Gaussian Process Model
class GP(gpytorch.models.ApproximateGP):
    def __init__(self, inducing_points, output_dim=3):
        variational_distribution = gpytorch.variational.CholeskyVariationalDistribution(
            inducing_points.size(-2), batch_shape=torch.Size([output_dim])
        )
        variational_strategy = gpytorch.variational.IndependentMultitaskVariationalStrategy(
            gpytorch.variational.VariationalStrategy(
                self, inducing_points, variational_distribution, learn_inducing_locations=True
            ), num_tasks=output_dim
        )
        super().__init__(variational_strategy)
        self.mean_module = gpytorch.means.ConstantMean(batch_shape=torch.Size([1]))
        self.covar_module = gpytorch.kernels.ScaleKernel(
            gpytorch.kernels.RBFKernel(batch_shape=torch.Size([1])),
            batch_shape=torch.Size([1]))

    def forward(self, x):
        mean_x = self.mean_module(x)
        covar_x = self.covar_module(x)
        return gpytorch.distributions.MultivariateNormal(mean_x, covar_x)

class Baseline(nn.Module):
    def __init__(self, pretrained = True):
        super().__init__()
        # parameters
        inducing_points = torch.zeros(3, 300, 128)
        self.backbone = Backbone()
        self.nn = NN()
        self.gp = GP(inducing_points)
        self.likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3)
        [norm_mean, norm_std] = torch.load(*args.norm_tensor)
        self.norm_mean = torch.nn.parameter.Parameter(norm_mean,requires_grad=False)
        self.norm_std = torch.nn.parameter.Parameter(norm_std,requires_grad=False) 
        
        if pretrained:
            self.load_state_dict(torch.load(os.path.join(args.model_dir,'pretrained-model-epoch=58-val_loss=0.26.ckpt'))['state_dict'])
            self._disable_grad(self)
            self.eval()
                
    def _disable_grad(self,model):
        for param in model.parameters():
            param.requires_grad = False
            
    def forward(self, x):
        dense_feat = self.backbone(x)
        output, feature_t, feature_r = self.nn(dense_feat)
        _, rot_pred = torch.split(output, [3, 4], dim=1)
        trans_pred = self.gp(feature_t)
        return trans_pred, rot_pred
    
    def forward_feature(self, x):
        dense_feat = self.backbone(x)
        _, feature_t, _ = self.nn(dense_feat)
        return feature_t
    
class Model(pl.LightningModule):
    def __init__(self, gp_idx, n_gps = 3):
        super().__init__()
        # Baseline model
        self.baseline = Baseline(pretrained = True)
        # GP estimators
        self.gps = nn.ModuleList()
        self.likelihoods = nn.ModuleList()
        
        for _ in range(n_gps):
            inducing_points = torch.zeros(3, 300, 128)
            self.gps.append(GP(inducing_points))
            self.likelihoods.append(gpytorch.likelihoods.MultitaskGaussianLikelihood(num_tasks=3))  
            
        self.gp_idx = gp_idx
        for idx in range(n_gps):
            if idx != self.gp_idx:
                self._disable_grad(self.gps[idx])
                self._disable_grad(self.likelihoods[idx])
            
    def forward(self, x):
        trans_pred, rot_pred = self.baseline(x)
        trans_pred_mean = trans_pred.mean
        
        trans_feat = self.baseline.forward_feature(x) #share data
        
        for i, gp in enumerate(self.gps):
            trans_pred_mean += 1 * gp(trans_feat).mean
            if i == self.gp_idx:
                break

        return trans_pred, rot_pred #no need to return `trans_pred_mean`
    
    def training_step(self, batch, batch_idx):
        # by default `gp_index`: {0,1,2}
        x, y = batch.values()
        y = normalize(y, self.baseline.norm_mean, self.baseline.norm_std)
        y, _ = torch.split(y, [3, 4], dim=1)
        
        trans_pred, _ = self.baseline(x)
        trans_pred_mean = trans_pred.mean
        
        trans_feat = self.baseline.forward_feature(x)
        
        x = trans_feat
        y = self._pseudo_residual(x, y, trans_pred_mean, self.gp_idx) # 计算残差
        
        loss = self.__loss(x, y, self.gp_idx)
        #self.log('train_loss', train_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        tensorboard = self.logger.experiment
        #tensorboard.add_scalar('train_loss',float(train_loss),self.global_step)
        tensorboard.add_scalars('train_loss',
                                {'loss':float(loss)},
                                self.global_step)
        return loss
    
    def __loss(self, x, y, gp_idx):
        # predict
        trans_pred = self.gps[gp_idx](x)

        # loss
        mll = gpytorch.mlls.PredictiveLogLikelihood(self.likelihoods[gp_idx], self.gps[gp_idx], num_data=num_data)
        trans_loss = -1.*mll(trans_pred, y)

        return trans_loss
    
    def _pseudo_residual(self, x, y, y_baseline, gp_idx):
        for idx in range(gp_idx):
            y_baseline += 1 * self.gps[idx](x).mean
        return y - y_baseline
    
    def validation_step(self, batch, batch_idx):
        x, y = batch.values()
        trans_target, rot_target = torch.split(y, [3, 4], dim=1)
        trans_pred, rot_pred = self.forward(x)
        
        trans_pred, trans_mean, trans_var = self._eval_gp(trans_pred)
        trans_pred = denormalize_navie(trans_pred, self.baseline.norm_mean, self.baseline.norm_std)
        trans_mean = denormalize_navie(trans_mean, self.baseline.norm_mean, self.baseline.norm_std)
        trans_var = trans_var.mul(self.baseline.norm_std)
        samples = self._sample(trans_mean, trans_var, 100)
        
        trans_loss = torch.sqrt(torch.sum((trans_pred - trans_target)**2,dim=1)).mean()
        rot_loss = 1. - torch.mean(torch.square(torch.sum(torch.mul(rot_pred,rot_target),dim=1)))
        #return trans_pred, rot_pred, trans_target, rot_target, samples
        
        val_loss = trans_loss
        self.log('val_loss', val_loss, on_step=True, on_epoch=True, prog_bar=False, logger=False)
        self.log('gp_idx', self.gp_idx, on_step=True, on_epoch=True, prog_bar=False, logger=False)
        tensorboard = self.logger.experiment
        tensorboard.add_scalars('val_loss',
                                {'trans_loss':float(trans_loss),
                                'rot_loss':float(rot_loss)},
                                self.current_epoch*self.trainer.num_val_batches[0]+batch_idx)
        return val_loss
    
    def _eval_gp(self, trans_pred):
        c_mean, c_var = trans_pred.mean, trans_pred.variance
        y_mean, y_var = self.baseline.likelihood(trans_pred).mean, self.baseline.likelihood(trans_pred).variance
        
        return y_mean, c_mean, c_var
    
    def _sample(self, mean, var, num_sample = 100):
        dist = torch.distributions.Normal(mean, var)
        samples = dist.sample([num_sample])
        return samples
    
    def configure_optimizers(self):
        lr,weight_decay = args.learning_rate,args.weight_decay
        optimizer_args = [
            {'params': self.gps[self.gp_idx].parameters(), 'lr': lr, 'weight_decay': weight_decay},
            {'params': self.likelihoods[self.gp_idx].parameters(), 'lr': lr,
             'weight_decay': weight_decay},]
        
        optimizer = torch.optim.Adam(optimizer_args)

        return optimizer
    
    def show_require_grad(self):
        for name, param in self.named_parameters():
            if param.requires_grad:
                print (name, param.shape)
                
    def _disable_grad(self,model):
        for param in model.parameters():
            param.requires_grad = False
    
    def get_progress_bar_dict(self):
        tqdm_dict = super().get_progress_bar_dict()
        if 'v_num' in tqdm_dict:
            del tqdm_dict['v_num']
        return tqdm_dict
    
    def train_dataloader(self):
        train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=os.cpu_count(),drop_last=True)
        return train_loader

    def val_dataloader(self):
        val_loader = DataLoader(val_dataset, batch_size=args.batch_size,
                                shuffle=False, num_workers=os.cpu_count(),drop_last=True)
        return val_loader

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks import ModelCheckpoint
'''
import os
os.system('rm -rf lightning_logs')
logger = TensorBoardLogger('lightning_logs')
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    #filepath='model-{epoch:02d}-{val_loss:.2f}',
    filepath='model-{gp_idx:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
    save_weights_only = True)


n_gps = 3
for gp_idx in range(1,n_gps):
    
    trainer = pl.Trainer(gpus=1,precision=32,
                     max_epochs=5,
                     limit_train_batches=0.8,
                     limit_val_batches=0.2,
                     accumulate_grad_batches=2,
                     reload_dataloaders_every_epoch = True,
                     logger=logger,
                     checkpoint_callback=checkpoint_callback)
    
    model = Model(gp_idx,n_gps)
    if gp_idx != 0:
        # find model with lowest val error
        files = os.listdir()
        files.sort()
        for file in files:
            if file.endswith('.ckpt') and file.split('-')[1] == 'gp_idx='+'{:02d}'.format(gp_idx-1):
                break
        model.load_state_dict(torch.load(file)['state_dict'])
    model.show_require_grad()
    trainer.fit(model)
'''

"\nimport os\nos.system('rm -rf lightning_logs')\nlogger = TensorBoardLogger('lightning_logs')\ncheckpoint_callback = ModelCheckpoint(\n    monitor='val_loss',\n    #filepath='model-{epoch:02d}-{val_loss:.2f}',\n    filepath='model-{gp_idx:02d}-{val_loss:.2f}',\n    save_top_k=3,\n    mode='min',\n    save_weights_only = True)\n\n\nn_gps = 3\nfor gp_idx in range(1,n_gps):\n    \n    trainer = pl.Trainer(gpus=1,precision=32,\n                     max_epochs=5,\n                     limit_train_batches=0.8,\n                     limit_val_batches=0.2,\n                     accumulate_grad_batches=2,\n                     reload_dataloaders_every_epoch = True,\n                     logger=logger,\n                     checkpoint_callback=checkpoint_callback)\n    \n    model = Model(gp_idx,n_gps)\n    if gp_idx != 0:\n        # find model with lowest val error\n        files = os.listdir()\n        files.sort()\n        for file in files:\n            if file.endswith('.ckpt') and file.sp

In [5]:
n_gps = 3
model = Model(n_gps-1,n_gps)
model.load_state_dict(torch.load('model-gp_idx=02-val_loss=0.26.ckpt')['state_dict'])

<All keys matched successfully>

In [6]:
model.cuda()
model.eval()
train_loader = DataLoader(val_dataset, batch_size=args.batch_size,
                                  shuffle=True, num_workers=os.cpu_count(),drop_last=True)
x,y = next(iter(train_loader)).values()
x=x.cuda()
y=y.cuda()

In [37]:
model.gp_idx = 1
trans_pred,_ = model(x)
y_pred_2 = denormalize_navie(trans_pred.mean, model.baseline.norm_mean, model.baseline.norm_std)
y_pred_2

tensor([[  -182.2388,    653.4119,    -12.8527],
        [  -294.6024,    640.8831,    -14.1611],
        [   -59.3149,    275.1011,     -8.0924],
        [    16.8487,    397.6211,     -1.6299],
        [   -56.7034,     69.6301,     -2.3728],
        [  -224.2253,    718.9692,    -13.2310],
        [  -135.5402,    580.5063,    -11.8316],
        [   -80.5028,    314.2275,     -8.5604],
        [    73.1532,    238.0699,     -2.2906],
        [  -219.0957,    615.9213,    -12.2337],
        [   -73.5760,    330.6924,     -8.3790],
        [    62.5162,     43.8641,     -2.3388],
        [  -203.6885,    613.4685,    -12.0833],
        [  -219.7526,    552.8555,    -11.8488],
        [  -297.5619,    444.5047,    -12.2123],
        [  -290.8160,    546.7823,    -12.4105],
        [   -51.7065,    551.0720,    -11.5534],
        [   -63.1721,    540.7927,    -11.4655],
        [  -312.5062,    465.0785,    -12.0790],
        [    34.0626,    467.6739,     -2.4572],
        [  -281.7507

In [38]:



y[:,:3]

tensor([[  -183.6275,    654.0954,    -12.8516],
        [  -294.7311,    641.0666,    -14.1471],
        [   -59.2336,    274.6865,     -8.0571],
        [    17.3419,    398.9818,     -1.6099],
        [   -55.9035,     71.1568,     -2.3737],
        [  -223.9893,    717.3274,    -13.2393],
        [  -136.6159,    580.3882,    -11.8235],
        [   -80.5338,    314.6850,     -8.6029],
        [    73.9773,    238.8143,     -2.2998],
        [  -218.7723,    614.9406,    -12.2277],
        [   -70.2328,    332.7012,     -8.3526],
        [    63.8172,     44.0244,     -2.2993],
        [  -203.2543,    613.3282,    -12.0711],
        [  -219.9280,    552.9838,    -11.8480],
        [  -298.4624,    444.6995,    -12.2160],
        [  -290.0325,    546.9589,    -12.3854],
        [   -51.4839,    550.7039,    -11.5664],
        [   -62.6016,    539.8173,    -11.4067],
        [  -312.6036,    464.7505,    -12.0812],
        [    32.7084,    468.5026,     -2.4331],
        [  -282.1566

In [40]:
model.gp_idx = 0
trans_pred,_ = model(x)
y_pred_0 = denormalize_navie(trans_pred.mean, model.baseline.norm_mean, model.baseline.norm_std)
y_pred_0

tensor([[  -182.1475,    653.5393,    -12.8496],
        [  -294.4902,    641.1179,    -14.1544],
        [   -59.2198,    275.1216,     -8.0913],
        [    16.9713,    397.6711,     -1.6312],
        [   -56.6015,     69.5793,     -2.3739],
        [  -224.0988,    719.1587,    -13.2275],
        [  -135.4657,    580.5775,    -11.8298],
        [   -80.4140,    314.2381,     -8.5602],
        [    73.2616,    238.0522,     -2.2891],
        [  -219.0245,    616.0034,    -12.2304],
        [   -73.4974,    330.7007,     -8.3793],
        [    62.6239,     43.7527,     -2.3318],
        [  -203.6241,    613.5387,    -12.0804],
        [  -219.6853,    552.8760,    -11.8464],
        [  -297.4635,    444.5179,    -12.2094],
        [  -290.7291,    546.8369,    -12.4058],
        [   -51.5851,    551.2047,    -11.5512],
        [   -63.0576,    540.9080,    -11.4636],
        [  -312.4077,    465.0789,    -12.0755],
        [    34.2319,    467.7459,     -2.4583],
        [  -281.6317

In [41]:
print(torch.sqrt(torch.sum((y_pred_0 - y[:,:3])**2,dim=1)).mean(),
    torch.sqrt(torch.sum((y_pred_0 - y[:,:3])**2,dim=1)).sort(descending=True)[0][:5])

tensor(1.2089, device='cuda:0') tensor([5.1983, 4.2967, 4.2113, 4.1466, 3.9337], device='cuda:0')


In [39]:
print(torch.sqrt(torch.sum((y_pred_2 - y[:,:3])**2,dim=1)).mean(),
    torch.sqrt(torch.sum((y_pred_2 - y[:,:3])**2,dim=1)).sort(descending=True)[0][:5])

tensor(1.2052, device='cuda:0') tensor([5.2292, 4.2761, 4.0550, 3.9817, 3.9201], device='cuda:0')
