In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('..')

from shared import metrics
from shared import common

import cv2
import os
import numpy as np
import quaternion as quat
from scipy.spatial.transform import Rotation

import albumentations as albu

from matplotlib import pyplot as plt

from shared.transforms import ResizeKeepingRatio

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader, ConcatDataset

import neptune
neptune.init('kail4ek/visual-odometry')
# https://ui.neptune.ai/kail4ek/visual-odometry/experiments

In [None]:
fill_value = (127, 127, 127)

def get_train_transform(target_wh=(1024, 320)):
    return albu.Compose([
#         albu.RandomRotate90(),
#         albu.Flip(),
#         albu.Transpose(),
        albu.OneOf([
            albu.IAAAdditiveGaussianNoise(),
            albu.GaussNoise(),
        ], p=0.2),
        albu.OneOf([
            albu.MotionBlur(p=0.2),
            albu.MedianBlur(blur_limit=3, p=0.1),
            albu.Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.2, rotate_limit=0, p=0.4),
#         albu.OneOf([
#             albu.OpticalDistortion(p=0.3),
#             albu.GridDistortion(p=0.1),
#             albu.IAAPiecewiseAffine(p=0.3),
#         ], p=0.2),
        albu.OneOf([
            albu.CLAHE(clip_limit=2),
            albu.IAASharpen(),
            albu.IAAEmboss(),
            albu.RandomBrightnessContrast(),
        ], p=0.3),
        albu.HueSaturationValue(p=0.3),
        ResizeKeepingRatio(target_wh=target_wh, interpolation=cv2.INTER_CUBIC, always_apply=True),
        albu.PadIfNeeded(min_height=target_wh[1],
                         min_width=target_wh[0],
                         border_mode=0,
                         value=fill_value,
                         always_apply=True)
    ],
    additional_targets={"image2" : "image"})

def get_valid_transform(target_wh=(1024, 320)):
    return albu.Compose([
        ResizeKeepingRatio(target_wh=target_wh, interpolation=cv2.INTER_CUBIC, always_apply=True),
        albu.PadIfNeeded(min_height=target_wh[1],
                         min_width=target_wh[0],
                         border_mode=0,
                         value=fill_value,
                         always_apply=True)
    ],
    additional_targets={"image2" : "image"})

In [None]:
class KITTIDataset(Dataset):
    def __init__(self, dataset_dir, transform, sequence_id='00', debug=False):
        self.debug = debug
        
        SEQUENCE_DIR = os.path.join(DATASET_DIR, 'sequences', sequence_id)
        POSES_DIR = os.path.join(DATASET_DIR, 'poses')

        POSE_PATH = os.path.join(POSES_DIR, f'{sequence_id}.txt')
        TIMES_PATH = os.path.join(SEQUENCE_DIR, 'times.txt')

        self.poses = self._load_poses(POSE_PATH)
        self.times = self._load_times(TIMES_PATH)
        
        self.IMAGES_DIR = os.path.join(SEQUENCE_DIR, 'image_2')
        self.images = [fname for fname in os.listdir(self.IMAGES_DIR) if fname.endswith('.png')]
        
        print(f'Sequence {sequence_id} length: {len(self.poses)}')
        
        # Sanity check!
        for i in range(len(self.poses)):
            fname = self._get_image_fname(i)
            if fname not in self.images:
                Exception(f'File with name {fname} not exists in {IMAGES_DIR}')
        # After this check we can use idx to generate fpaths
        
        self.transform = transform

    def _get_image_fname(self, idx):
        return f'{idx:06}.png'
        
    def _load_times(self, fpath):
        times_data = np.fromfile(fpath, sep='\n')
        return times_data
    
    def _load_poses(self, fpath):
        poses_data = np.fromfile(fpath, sep=' ')
        poses_data = poses_data.reshape((-1, 3, 4))
        # Convert to 4x4 matrices
        last_row = np.array([[[0,0,0,1]]])
        last_rows = np.repeat(last_row, axis=0, repeats=poses_data.shape[0])
        poses_data = np.hstack((poses_data, last_rows))
        return poses_data

    def __len__(self):
        return len(self.poses)-1

    def _get_transform(self, idx):
        c_idx = idx
        n_idx = idx+1
        
        c_pose = self.poses[c_idx]
        n_pose = self.poses[n_idx]

        local_dtrans = np.linalg.inv(c_pose) @ n_pose[:, 3]

        quat_c = quat.from_rotation_matrix(c_pose[:3,:3])
        quat_n = quat.from_rotation_matrix(n_pose[:3,:3])
        quat_t = quat_c.inverse() * quat_n

        gt_quat_t_ar = quat.as_float_array(quat_t).astype(np.float32)
        gt_trans = local_dtrans[:3].astype(np.float32)

        return gt_quat_t_ar, gt_trans
    
    def __getitem__(self, idx):
        c_idx = idx
        n_idx = idx+1
        
        c_pose = self.poses[c_idx]
        n_pose = self.poses[n_idx]

        c_img_fpath = os.path.join(
            self.IMAGES_DIR, 
            self._get_image_fname(c_idx)
        )
        n_img_fpath = os.path.join(
            self.IMAGES_DIR, 
            self._get_image_fname(n_idx)
        )
        
        c_img = cv2.imread(c_img_fpath)
        c_img = cv2.cvtColor(c_img, cv2.COLOR_BGR2RGB)
        n_img = cv2.imread(n_img_fpath)
        n_img = cv2.cvtColor(n_img, cv2.COLOR_BGR2RGB)
        
        transform_input = {
            'image': c_img,
            'image2': n_img
        }
        transform_result = self.transform(**transform_input)
        
        c_img = transform_result['image']
        n_img = transform_result['image2']
        
        def image2tensor(img):
            img = img.transpose(2,0,1)
            img = img / 255 * 2 - 1
            tnsr = torch.from_numpy(img.astype(np.float32))
            return tnsr
            
        gt_quat_t_ar, gt_trans = self._get_transform(idx)
         
        c_tnsr = image2tensor(c_img)
        n_tnsr = image2tensor(n_img)
            
        gt_quat_t_ar = torch.from_numpy(gt_quat_t_ar)
        gt_trans = torch.from_numpy(gt_trans)
            
        if self.debug:
            return c_img, n_img, gt_quat_t_ar, gt_trans
        else:
            return c_tnsr, n_tnsr, gt_quat_t_ar, gt_trans
    
    
def predict_2_next_pose(c_pose, pred_quat, pred_trans):
    n_pose = np.eye(4)
    quat_t = quat.from_float_array(pred_quat)
    quat_c = quat.from_rotation_matrix(c_pose[:3,:3])
    n_pose[:3, :3] = quat.as_rotation_matrix(quat_c * quat_t)
    n_pose[:, 3] = c_pose @ np.array([*pred_trans, 1])

    return n_pose
        

In [None]:
model_input_wh = (1024, 320)

DATASET_DIR = os.path.join('../', 'data/KITTI/dataset')
train_transform = get_train_transform()
valid_transform = get_valid_transform()
dataset = KITTIDataset(DATASET_DIR, train_transform, debug=True)

In [None]:
c_img, n_img, quaternion, trns = dataset[101]

In [None]:
plt.figure(figsize=(15,9))
plt.imshow(c_img)
print(c_img.shape)

plt.figure(figsize=(15,9))
plt.imshow(n_img)
print('Trans', trns)
print('Quat', quaternion)

In [None]:
class PosePrediction(nn.Module):
    def __init__(self, ):
        super().__init__()
        
        self.backbone = torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True)
        self.backbone = nn.Sequential(*(list(self.backbone.children())[:-2]))
        
        self.predict = nn.Sequential(
            *[
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            ]
        )

        self.fc = nn.Sequential(
            *[
                nn.Linear(1024, 120),
                nn.Dropout(p=0.5),
                nn.Linear(120, 84),
                nn.Dropout(p=0.5)
            ]
        )
        
        self.fc_quat = nn.Linear(84, 4)
        self.fc_trns = nn.Linear(84, 3)
        
    def forward(self, img1, img2):
        
        feat_img1 = self.backbone(img1)
        feat_img2 = self.backbone(img2)
        
        x = torch.cat((feat_img1, feat_img2), dim=1)
        
        x = self.predict(x)
        
