# Library

In [1]:
import numpy as np
#import os
#os.environ["CUDA_VISIBLE_DEVICES"]="-1"
import torch
import argparse
import time
import pickle

#from src.self_awareness.networks import utils
#from src.self_awareness.learning.tf_cnn_auxiliary_gp import Model
from torch.distributions import Normal
import matplotlib.pyplot as plt
import random

import roslib
import rospy
import tf as tf_ros
from nav_msgs.msg import Odometry, Path
from sensor_msgs.msg import Image
from cv_bridge import CvBridge
from geometry_msgs.msg import PoseStamped, PoseArray, Pose
import math
import cv2
import copy

the rosdep view is empty: call 'sudo rosdep init' and 'rosdep update'


# Check GPU

In [2]:
if torch.cuda.is_available():
    print(torch.cuda.get_device_name(1))

TITAN Xp


# Set torch default parameters

In [3]:
torch.set_default_dtype(torch.float32)
torch.set_printoptions(precision=4,sci_mode=False)
torch.backends.cudnn.benchmark = True

# Set Arguments

In [4]:
import argparse
import sys
import os
import time
import pickle

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', type=int, default=450, help='size of mini batch')
parser.add_argument('--target_image_size', default=[300, 300], nargs=2, type=int, help='Input images will be resized to this for data argumentation.')

parser.add_argument('--model_dir', type=str, default='/notebooks/global_localization/dual_resnet_torch', help='model directory')

parser.add_argument('--test_dataset', type=str, default=[# '/notebooks/michigan_nn_data/2012_01_08',
                                                         # '/notebooks/michigan_nn_data/2012_01_15',
                                                         # '/notebooks/michigan_nn_data/2012_01_22',
                                                         # '/notebooks/michigan_nn_data/2012_02_02',
                                                         # '/notebooks/michigan_nn_data/2012_02_04',
                                                         # '/notebooks/michigan_nn_data/2012_02_05',
                                                         '/notebooks/michigan_nn_data/2012_02_12',
                                                         # '/notebooks/michigan_nn_data/2012_03_31',
                                                         '/notebooks/michigan_nn_data/2012_04_29',
                                                         '/notebooks/michigan_nn_data/2012_05_11',
                                                         '/notebooks/michigan_nn_data/2012_06_15',
                                                         '/notebooks/michigan_nn_data/2012_08_04',
                                                         # '/notebooks/michigan_nn_data/2012_09_28'])
                                                         '/notebooks/michigan_nn_data/2012_10_28',
                                                         '/notebooks/michigan_nn_data/2012_11_16',
                                                         '/notebooks/michigan_nn_data/2012_12_01'
                                                        ] )

parser.add_argument('--train_dataset', type=str, default = ['/notebooks/michigan_nn_data/test'])
#parser.add_argument('--map_dataset', type=str, default='/home/kevin/data/michigan_gt/training')
parser.add_argument('--enable_ros', type=bool, default=False, help='put data into ros')
sys.argv = ['']
args = parser.parse_args()

if args.enable_ros:
    rospy.init_node('global_localization_tf_broadcaster_cnn')

# Load Dataset

In [5]:
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
import tf.transformations as tf_tran
from tqdm import tqdm
#from PIL import Image
import numpy as np
import random

import torch.nn as nn
import torch.optim as optim
from torchlib import resnet, vggnet, cnn_auxiliary
from torchlib.cnn_auxiliary import normalize, denormalize, get_relative_pose, translational_rotational_loss
from torchlib.utils import LocalizationDataset, display_loss, data2tensorboard
import time

transform = transforms.Compose([transforms.ToTensor()])
dataset = LocalizationDataset(dataset_dirs = args.test_dataset, \
                              image_size = args.target_image_size, \
                              transform = transform,
                              get_pair = False, mode='evaluate')
#[args.norm_mean, args.norm_std] = [torch.tensor(x) for x in dataset.get_norm()]
[args.norm_mean, args.norm_std] = torch.load('/notebooks/global_localization/norm_mean_std.pt')

dataloader = DataLoader(dataset, batch_size=args.batch_size, \
                        shuffle=False, num_workers=0, \
                        drop_last=False, pin_memory=True)

