In [None]:
%load_ext autoreload
%autoreload 2

import cv2
import os
import numpy as np
import quaternion as quat
from scipy.spatial.transform import Rotation

import albumentations as albu

from matplotlib import pyplot as plt

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
def get_transform():
    return albu.Compose([
#         albu.RandomRotate90(),
#         albu.Flip(),
#         albu.Transpose(),
        albu.OneOf([
            albu.IAAAdditiveGaussianNoise(),
            albu.GaussNoise(),
        ], p=0.2),
        albu.OneOf([
            albu.MotionBlur(p=0.2),
            albu.MedianBlur(blur_limit=3, p=0.1),
            albu.Blur(blur_limit=3, p=0.1),
        ], p=0.2),
        albu.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=0, p=0.2),
#         albu.OneOf([
#             albu.OpticalDistortion(p=0.3),
#             albu.GridDistortion(p=0.1),
#             albu.IAAPiecewiseAffine(p=0.3),
#         ], p=0.2),
        albu.OneOf([
            albu.CLAHE(clip_limit=2),
            albu.IAASharpen(),
            albu.IAAEmboss(),
            albu.RandomBrightnessContrast(),
        ], p=0.3),
        albu.HueSaturationValue(p=0.3),
    ], p=0.5,
    additional_targets={"image2" : "image"})

In [None]:
class KITTIDataset(Dataset):
    def __init__(self, dataset_dir, transform, debug=False):
        self.debug = debug
        
        SEQUENCE_IDX = '00'
        SEQUENCE_DIR = os.path.join(DATASET_DIR, 'sequences', SEQUENCE_IDX)
        POSES_DIR = os.path.join(DATASET_DIR, 'poses')

        POSE_PATH = os.path.join(POSES_DIR, f'{SEQUENCE_IDX}.txt')
        TIMES_PATH = os.path.join(SEQUENCE_DIR, 'times.txt')

        self.poses = self._load_poses(POSE_PATH)
        self.times = self._load_times(TIMES_PATH)
        
        self.IMAGES_DIR = os.path.join(SEQUENCE_DIR, 'image_2')
        self.images = [fname for fname in os.listdir(self.IMAGES_DIR) if fname.endswith('.png')]
        
        # Sanity check!
        for i in range(len(self.poses)):
            fname = self._get_image_fname(i)
            if fname not in self.images:
                Exception(f'File with name {fname} not exists in {IMAGES_DIR}')
        # After this check we can use idx to generate fpaths
        
        self.transform = transform

    def _get_image_fname(self, idx):
        return f'{idx:06}.png'
        
    def _load_times(self, fpath):
        times_data = np.fromfile(fpath, sep='\n')
        return times_data
    
    def _load_poses(self, fpath):
        poses_data = np.fromfile(fpath, sep=' ')
        poses_data = poses_data.reshape((-1, 3, 4))
        # Convert to 4x4 matrices
        last_row = np.array([[[0,0,0,1]]])
        last_rows = np.repeat(last_row, axis=0, repeats=poses_data.shape[0])
        poses_data = np.hstack((poses_data, last_rows))
        return poses_data

    def __len__(self):
        return len(self.poses)-1

    def __getitem__(self, idx):
        c_idx = idx
        n_idx = idx+1
        
        c_pose = self.poses[c_idx]
        n_pose = self.poses[n_idx]

        c_img_fpath = os.path.join(
            self.IMAGES_DIR, 
            self._get_image_fname(c_idx)
        )
        n_img_fpath = os.path.join(
            self.IMAGES_DIR, 
            self._get_image_fname(n_idx)
        )
        
        c_img = cv2.imread(c_img_fpath)
        c_img = cv2.cvtColor(c_img, cv2.COLOR_BGR2RGB)
        n_img = cv2.imread(n_img_fpath)
        n_img = cv2.cvtColor(n_img, cv2.COLOR_BGR2RGB)
        
        transform_input = {
            'image': c_img,
            'image2': n_img
        }
        transform_result = self.transform(**transform_input)
        
        c_img = transform_result['image']
        n_img = transform_result['image2']
        
        def image2tensor(img):
            img = img.transpose(2,0,1)
            img = img / 255 * 2 - 1
            tnsr = torch.from_numpy(img.astype(np.float32))
            return tnsr
            
