In [None]:
# Import gradslam related modules
import gradslam as gs
from gradslam import Pointclouds, RGBDImages
from gradslam.datasets import TUM
from gradslam.slam import PointFusion, ICPSLAM


import matplotlib.pyplot as plt
import numpy as np
import os 
import torch
from torch.utils.data import DataLoader

In [None]:
import gc
torch.cuda.empty_cache() # clean cuda cache

In [None]:
# Down TUM dataset
data_path = '../data/'
if not os.path.isdir(data_path + 'TUM'):
    os.mkdir(data_path + 'TUM')
if not os.path.isdir(data_path + 'TUM/rgbd_dataset_freiburg1_desk'):
    print('No dataset found in ', data_path)
    print('Downloading TUM/rgbd_dataset_freiburg1_desk dataset ...')
    os.mkdir(data_path + 'TUM/rgbd_dataset_freiburg1_desk')
    !wget https://vision.in.tum.de/rgbd/dataset/freiburg1/rgbd_dataset_freiburg1_desk.tgz -P ../data/TUM/rgbd_dataset_freiburg1_desk/ -q
    !tar -xzf ../data/TUM/rgbd_dataset_freiburg1_desk/rgbd_dataset_freiburg1_desk.tar -C ../data/TUM/rgbd_dataset_freiburg1_desk/
    !rm ../data/TUM/rgbd_dataset_freiburg1_desk/rgbd_dataset_freiburg1_desk.tar
    print('Downloaded')    

tum_path = data_path + 'TUM/'

n_frames = 50

# Load data
dataset = TUM(tum_path, seqlen=n_frames, dilation=4, height=480, width=640)
loader = DataLoader(dataset=dataset, batch_size=2)
colors, depths, intrinsics, poses, *_ = next(iter(loader))

In [None]:
# Instantiation rgbdimage with poses
rgbdimages = RGBDImages(colors, depths, intrinsics, poses)
print(rgbdimages.shape)
print(rgbdimages.poses)
print('-----')

In [None]:
# Visualize rgb and depth image
rgbdimages.plotly(0).show()

In [None]:
print(torch.cuda.memory_summary(device=None, abbreviated=False))

In [None]:

loader = DataLoader(dataset=dataset, batch_size=1)
colors, depths, intrinsics, poses, *_ = next(iter(loader))


# Create RGB-D image objects 
rgbdimages = RGBDImages(colors.requires_grad_(False), 
                        depths.requires_grad_(False), 
                        intrinsics.requires_grad_(False),
                        poses.requires_grad_(False),
                        )

print(rgbdimages.shape)

In [None]:
import time 

# Fuse point clouds
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print("Using device ", device)
slam = PointFusion(odom='gt', device=device).requires_grad_(False)

pcds = Pointclouds(device=device)
batch_size, seq_len = rgbdimages.shape[:2]
initial_poses = torch.eye(4, device=device).view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1)
prev_frame = None

start = time.time()

for s in range(seq_len):
    live_frame = rgbdimages[:, s].to(device)
    if s == 0 and live_frame.poses is None:
        live_frame.poses = initial_poses
    pcds, live_frame.poses = slam.step(pcds, live_frame, prev_frame)
    prev_frame = live_frame
    del live_frame
    gc.collect()
    print('Time taken: %.3f s' % (time.time() - start))
    start = time.time()

# print("Processing %d frames.\n It has taken %.3f sec per frame." % (n_frames, (time.time() - start) / n_frames))
pcds.plotly(0, max_num_points=10000).update_layout(autosize=False, width=600).show()

In [None]:
import numpy as np 
import plotly.graph_objects as go