#         x = x.view([-1, 1024])
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        pred_quat = self.fc_quat(x)        
        pred_quat = F.normalize(pred_quat)
        
        pred_trns = self.fc_trns(x)
        
        return pred_quat, pred_trns

#     def extra_repr(self):
#         return f'(backbone): {self.backbone}'

In [None]:
model = PosePrediction()

sample_input1 = torch.rand(
    (2, 3, 224, 224)
) 

sample_input2 = torch.rand(
    (2, 3, 224, 224)
) 

pred_quat, pred_trns = model(sample_input1, sample_input2)

# Test trajectory collection

In [None]:
dataset = KITTIDataset(DATASET_DIR, train_transform, sequence_id="00", debug=False)

gt_poses = [np.eye(4)]

for i in range(len(dataset)):
    gt_quat, gt_trns = dataset._get_transform(i)
    
    n_pose = predict_2_next_pose(gt_poses[-1], gt_quat, gt_trns)
    gt_poses.append(n_pose)
    
metrics_ate = metrics.compute_ATE(gt_poses, dataset.poses)
print(f'Metrics ATE: {metrics_ate}')
    
plt.figure(figsize=[9,9])
common.plot_trajectory(gt_poses)
    
plt.figure(figsize=[9,9])
common.plot_trajectory(dataset.poses)


## Training code

In [None]:
# As in https://www.cs.ox.ac.uk/files/9026/DeepVO.pdf
train_sequences = ['00', '02', '08', '09']
valid_sequences = ['01', '03', '04', '05', '06', '07', '10']

train_datasets = []
valid_datasets = []