#         T = n_pose @ np.linalg.inv(c_pose)
#         print('CPose\n', c_pose)
#         print('NPose\n', n_pose)
#         print('Transform\n', T)
#         print('CEuler\n', Rotation.from_matrix(c_pose[:3,:3]).as_euler('zxy', degrees=True))
#         print('NEuler\n', Rotation.from_matrix(n_pose[:3,:3]).as_euler('zxy', degrees=True))

        local_dtrans = np.linalg.inv(c_pose) @ n_pose[:, 3]
#         print('Local trans', local_dtrans)
        
        quat_c = quat.from_rotation_matrix(c_pose[:3,:3])
        quat_n = quat.from_rotation_matrix(n_pose[:3,:3])
        quat_t = quat_c.inverse() * quat_n
#         print('CQ\n', quat_c)
#         print('NQ\n', quat_n)
#         print('TQ\n', quat_t)
        
        # Recovery process
        n_pose_rec = np.eye(4)
        n_pose_rec[:3, :3] = quat.as_rotation_matrix(quat_c * quat_t)
        n_pose_rec[:, 3] = c_pose @ local_dtrans
#         print('NPose (recovered)\n', n_pose_rec)
#         print('Diff\n', np.abs(n_pose-n_pose_rec))
        
        c_tnsr = image2tensor(c_img)
        n_tnsr = image2tensor(n_img)
        
        quat_t_ar = quat.as_float_array(quat_t).astype(np.float32)
        
        if self.debug:
            return c_img, n_img, quat_t_ar, local_dtrans[:3].astype(np.float32)
        else:
            return c_tnsr, n_tnsr, quat_t_ar, local_dtrans[:3].astype(np.float32)

In [None]:
DATASET_DIR = os.path.join('../', 'data/KITTI/dataset')
transform = get_transform()
dataset = KITTIDataset(DATASET_DIR, transform, debug=True)

In [None]:
c_img, n_img, quaternion, trns = dataset[101]

In [None]:
plt.figure(figsize=(15,9))
plt.imshow(c_img)
print(c_img.shape)

plt.figure(figsize=(15,9))
plt.imshow(n_img)
print('Trans', trns)
print('Quat', quaternion)

In [None]:
class PosePrediction(nn.Module):
    def __init__(self, ):
        super().__init__()
        
        self.backbone = torch.hub.load('pytorch/vision:v0.5.0', 'resnet18', pretrained=True)
        self.backbone = nn.Sequential(*(list(self.backbone.children())[:-2]))
        
        self.predict = nn.Sequential(
            *[
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.Conv2d(1024, 1024, 3),
                nn.BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
                nn.ReLU(inplace=True),
                
                nn.AdaptiveAvgPool2d(output_size=(1, 1)),
            ]
        )

        self.fc = nn.Sequential(
            *[
                nn.Linear(1024, 120),
                nn.Linear(120, 84)
            ]
        )
        
        self.fc_quat = nn.Linear(84, 4)
        self.fc_trns = nn.Linear(84, 3)
        
    def forward(self, img1, img2):
        
        feat_img1 = self.backbone(img1)
        feat_img2 = self.backbone(img2)
        
        x = torch.cat((feat_img1, feat_img2), dim=1)
        
        x = self.predict(x)
        
#         x = x.view([-1, 1024])
        x = torch.flatten(x, 1)
        x = self.fc(x)
        
        pred_quat = self.fc_quat(x)        
        pred_quat = F.normalize(pred_quat)
        
        pred_trns = self.fc_trns(x)
        
        return pred_quat, pred_trns

#     def extra_repr(self):
#         return f'(backbone): {self.backbone}'

In [None]:
model = PosePrediction()

sample_input1 = torch.rand(
    (2, 3, 224, 224)
) 

sample_input2 = torch.rand(
    (2, 3, 224, 224)
) 

pred_quat, pred_trns = model(sample_input1, sample_input2)

In [None]:
pred_quat

## Training code

In [None]:
dataset = KITTIDataset(DATASET_DIR, transform, debug=False)

train_loader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=8)

criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
target_device = torch.device('cuda')
model.to(target_device)

def move_2_device(data):
    return [el.to(target_device) for el in data]

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(train_loader):
        c_img, n_img, gt_quat, gt_trns = move_2_device(data)
        optimizer.zero_grad()
        
        pred_quat, pred_trns = model(c_img, n_img)

        loss_quat = criterion(pred_quat, gt_quat)
        loss_trns = criterion(pred_trns, gt_trns)
        loss = loss_quat + loss_trns
        
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.item()
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')