## Import dependencies

In [1]:
import numpy as np
from matplotlib import pyplot as plt
import torch
from pytorch3d.transforms import quaternion_to_axis_angle, axis_angle_to_matrix, matrix_to_quaternion
import os
import PIL.Image
import torchvision

## Select dataset
Datasets are in TUM format

In [16]:
datasets = {
    'tum_desk': {
        'directory': 'rgbd_dataset_freiburg3_long_office_household/',
        'camera_factor': 5000.,
        'camera_fx': 535.4,
        'camera_fy': 539.2,
        'camera_cx': 320.1,
        'camera_cy': 247.6,
        'camera_distortion': [0., 0., 0., 0., 0.], # k1, k2, p1, p2, k3
        'camera_near': 0.3,
        'camera_far': 8.,
        'img_w': 640,
        'img_h': 480,
        'normalisation_factor': 10., # used to keep the model query region close to [-1, 1] range
    },
    'tiny_nerf': {
        'camera_fx': 138.88887889922103,
        'camera_fy': 138.88887889922103,
        'camera_cx': 50.,
        'camera_cy': 50.,
        'camera_distortion': [0., 0., 0., 0., 0.], # k1, k2, p1, p2, k3
        'camera_near': 2.,
        'camera_far': 6.,
        'img_w': 100,
        'img_h': 100,
        'normalisation_factor': 10., # used to keep the model query region close to [-1, 1] range
    },
}

dataset_dict = datasets['tiny_nerf']
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

In [17]:
class TinyNerfDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dict):
        super(TinyNerfDataset).__init__()
        
        data = np.load('tiny_nerf_data.npz')
        self.images = data['images']
        self.poses = data['poses']
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, idx):
        rgb = torch.tensor(self.images[idx]).permute((2,0,1))
        pose = self.poses[idx]
        axis_angle = quaternion_to_axis_angle(matrix_to_quaternion(torch.tensor(pose[:3,:3])))
        translation = torch.tensor(pose[:3,-1])
        gt_pose = torch.stack([translation, axis_angle], -1)
        
        return 0., torch.zeros((1, 100, 100)).to(device), rgb.to(device), gt_pose.to(device)
    
dataset = TinyNerfDataset(dataset_dict)

In [10]:
class SlamDataset(torch.utils.data.Dataset):
    def __init__(self, dataset_dict):
        super(SlamDataset).__init__()
        
        self.entries = []
        self.dataset_dict = dataset_dict
        self.convert_tensor = torchvision.transforms.ToTensor()
        
        assoc_path = os.path.join(dataset_dict['directory'], "associations.txt")
        gt_path = os.path.join(dataset_dict['directory'], "groundtruth.txt")
        gt_timestamp = 0.
        
        with open(assoc_path) as assoc, open(gt_path) as gt:
            while True:
                line = assoc.readline()
                
                if line == "":
                    break
            
                t, depth_path, _, rgb_path = line.rstrip().split(' ')
                timestamp = float(t)

                while gt_timestamp < timestamp:
                    line = gt.readline()
                    if line.startswith('#'):
                        continue
                    t, tx, ty, tz, qx, qy, qz, qw = line.rstrip().split(' ')
                    gt_timestamp = float(t)

                axis_angle = quaternion_to_axis_angle(torch.tensor([float(qx), float(qy), float(qz), float(qw)]))
                translation = torch.tensor([float(tx), float(ty), float(tz)])
                gt_pose = torch.stack([translation, axis_angle], -1)

                self.entries.append((timestamp, depth_path, rgb_path, gt_pose))
            
        
    def __len__(self):
        return len(self.entries)
        
    def __getitem__(self, idx):
        timestamp, depth_path, rgb_path, gt_pose = self.entries[idx]
        
        depth = self.convert_tensor(PIL.Image.open(os.path.join(self.dataset_dict['directory'], depth_path)))
        depth = depth.float() / self.dataset_dict['camera_factor']
        depth[depth < self.dataset_dict['camera_near']] = 0.
        depth[depth > self.dataset_dict['camera_far']] = -1.
        rgb = self.convert_tensor(PIL.Image.open(os.path.join(self.dataset_dict['directory'], rgb_path)))
        
        return timestamp, depth.to(device), rgb.to(device), gt_pose.to(device)
    