train_transform = get_train_transform()
valid_transform = get_valid_transform()

for seq in train_sequences:
    dataset = KITTIDataset(DATASET_DIR, train_transform, sequence_id=seq, debug=False)
    train_datasets.append(dataset)
    
train_dataset = ConcatDataset(train_datasets)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=8)

for seq in valid_sequences:
    dataset = KITTIDataset(DATASET_DIR, valid_transform, sequence_id=seq, debug=False)
    valid_datasets.append(dataset)
    
valid_dataset = ConcatDataset(valid_datasets)
valid_loader = DataLoader(valid_dataset, batch_size=8, shuffle=True, num_workers=8)

params = {
    'input_wh': model_input_wh,
    'initial_lr': 1e-3,
    'epochs': 500
}

criterion_quat = nn.MSELoss()
criterion_trns = nn.MSELoss()

optimizer = optim.Adam(model.parameters(), lr=params['initial_lr'])

scheduler_cfg = {
    'mode': 'min',
    'factor': 0.5,
    'patience': 30,
    'verbose': False,
    'threshold': 0.0001,
    'threshold_mode': 'rel',
    'cooldown': 0,
    'min_lr': 1e-6,
    'eps': 1e-08
}

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, **scheduler_cfg)

In [None]:
def get_sequence_metrics(dataset, model):
    poses = [np.eye(4)]
    
    dataloader = DataLoader(dataset, batch_size=8, shuffle=False, num_workers=8)
    
    model.eval()
    with torch.no_grad():
        for data in dataloader:          
            c_img, n_img, gt_quat, gt_trns = move_2_device(data)

            pred_quat, pred_trns = model(c_img, n_img)
            pred_quat = pred_quat.cpu()
            pred_trns = pred_trns.cpu()

            for i in range(pred_quat.shape[0]):
                n_pose = predict_2_next_pose(poses[-1], pred_quat[i], pred_trns[i])
                poses.append(n_pose)

    metrics_ate = metrics.compute_ATE(poses, dataset.poses)
    rpe_trns, rpe_rot = metrics.compute_RPE(poses, dataset.poses)
    
    return [metrics_ate, rpe_trns, rpe_rot]

In [None]:
target_device = torch.device('cuda')
model.to(target_device)

def move_2_device(data):
    return [el.to(target_device) for el in data]

modes = ['train', 'valid']

# Test metrics calculation before start
dataset = KITTIDataset(DATASET_DIR, valid_transform, sequence_id="01", debug=False)
seq_metrics = get_sequence_metrics(dataset, model)

assert len(seq_metrics) == 3

In [None]:
from datetime import datetime

tags = []
    
neptune.create_experiment(
#     name=utils.get_neptune_name(config),
    upload_stdout=False,
#     upload_source_files=sources_to_upload,
    params=params,
    tags=tags
)

best_val_loss = 10e10

date = datetime.today().strftime('%Y%m%d-%H%M%S')
chk_dir = f'../data/chks/{date}'
os.makedirs(chk_dir, exist_ok=True)

for epoch in range(params['epochs']):  # loop over the dataset multiple times
    for mode in modes:
        if mode == 'train':
            dataloader = train_loader
            model.train()
        else:
            dataloader = valid_loader
            model.eval()

        loss_vals = []

        for data in dataloader:
            c_img, n_img, gt_quat, gt_trns = move_2_device(data)
            optimizer.zero_grad()
            pred_quat, pred_trns = model(c_img, n_img)

            loss_quat = criterion_quat(pred_quat, gt_quat)
            loss_trns = criterion_trns(pred_trns, gt_trns)
            loss = loss_quat + loss_trns

            if mode == 'train':
                loss.backward()
                optimizer.step()

            # print statistics
            loss_vals.append(loss.item())

        epoch_loss = np.mean(loss_vals)
        print(f'{mode} loss: {epoch_loss}')
        neptune.send_metric(f'{mode}_loss', epoch_loss)
       
        if mode == 'valid':
            scheduler.step(epoch_loss)
            
            if epoch_loss <= best_val_loss:
                best_val_loss = epoch_loss
                save_dict = {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
#                     '_model_config': model_config,
#                     '_train_config': train_config
                }
                chkpnt_fpath = os.path.join(chk_dir, f'val_loss_{epoch}.pthck')
                
                torch.save(save_dict, chkpnt_fpath)
            
            print('Perform sequence metrics collection')
            val_metrics = []
            for seq in valid_sequences:
                dataset = KITTIDataset(DATASET_DIR, valid_transform, sequence_id=seq, debug=False)
                seq_metrics = get_sequence_metrics(dataset, model)
                val_metrics.append(seq_metrics)

            val_mean_metrics = np.mean(val_metrics, axis=0)

            neptune.send_metric('ATE', val_mean_metrics[0])
            neptune.send_metric('RPE_T', val_mean_metrics[1])
            neptune.send_metric('RPE_R', val_mean_metrics[2])

        
        
        
neptune.stop()    
print('Finished Training')