# Testing odometry

In [None]:
import glob
import time

from tqdm.notebook import trange
import torch
import numpy as np
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100
import matplotlib.pyplot as plt

from slam_framework.neural_slam import NeuralSLAM
from utils.helpers import log
from utils.arguments import Arguments
from localization.datasets import ColorDataset


args = Arguments.get_arguments()
sequence_length = 1

dataset = ColorDataset(data_path=args.data_path, sequence="00")

weights_file = "checkpoints/10_1atdnvo_c.pth" # TODO overwrite tod actual
slam = NeuralSLAM(args, odometry_weights=weights_file)

In [None]:
slam.start_odometry()
print("SLAM mode: ", slam.mode())

In [None]:
global_scale = []
slam_call_time = []

for i in trange(len(dataset)):
    img = dataset[i]
    
    start = time.time()
    current_pose = slam(img.squeeze())
    end = time.time()
    
    slam_call_time.append(end-start)
    global_scale.append(current_pose)


global_scale = torch.stack(global_scale, dim=0)
slam_call_time = np.array(slam_call_time)

## FPS calculation

In [None]:
log("Average odometry time: ", slam_call_time.mean())
log("Odometry time std: ", slam_call_time.std())
fps_manual = 1/(slam_call_time.mean())
log("FPS from time: ", 1/slam_call_time.mean())

## Keyframe check

In [None]:
slam.end_odometry()

### Check wheter odometry state saved poses correctly

In [None]:
import matplotlib.pyplot as plt

DATA_PATH = args.keyframes_path + "/poses.pth"
poses = torch.load(DATA_PATH)
print(poses.shape)
X = poses[:, 3]
Z = poses[:, -1]
plt.scatter(X.numpy(), Z.numpy())
plt.show()

### Check keyframe recovery

In [None]:
keyframe_positions = []

for i in range(len(slam)):
    pose = slam[i].pose
    keyframe_positions.append(pose[:3, 3])

keyframe_positions = torch.stack(keyframe_positions, dim=0).to("cpu")
global_pos = global_scale[:, :3, -1].to('cpu')
X, Y, Z = global_pos[:, 0], global_pos[:, 1], global_pos[:, 2]
X_key, Y_key, Z_key = keyframe_positions[:, 0], keyframe_positions[:, 1], keyframe_positions[:, 2]

plt.plot(X, Z)
plt.scatter(X_key, Z_key)
plt.show()

plt.plot(X); plt.show()
plt.plot(Y); plt.show()
plt.plot(Z); plt.show()

In [None]:
slam = NeuralSLAM(args, odometry_weights=weights_file, start_mode="mapping")

keyframe_positions = []

for i in range(len(slam)):
    keyframe_positions.append(slam[i].pose[:3, 3])

keyframe_positions = torch.stack(keyframe_positions, dim=0).to("cpu")
X_key, Y_key, Z_key = keyframe_positions[:, 0], keyframe_positions[:, 1], keyframe_positions[:, 2]

plt.plot(X, Z)
plt.scatter(X_key, Z_key)
plt.show()

------------------------------------------------
# Test mapping

In [None]:
import time

import torch
import torchvision.transforms.functional as TF
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['figure.dpi'] = 100

from utils.helpers import log
from utils.transforms import matrix2euler
from utils.arguments import Arguments
from odometry.datasets import KittiOdometryDataset
from slam_framework.neural_slam import NeuralSLAM


args = Arguments.get_arguments()
dataset = KittiOdometryDataset(data_path=args.data_path, sequence="00")

weights_file = "checkpoints/10_1atdnvo_c.pth" # TODO overwrite to actual
slam = NeuralSLAM(args, odometry_weights=weights_file, start_mode="relocalization")

with torch.no_grad():
    rgb, true_orientation, true_position = dataset[195]
    rgb = TF.resize(rgb, (376, 1232))
    initial_pose, refined_pose, distances = slam(rgb)
    log("Distances shape", distances.shape)

    # Histogram
    plt.hist(distances.cpu().numpy(), bins=1000)
    plt.xlabel("Distance from sample")
    plt.ylabel("Count of elements")
    plt.show()

    plt.plot(distances.cpu().numpy())
    plt.xlabel("Index of keyframe")
    plt.ylabel("Embedding distance from sample")
    plt.show()
    
    # Predicted index
    pred_index = torch.argmin(distances)
    distances[pred_index] = distances.max()
    second_pred_index = torch.argmin(distances)
    print(pred_index)
    print(second_pred_index)

pred_im = torch.load(slam[int(pred_index.squeeze())].rgb_file_name).permute(1, 2, 0)
second_pred_im = torch.load(slam[int(second_pred_index.squeeze())].rgb_file_name).permute(1, 2, 0)
plt.imshow(pred_im); plt.show()
plt.imshow(second_pred_im); plt.show()

def to_vectors(mat):
    orientation = matrix2euler(mat[:3, :3])
    position = mat[:3, -1]
    return orientation, position

initial_orientation, initial_position = to_vectors(initial_pose.to("cpu"))
refined_orientation, refined_position = to_vectors(refined_pose.to("cpu"))

log("True pose: ", [true_orientation, true_position])
log("Initial estimate: ", [initial_orientation, initial_position])
log("Refined estimate: ", [refined_orientation, refined_position])

log("Initial difference: ", [(true_orientation-initial_orientation).abs().sum(), (true_position-initial_position).abs().sum()])
log("Refined difference: ", [(true_orientation-refined_orientation).abs().sum(), (true_position-refined_position).abs().sum()])