dataset = SlamDataset(dataset_dict)

In [13]:
# FROM https://github.com/krrish94/nerf-pytorch/blob/a14357da6cada433d28bf11a45c7bcaace76c06e/nerf/nerf_helpers.py

def cumprod_exclusive(tensor: torch.Tensor) -> torch.Tensor:
    r"""Mimick functionality of tf.math.cumprod(..., exclusive=True), as it isn't available in PyTorch.
    Args:
    tensor (torch.Tensor): Tensor whose cumprod (cumulative product, see `torch.cumprod`) along dim=-1
      is to be computed.
    Returns:
    cumprod (torch.Tensor): cumprod of Tensor along dim=-1, mimiciking the functionality of
      tf.math.cumprod(..., exclusive=True) (see `tf.math.cumprod` for details).
    """
    # TESTED
    # Only works for the last dimension (dim=-1)
    dim = -1
    # Compute regular cumprod first (this is equivalent to `tf.math.cumprod(..., exclusive=False)`).
    cumprod = torch.cumprod(tensor, dim)
    # "Roll" the elements along dimension 'dim' by 1 element.
    cumprod = torch.roll(cumprod, 1, dim)
    # Replace the first element by "1" as this is what tf.cumprod(..., exclusive=True) does.
    cumprod[..., 0] = 1.0

    return cumprod

In [18]:
i, j = torch.meshgrid(torch.arange(0, dataset_dict['img_w']), torch.arange(0, dataset_dict['img_h']), indexing='xy')
ray_dirs_cannonical = torch.stack([(i-dataset_dict['img_w']*.5) / dataset_dict['camera_fx'],
                    -(j-dataset_dict['img_h']*.5) / dataset_dict['camera_fy'],
                    -torch.ones_like(i)], -1).to(device)
rays_norms = torch.norm(ray_dirs_cannonical, dim=-1).to(device)

def get_rays(camera_pose):
    rotation_mat = axis_angle_to_matrix(camera_pose[..., 1])
    ray_dirs = torch.sum(ray_dirs_cannonical[None, :, :, None, :] * rotation_mat[:, None, None, :, :], dim=-1)
    ray_origins = camera_pose[:, None, None, :, 0].expand(ray_dirs.shape)

    return ray_origins, ray_dirs

def render_rays(model, ray_origins, ray_dirs, n_samples=48):
    depth_values = torch.linspace(dataset_dict['camera_near'], dataset_dict['camera_far'], n_samples).to(ray_origins)
    depth_values = depth_values.broadcast_to(list(ray_origins.shape[:-1]) + [n_samples])
    
#     # Randomize
#     noise_shape = list(ray_origins.shape[:-1]) + [n_samples]
#     depth_values = (
#         depth_values
#         + torch.rand(noise_shape).to(ray_origins)
#         * (dataset_dict['camera_far'] - dataset_dict['camera_near'])
#         / n_samples
#     )
    
    query_points = (
        ray_origins[..., None, :]
        + ray_dirs[..., None, :] * depth_values[..., :, None]
    )
    
    query_points = query_points.reshape((-1, dataset_dict['img_w'], n_samples, 3)).unsqueeze(0)
    query_points = query_points / dataset_dict['normalisation_factor']

    output = torch.nn.functional.grid_sample(model, query_points)
    output = output.reshape((1, 4, -1, dataset_dict['img_h'], dataset_dict['img_w'], n_samples))
    output = output.squeeze(0).swapaxes(0, 1)

    sigma_a = torch.relu(output[:,:1,...])
    rgb = torch.sigmoid(output[:,1:,...])
    one_e_10 = torch.tensor([1e5]).to(device)
    dists = torch.cat(
        (
            depth_values[..., 1:] - depth_values[..., :-1],
            one_e_10.expand(depth_values[..., :1].shape),
        ),
        dim=-1,
    )
    dists = dists[:,None,...] * rays_norms[None,None,...,None]
    alpha = 1.0 - torch.exp(-sigma_a * dists)
    weights = alpha * cumprod_exclusive(1.0 - alpha + 1e-10)
    rgb_map = (weights * rgb).sum(dim=-1)
    depth_map = (weights * depth_values[:,None,...]).sum(dim=-1)

    return rgb_map, depth_map