def plotly_map_update_visualization(intermediate_pcs, poses, K, max_points_per_pc = 50000, ms_per_frame = 50):
    def plotly_poses(poses, K):
        fx = abs(K[0, 0])
        fy = abs(K[1, 1])
        f = (fx + fy) / 2
        cx = K[0, 2]
        cy = K[1, 2]

        cx = cx / f
        cy = cy / f
        f = 1.

        pos_0 = np.array([0., 0., 0.])
        fustum_0 = np.array(
            [
                [-cx, -cy, f],
                [cx, -cy, f],
                list(pos_0),
                [-cx, -cy, f],
                [-cx, cy, f],
                list(pos_0),
                [cx, cy, f],
                [-cx, cy, f],
                [cx, cy, f],
                [cx, -cy, f]
            ]
        )

        traj = []
        traj_fustums = []
        for pose in poses:
            rot = pose[:3, :3]
            tvec = pose[:3, 3]

            fustum_i = fustum_0 @ rot.T
            fustum_i = fustum_i + tvec
            pos_i = pos_0 + tvec

            pos_i = np.round(pos_i, decimals=2)
            fustum_i = np.round(fustum_i, decimals=2)

            traj.append(pos_i)
            traj_array = np.array(traj)
            traj_fustum = [
                go.Scatter3d(
                    x=fustum_i[:, 0], y=fustum_i[:, 1], z=fustum_i[:, 2],
                    marker=dict(
                        size=0.1
                    ),
                    line=dict(color='purple', width=2)
                ),
                go.Scatter3d(
                    x=pos_i[None, 0], y=pos_i[None, 1], z=pos_i[None, 2],
                    marker=dict(size=0, color='purple')
                ),
                go.Scatter3d(
                    x=traj_array[:, 0], y=traj_array[:, 1], z=traj_array[:, 2],
                    marker=dict(size=0.1),
                    line=dict(color ='purple', width=2)
                )
            ]
            traj_fustums.append(traj_fustum)
        return traj_fustums
        
    def frame_args(duration):
        return {
            "frame": {"duration": duration, "redraw": True},
            "mode": "immediate",
            "fromcurrent": True,
            "transistion": {"duration": duration, "easing": "linear"}
        }

    # visualization
    scatter3d_list = [pc.plotly(0, as_figure=False, max_num_points=max_points_per_pc) for pc in intermediate_pcs]
    traj_frustums = plotly_poses(poses.cpu().numpy(), K.cpu().numpy())
    data = [[*frustum, scatter3d] for frustum, scatter3d in zip(traj_frustums, scatter3d_list)]

    steps = [
        {"args": [[i], frame_args(0)], "label": i, "method": "animate"}
        for i in range(seq_len)
    ]
    sliders = [
        {
            "active": 0,
            "yanchor": "top",
            "xanchor": "left",
            "currentvalue": {"prefix": "Frame: "},
            "pad": {"b": 10, "t": 60},
            "len": 0.9,
            "x": 0.1,
            "y": 0,
            "steps": steps,
        }
    ]
    updatemenus = [
        {
            "buttons": [
                {
                    "args": [None, frame_args(ms_per_frame)],
                    "label": "&#9654;",
                    "method": "animate",
                },
                {
                    "args": [[None], frame_args(0)],
                    "label": "&#9724;",
                    "method": "animate",
                },
            ],
            "direction": "left",
            "pad": {"r": 10, "t": 70},
            "showactive": False,
            "type": "buttons",
            "x": 0.1,
            "xanchor": "right",
            "y": 0,
            "yanchor": "top",
        }
    ]

    fig = go.Figure()
    frames = [{"data": frame, "name": i} for i, frame in enumerate(data)]
    fig.add_traces(frames[0]["data"])
    fig.update(frames=frames)
    fig.update_layout(
        updatemenus=updatemenus,
        sliders=sliders,
        showlegend=False,
        scene=dict(
            xaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),
            yaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),
            zaxis=dict(showticklabels=False, showgrid=False, zeroline=False, visible=False,),
        ),
        height=600, width=400
    )
    fig.show()
    return fig

In [None]:
dataset = TUM(tum_path, seqlen=20, dilation=19, height=480, width=640)
loader = DataLoader(dataset=dataset, batch_size=1)
colors, depths, intrinsics, poses, *_ = next(iter(loader))

# create rgbdimages object
rgbdimages = RGBDImages(colors, depths, intrinsics, poses)

# step by step SLAM and store intermediate maps
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
slam = PointFusion(odom='gt', device=device)  # use gt poses because large dilation (small fps) makes ICP difficult
pointclouds = Pointclouds(device=device)
batch_size, seq_len = rgbdimages.shape[:2]
initial_poses = torch.eye(4, device=device).view(1, 1, 4, 4).repeat(batch_size, 1, 1, 1)
prev_frame = None
intermediate_pcs = []
for s in range(seq_len):
    live_frame = rgbdimages[:, s].to(device)
    if s == 0 and live_frame.poses is None:
        live_frame.poses = initial_poses
    pointclouds, live_frame.poses = slam.step(pointclouds, live_frame, prev_frame)
    prev_frame = live_frame if slam.odom != 'gt' else None
    intermediate_pcs.append(pointclouds[0])

# visualize
rgbdimages.plotly(0).update_layout(autosize=False, height=600, width=400).show()
fig = plotly_map_update_visualization(intermediate_pcs, poses[0], intrinsics[0, 0], 15000)