100%|██████████| 14301/14301 [00:15<00:00, 932.41it/s]
100%|██████████| 7008/7008 [00:08<00:00, 799.51it/s]
100%|██████████| 12852/12852 [00:16<00:00, 777.45it/s]
100%|██████████| 9567/9567 [00:12<00:00, 768.60it/s]
100%|██████████| 13580/13580 [00:17<00:00, 785.69it/s]
100%|██████████| 14835/14835 [00:19<00:00, 773.39it/s]
100%|██████████| 7114/7114 [00:09<00:00, 764.31it/s]
100%|██████████| 12683/12683 [00:16<00:00, 782.30it/s]


# Define Model

In [6]:
class CNN_Model:
    def __init__(self, training = True, device = "cpu"):
        # device
        self.device = torch.device(device)
        
        # data
        self.model = cnn_auxiliary.Model(training).to(device)
        self.norm_mean = args.norm_mean.to(device)
        self.norm_std = args.norm_std.to(device)
        
        # training tool
        if training:
            self.optimizer = optim.Adam(self.model.parameters(), 
                                        lr=args.learning_rate, 
                                        weight_decay=args.weight_decay)
            self.scheduler = optim.lr_scheduler.LambdaLR(optimizer=self.optimizer,
                                                         lr_lambda=lambda epoch: args.decay_rate**epoch)
        
    def load_model(self, file_name = 'pretrained.pth', display_info = True):
        state_dict = torch.load(os.path.join(args.model_dir, file_name))
        if display_info:
            for name,param in state_dict.items():
                print(name, param.shape)
            print('Parameters layer:',len(state_dict.keys()))
        self.model.load_state_dict(state_dict,strict = False)
        
    def display_structure(self):
        for name, param in self.model.named_parameters():
            print(name, param.shape)
        print('Parameters layer:',len(self.model.state_dict().keys()))
    
    def display_require_grad(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                print(name, param.shape)
    
    def power_resnet(self, status = False):
        if status == 'off':
            for param in self.model.resnet.parameters():
                param.requires_grad = False
        elif status == 'on':
            for param in self.model.resnet.parameters():
                param.requires_grad = True
        else:
            raise Exception("status must be 'on' or 'off'.")
            
    def power_context(self, status = False):
        if status == 'off':
            for param in self.model.global_context.parameters():
                param.requires_grad = False
        elif status == 'on':
            for param in self.model.global_context.parameters():
                param.requires_grad = True
        else:
            raise Exception("status must be 'on' or 'off'.")
    
    def power_regressor(self, status = False):
        if status == 'off':
            for param in self.model.global_regressor.parameters():
                param.requires_grad = False
        elif status == 'on':
            for param in self.model.global_regressor.parameters():
                param.requires_grad = True
        else:
            raise Exception("status must be 'on' or 'off'.")
            
    def power_all(self, status = False):
        if status == 'off':
            for param in self.model.parameters():
                param.requires_grad = False
        elif status == 'on':
            for param in self.model.parameters():
                param.requires_grad = True
        else:
            raise Exception("status must be 'on' or 'off'.")
            
    def save_model(self, file_name = 'model-{}-{}.pth'):
        checkpoint_path = os.path.join(args.model_dir, file_name)
        torch.save(self.model.state_dict(),checkpoint_path)
        print('saving model to' +  file_name)
            
    def loss(self,x0, x1, y0, y1):
        start = time.time()
        
        x0,x1,y0,y1 = x0.to(self.device),x1.to(self.device),y0.to(self.device),y1.to(self.device)
        y0_norm, y1_norm = [normalize(y,self.norm_mean, self.norm_std) for y in [y0,y1]]
        
        relative_target_normed = get_relative_pose(y0_norm, y1_norm)
        
        #self.optimizer.zero_grad()
        
        global_output0,global_output1 = self.model(x0, x1)
        relative_consistence = get_relative_pose(global_output0,global_output1)
        global_loss = translational_rotational_loss(pred=global_output1, \
                                                    gt=y1_norm, \
                                                    lamda=args.lamda_weights)
        geometry_consistent_loss = translational_rotational_loss(pred=relative_consistence, \
                                                                 gt=relative_target_normed, \
                                                                 lamda=args.lamda_weights)
        total_loss = global_loss + geometry_consistent_loss        
        #total_loss.backward()
        #self.optimizer.step()
        
        end = time.time()
        batch_time = end - start
        return batch_time, total_loss
    
    def eval_forward(self,x,y):
        x,y = x.to(self.device),y.to(self.device)
        
        global_output = self.model(x)
        trans_target, rot_target = torch.split(y, [3, 4], dim=1)
        global_output_demormed = denormalize(global_output, self.norm_mean, self.norm_std)
        trans_prediction, rot_prediction = torch.split(global_output_demormed, [3, 4], dim=1)
        return trans_prediction, rot_prediction, trans_target, rot_target

cnn_model = CNN_Model(training=False,device="cuda:1")
cnn_model.load_model('pretrained.pth',display_info=False)
cnn_model.power_all('off')

# Initialize

In [7]:
trans_errors = []
rot_errors = []
uncertainties = []
pose_map = []

total_trans_error = 0.
total_rot_error = 0.

count = 0.

is_save_map = False
is_read_map = False

trans_preds = []
trans_gts = []

rot_preds = []
rot_gts = []

pred_uncertainties = []

pred_time = []

br = tf_ros.TransformBroadcaster()

GT_POSE_TOPIC = '/gt_pose'
BIRDVIEW_TOPIC_PUB = '/bird_view'
MAP_TOPIC_PUB = '/pose_map'
PARTICLES_PUB = '/particles'
NN_LOCALIZASION_PUB = '/nn_pose'
gt_pose_pub = rospy.Publisher(GT_POSE_TOPIC, Odometry, queue_size=1)
bird_view_pub = rospy.Publisher(BIRDVIEW_TOPIC_PUB, Image, queue_size=1)
map_pub = rospy.Publisher(MAP_TOPIC_PUB, Path, queue_size=1)
particles_pub = rospy.Publisher(PARTICLES_PUB, PoseArray, queue_size=1)
nn_pose_pub = rospy.Publisher(NN_LOCALIZASION_PUB, Odometry, queue_size=1)

# Evaluate

In [8]:
cnn_model.model.eval()

for b, data in enumerate(dataloader, 0):
    start = time.time()
    x,y = data.values()
    
    with torch.no_grad():
        trans_pred, rot_pred, trans_gt, rot_gt = cnn_model.eval_forward(x,y)
    trans_pred = trans_pred.cpu().numpy()
    rot_pred = rot_pred.cpu().numpy()
    trans_gt = trans_gt.cpu().numpy()
    rot_gt = rot_gt.cpu().numpy()
    x = x.numpy()
    end = time.time()
    
    if args.enable_ros:
        for i in range(y.shape[0]):
            br.sendTransform(trans_pred[i],rot_pred[i], rospy.Time.now(),"estimation", "world")
            br.sendTransform(trans_gt[i],rot_gt[i],rospy.Time.now(), "gt", "world")

            timestamp = rospy.Time.now()

            gt_msg = Odometry()
            gt_msg.header.frame_id = 'world'
            gt_msg.header.stamp = timestamp
            gt_msg.child_frame_id = 'base_link'
            gt_msg.pose.pose.position.x = trans_gt[i][0]
            gt_msg.pose.pose.position.y = trans_gt[i][1]
            gt_msg.pose.pose.position.z = trans_gt[i][2]
            gt_msg.pose.pose.orientation.x = rot_gt[i][0]
            gt_msg.pose.pose.orientation.y = rot_gt[i][1]
            gt_msg.pose.pose.orientation.z = rot_gt[i][2]
            gt_msg.pose.pose.orientation.w = rot_gt[i][3]

            bridge = CvBridge()

            bird_view_img_msg = bridge.cv2_to_imgmsg(x[i][0], encoding="passthrough") # 1x300x300 -> 300x300
            stamp_now = rospy.Time.now()
            bird_view_img_msg.header.stamp = stamp_now

            bird_view_pub.publish(bird_view_img_msg)

            rospy.sleep(.0)

            count += 1
    else:
        count += y.shape[0]

    trans_preds += [x for x in trans_pred]
    rot_preds += [x for x in rot_pred]
    trans_gts += [x for x in trans_gt]
    rot_gts += [x for x in rot_gt]

    trans_error = np.sqrt(np.sum((trans_pred - trans_gt)**2,axis=1))
    rot_error_1 = np.arccos(np.sum(np.multiply(rot_pred,rot_gt),axis=1))/math.pi*180
    rot_error_2 = np.arccos(np.sum(np.multiply(rot_pred,-rot_gt),axis=1))/math.pi*180
    rot_error = np.minimum(rot_error_1,rot_error_2)

    trans_errors += [x for x in trans_error]
    rot_errors += [x for x in rot_error]

    total_trans_error += np.sum(trans_error)
    total_rot_error += np.sum(rot_error)

    display = 1

    if b % display == 0:
        print(
            "{}/{}, translation error = {:.3f}, rotation error = {:.3f}, time/batch = {:.3f}"
            .format(
             (b+1)*args.batch_size,
            len(dataloader)*args.batch_size,
            total_trans_error / count,
            total_rot_error / count,
            end - start))

#print("pred time", np.mean(np.array(pred_time)))
#print("time std", np.std(np.array(pred_time)))
    

450/92250, translation error = 8.783, rotation error = 4.420, time/batch = 4.818
900/92250, translation error = 9.578, rotation error = 3.906, time/batch = 1.135
1350/92250, translation error = 8.834, rotation error = 3.834, time/batch = 1.127
1800/92250, translation error = 7.584, rotation error = 3.928, time/batch = 1.119
2250/92250, translation error = 6.880, rotation error = 4.094, time/batch = 1.116
2700/92250, translation error = 6.312, rotation error = 4.231, time/batch = 1.114
3150/92250, translation error = 6.081, rotation error = 4.153, time/batch = 1.110
3600/92250, translation error = 5.934, rotation error = 4.187, time/batch = 1.119
4050/92250, translation error = 5.589, rotation error = 3.978, time/batch = 1.117
4500/92250, translation error = 5.311, rotation error = 3.883, time/batch = 1.115
4950/92250, translation error = 5.052, rotation error = 3.821, time/batch = 1.113
5400/92250, translation error = 4.919, rotation error = 3.799, time/batch = 1.121
5850/92250, transl

45000/92250, translation error = 9.659, rotation error = 5.194, time/batch = 1.129
45450/92250, translation error = 9.594, rotation error = 5.174, time/batch = 1.131
45900/92250, translation error = 9.534, rotation error = 5.161, time/batch = 1.130
46350/92250, translation error = 9.490, rotation error = 5.164, time/batch = 1.129
46800/92250, translation error = 9.446, rotation error = 5.148, time/batch = 1.134
47250/92250, translation error = 9.398, rotation error = 5.134, time/batch = 1.136
47700/92250, translation error = 9.380, rotation error = 5.117, time/batch = 1.129
48150/92250, translation error = 9.336, rotation error = 5.117, time/batch = 1.132
48600/92250, translation error = 9.283, rotation error = 5.096, time/batch = 1.127
49050/92250, translation error = 9.235, rotation error = 5.075, time/batch = 1.131
49500/92250, translation error = 9.179, rotation error = 5.058, time/batch = 1.127
49950/92250, translation error = 9.133, rotation error = 5.039, time/batch = 1.128
5040

89100/92250, translation error = 12.642, rotation error = 6.572, time/batch = 1.132
89550/92250, translation error = 12.593, rotation error = 6.565, time/batch = 1.130
90000/92250, translation error = 12.604, rotation error = 6.561, time/batch = 1.130
90450/92250, translation error = 12.564, rotation error = 6.551, time/batch = 1.135
90900/92250, translation error = 12.531, rotation error = 6.553, time/batch = 1.129
91350/92250, translation error = 12.502, rotation error = 6.544, time/batch = 1.127
91800/92250, translation error = 12.458, rotation error = 6.531, time/batch = 1.141
92250/92250, translation error = 12.450, rotation error = 6.535, time/batch = 2.143


In [9]:
import scipy.io as sio

sio.savemat('results.mat', {'trans_pred': np.array(trans_preds), 'trans_gt': np.array(trans_gts), 'uncertainty': np.array(pred_uncertainties)})

if len(pose_map):
    np.savetxt(os.path.join(args.map_dataset, 'map.txt'), np.asarray(pose_map, dtype=np.float32))
    print("map is saved!")

plt.hist(trans_errors, bins='auto')
plt.title("Translation errors")
plt.xlabel("translational error in meters")
plt.ylabel("number of frames")
plt.savefig('terror.png', bbox_inches='tight')

plt.hist(rot_errors, bins='auto')
plt.title("Rotation errors")
plt.xlabel("rotational error in degree")
plt.ylabel("number of frames")
plt.savefig('rerror.png', bbox_inches='tight')

median_trans_errors = np.median(trans_errors)
median_rot_errors = np.median(rot_errors)
mean_trans_errors = np.mean(trans_errors)
mean_rot_errors = np.mean(rot_errors)

print("median translation error = {:.3f}".format(median_trans_errors))
print("median rotation error = {:.3f}".format(median_rot_errors))
print("mean translation error = {:.3f}".format(mean_trans_errors))
print("mean rotation error = {:.3f}".format(mean_rot_errors))   

median translation error = 4.014
median rotation error = 3.254
mean translation error = 12.450
mean rotation error = 6.535


In [10]:
def evaluate(trans_errors,rot_errors):
    t = [14301,7008,12852,9567,13580,14835,7114,12683]
    for i in range(len(t)):
        if i >0:
            t[i] += t[i-1]
    trans_errors_month = list()
    trans_errors_month.append(trans_errors[:t[0]])
    trans_errors_month.append(trans_errors[t[0]:t[1]])
    trans_errors_month.append(trans_errors[t[1]:t[2]])
    trans_errors_month.append(trans_errors[t[2]:t[3]])
    trans_errors_month.append(trans_errors[t[3]:t[4]])
    trans_errors_month.append(trans_errors[t[4]:t[5]])
    trans_errors_month.append(trans_errors[t[5]:t[6]])
    trans_errors_month.append(trans_errors[t[6]:])

    rot_errors_month = list()
    rot_errors_month.append(rot_errors[:t[0]])
    rot_errors_month.append(rot_errors[t[0]:t[1]])
    rot_errors_month.append(rot_errors[t[1]:t[2]])
    rot_errors_month.append(rot_errors[t[2]:t[3]])
    rot_errors_month.append(rot_errors[t[3]:t[4]])
    rot_errors_month.append(rot_errors[t[4]:t[5]])
    rot_errors_month.append(rot_errors[t[5]:t[6]])
    rot_errors_month.append(rot_errors[t[6]:])
    
    print('================== median translation error ==================')
    for trans_errors_i in trans_errors_month:
        print("median translation error = {:.3f}".format(np.median(trans_errors_i)))
        
    print('================== median rotation error ==================')
    for rot_errors_i in rot_errors_month:
        print("median rotation error = {:.3f}".format(np.median(rot_errors_i)))
    
    print('================== mean translation error ==================')
    for trans_errors_i in trans_errors_month:
        print("mean translation error = {:.3f}".format(np.mean(trans_errors_i)))
        
    print('================== mean rotation error ==================')  
    for rot_errors_i in rot_errors_month:
        print("mean rotation error = {:.3f}".format(np.mean(rot_errors_i)))
        
evaluate(trans_errors,rot_errors)

median translation error = 3.420
median translation error = 3.159
median translation error = 3.957
median translation error = 3.659
median translation error = 4.189
median translation error = 4.038
median translation error = 5.335
median translation error = 5.517
median rotation error = 3.024
median rotation error = 2.813
median rotation error = 3.035
median rotation error = 2.794
median rotation error = 3.319
median rotation error = 3.360
median rotation error = 3.776
median rotation error = 4.268
mean translation error = 6.054
mean translation error = 5.826
mean translation error = 14.223
mean translation error = 12.198
mean translation error = 11.937
mean translation error = 14.344
mean translation error = 20.826
mean translation error = 15.355
mean rotation error = 4.429
mean rotation error = 4.218
mean rotation error = 6.175
mean rotation error = 5.712
mean rotation error = 6.322
mean rotation error = 6.256
mean rotation error = 8.596
mean rotation error = 10.572