def render(model, pose):
    ray_origins, ray_dirs = get_rays(pose)
    return render_rays(model, ray_origins, ray_dirs)

def render_and_optim(model, optimizer, pose, rgb_gt, depth_gt):
    optimizer.zero_grad()
    rgb_pred, depth_pred = render(model, pose)
    loss = torch.nn.functional.mse_loss(rgb_pred, rgb_gt)
    
#     depth_loss = torch.nn.functional.mse_loss(depth_pred, depth_gt)
#     loss += 1e-1 * depth_loss

#     # TV regularisation
#     shifted = torch.roll(model, shifts=(1, 1, 1), dims=(2, 3, 4))
#     tv_losses = torch.sqrt(torch.sum((model[...,1:-1,1:-1,1:-1] - shifted[...,1:-1,1:-1,1:-1]) ** 2., dim=1))
#     loss += 1e-1 * torch.mean(tv_losses)
    
    loss.backward()
    optimizer.step()

# for img in range(100):
#     timestamp, depth_gt, rgb_gt, pose_gt = dataset[:500:20]
#     rgb_gt = rgb_gt.unsqueeze(0).to(device)
#     pose_gt = pose_gt.to(device)

#     for iter in range(20):
#         render_and_optim(voxel_model, model_optimizer, pose_gt, rgb_gt)

# #     for iter in range(50):
# #         render_and_optim(voxel_model, pose_optimizer, current_pose, rgb_gt)
    
#     if img % 10 == 0:
#         ray_origins, ray_dirs = get_rays(pose_gt)
#         rgb_pred, depth_pred = render_rays(voxel_model, ray_origins, ray_dirs)
#         plt.imshow(rgb_pred.detach().cpu()[0].permute((1,2,0)))
#         plt.show()
#         plt.imshow(depth_pred.detach().cpu()[0,0])
#         plt.show()

In [None]:
train_set, val_set = torch.utils.data.random_split(dataset, [len(dataset)-1, 1])
train_loader = torch.utils.data.DataLoader(train_set, batch_size=4, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_set, batch_size=1, shuffle=False)
voxel_model = torch.nn.parameter.Parameter(torch.rand((1,4,256,256,256)).to(device))
model_optimizer = torch.optim.RMSprop([voxel_model], lr=1e-1)

# _, _, _, pose_gt = dataset[0]
# current_pose = torch.nn.parameter.Parameter(pose_gt.clone())
# pose_optimizer = torch.optim.Adam([current_pose], lr=1e-3)

def train_step(model, model_optimizer, data):
    timestamp, depth_gt, rgb_gt, pose_gt = data
    render_and_optim(model, model_optimizer, pose_gt, rgb_gt, depth_gt)
    
def val_step(model, val_loader):
    loss = 0.
    with torch.no_grad():
        for data in val_loader:
            timestamp, depth_gt, rgb_gt, pose_gt = data
            rgb_pred, depth_pred = render(model, pose_gt)
            loss += torch.nn.functional.mse_loss(rgb_pred, rgb_gt)
    print('Validation loss: ', round(loss.detach().cpu().numpy() / 2., 5))
    fig, axes = plt.subplots(1,2)
    axes[0].imshow(rgb_pred.detach().cpu()[0].permute((1,2,0)))
    axes[1].imshow(rgb_gt.cpu()[0].permute((1,2,0)))
    plt.show()

for epoch in range(100):
    for i, data in enumerate(train_loader):
        train_step(voxel_model, model_optimizer, data)
        if i % 10 == 0:
            val_step(voxel_model, val_loader)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()