In [None]:
from pytorch_direct_warp.direct_warp import DirectWarper
from pytorch_direct_warp.occlusion_mapper import OcclusionMapper
from utils import inverse_warp, pixel2cam, pose_vec2mat, cam2pixel
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
from matplotlib import animation, rc
import matplotlib.colors as colors
from imageio import imread
from skimage.transform import resize
from mpl_toolkits.axes_grid1 import make_axes_locatable
from IPython.display import HTML

device = torch.device("cuda")
torch.no_grad()
plt.rcParams["figure.figsize"] = [12,9]
h=200
f = h/2
batch_size = 1

def truncate_colormap(cmap, minval=0.0, maxval=1.0, n=100):
    new_cmap = colors.LinearSegmentedColormap.from_list(
        'trunc({n},{a:.2f},{b:.2f})'.format(n=cmap.name, a=minval, b=maxval),
        cmap(np.linspace(minval, maxval, n)), gamma=0.8)
    return new_cmap

cmap = plt.get_cmap('gist_rainbow')
new_cmap = truncate_colormap(cmap, 0, 0.85)

intrinsics = torch.Tensor([[f, 0, h/2],
                           [0, f, h/2],
                           [0,  0,  1]]).float().to(device).unsqueeze(0)
intrinsics_inv = torch.inverse(intrinsics[0]).unsqueeze(0)

In [None]:
fg_depth = 1
bg_depth = 10

fg_pos = [50, 50]

## Experiment setup

In [None]:
foreground = resize(imread('img/foreground.jpg'), (h//2,h//2))
fg_tensor = torch.from_numpy(foreground.transpose(2,0,1)).float().to(device)
background = resize(imread('img/background.jpg'),(h,h))
bg_tensor = torch.from_numpy(background.transpose(2,0,1)).float().to(device)
img = bg_tensor.clone().unsqueeze(0)
temp_h, temp_w = fg_pos
img[:,:, temp_h:temp_h + h//2, temp_w:temp_w + h//2] = fg_tensor
depth = (torch.zeros(1,h,h) + bg_depth).float().to(device)
depth[:, temp_h:temp_h + h//2, temp_w:temp_w + h//2] = fg_depth
depth +=  0.001*torch.randn(1,h,h).to(device)
#line = torch.linspace(-h/2,h/2,h).abs().to(device)
#depth += line
#parabola = 0.1*(line.view(1,1,h) + line.view(1,h,1))
#depth = parabola

plt.figure(figsize=(12,9))
plt.subplot(121)
plt.imshow(img[0].permute(1,2,0).cpu().numpy())
plt.subplot(122)
plt.imshow(depth[0].cpu().numpy(), cmap=new_cmap, vmin=0, vmax=10)
plt.colorbar(fraction=0.046, pad=0.04)
plt.tight_layout()

## Testing Warping on a vector pose

In [None]:
#you can define different poses to see how the warping behaves
pose = torch.Tensor([0.2, 0.2, 0, 0, 0, 0]).unsqueeze(0).to(device).expand(batch_size, 6)


matrix = pose_vec2mat(pose)
warper = DirectWarper()
warped, wimg = warper(depth, img, matrix, intrinsics, 0)
# warped[:,20:40, 20:40] = warped[:, 20:40,20:40] * 0 + torch.linspace(100,10,20).view(1,1,-1).to(device)
plt.figure(figsize = (12,9), dpi=100)
plt.subplot(121)
plt.imshow(wimg[0].cpu().permute(1,2,0))
plt.subplot(122)
plt.imshow(warped[0].cpu(),cmap=new_cmap)

plt.tight_layout()
plt.show()

## Video version

In [None]:
fig = plt.figure(figsize=(12,9))

warped, wimg = warper(depth, img, matrix, intrinsics, 0)
warped = warped.cpu()
wimg = wimg[0].permute(1,2,0).cpu()
plt.subplot(121)
im1 = plt.imshow(warped[0].numpy(), animated=True, vmin=0, vmax=10,cmap=new_cmap)
plt.subplot(122)
im2 = plt.imshow(wimg.numpy(), animated=True)
plt.show()


x = np.linspace(0, 2*np.pi, 100)
def animate(i):
    pose = torch.Tensor([0*np.sin(x[i]) , 0*np.cos(x[i]), 0.5*np.cos(x[i]), 0.5*np.cos(x[i]), 0, x[i]]).view(1,6).to(device).expand(batch_size, 6)
    pose_mat = pose_vec2mat(pose)
    
    warped, wimg = warper(depth, img, pose_mat, intrinsics, 0)
    im1.set_array(warped[0].cpu().numpy())
    im2.set_array(wimg[0].permute(1,2,0).cpu().numpy())
    return (im1,im2)
animate(0)

ani = animation.FuncAnimation(fig, animate, interval=50, blit=True)

HTML(ani.to_html5_video())

## Occlusion module

This figure decompose the different techniques to get the occlusion map for a particular depth map + a pose.
The video is optional

In [None]:
fig,axes = plt.subplots(3,3,figsize=(16,20),sharex=True,sharey=True)
noise = 0.1 * torch.randn(1,h,h).to(device)
noisy_depth = depth + noise
im_list = []
ax_list = axes.reshape(-1)
i = 0
warper = DirectWarper(keep_index=True)

def draw_map(ax, dmap, title):
    im = ax.imshow(dmap.cpu()[0], animated=True, vmin=0, vmax=10, cmap=new_cmap)
    ax.set_title(title, fontsize=20, y=1.05)
    return im

im0 = draw_map(ax_list[0], depth, '$\\theta_t$')
im1 = draw_map(ax_list[1], depth, '$\\theta_{t+1}$')
im2 = draw_map(ax_list[2], depth, 'valid pixels of $\\theta_t$ \n (pixels not occluded in $\\theta_{t+1}$)')

warped_back = '\widetilde{\\nu}_{t \\rightarrow t+1 \\rightarrow t}'
range_alpha = '\\left[\\frac{1}{2}\\nu_t, 2\\nu_t\\right]'
p = '\\mathbf{p}'

im3 = draw_map(ax_list[3], noisy_depth, 'Noisy depth $\\nu_t $')
im4 = draw_map(ax_list[4], noisy_depth, 'Warped depth $\widetilde{\\nu}_{t \\rightarrow t+1}$')
im5 = draw_map(ax_list[5], noisy_depth, 'Indexed back depth \n'
                                         '$\\lbrace\\nu_t(\\mathbf{p}),\\mathbf{p} \in Id \\rbrace$')
im6 = draw_map(ax_list[6], noisy_depth, 'Warped back depth ${}$'.format(warped_back))
im7 = draw_map(ax_list[7], noisy_depth, '$\\lbrace\\nu_t, {wb} \in {ra} \\rbrace$'.format(p=p, wb=warped_back, ra=range_alpha))
im8 = draw_map(ax_list[8], noisy_depth, '$\\lbrace\\nu_t, {wb} \in {ra} \\rbrace$\n'
    '+ erosion of $2$'.format(p=p, wb=warped_back, ra=range_alpha))

cbar_ax = fig.add_axes([0.2, 0.1, 0.6, 0.01])
plt.colorbar(im0, cax=cbar_ax, orientation='horizontal')


x = np.linspace(0, 2*np.pi, 100)

def animate(i):
    pose = torch.Tensor([0.5*np.sin(x[i]) , 0.5*np.cos(x[i]), 0.5*np.cos(x[i]), 0, 0.*np.cos(x[i]), 0.*x[i]]).view(1,6).to(device).expand(batch_size, 6)
    pose_mat = pose_vec2mat(pose)
    inverse_rot = pose_mat[:,:,:3].transpose(1,2)
    inverse_tr = -inverse_rot @ pose_mat[:,:,-1:]
    inverse_pose_mat = torch.cat([inverse_rot, inverse_tr], dim=-1)
    
    #construct theoretical warped and occlusion maps, with no noise
    mapper = OcclusionMapper(dilation=1, alpha = 100)
    warped = warper(depth, None, pose_mat, intrinsics, 0)
    occlusion = mapper(depth, pose_mat, intrinsics)
    
    #do the same, but with a noisy depth
    mapper = OcclusionMapper(dilation=0, alpha=2)
    warped_noisy = warper(noisy_depth, None, pose_mat, intrinsics, 0)
    warped_back_noisy = warper(warped_noisy, None, inverse_pose_mat, intrinsics, 0)
    
    mapper = OcclusionMapper(dilation=0, alpha=2)
    occlusion_noisy = mapper(noisy_depth, pose_mat, intrinsics)
    
    mapper = OcclusionMapper(dilation=2, alpha=2)
    occlusion_noisy_eroded = mapper(noisy_depth, pose_mat, intrinsics)
    
    depth_occ = depth.clone()
    depth_occ[occlusion] = float('inf')
    
    depth_occ_noisy = noisy_depth.clone()
    depth_occ_noisy[occlusion_noisy] = float('inf')
    
    depth_occ_noisy2 = noisy_depth.clone()
    depth_occ_noisy2[occlusion_noisy_eroded] = float('inf')
    
    warper(noisy_depth, None, pose_mat, intrinsics, 0)
    index = warper.index
    id_map = torch.full_like(noisy_depth, float('inf'))
    id_map.view(-1).index_copy_(0, index[index>=0], (warped_noisy - pose[0,2])[index>=0])
    
    im1.set_array(warped.cpu()[0])
    im2.set_array(depth_occ.cpu()[0])
    im4.set_array(warped_noisy.cpu()[0])
    im5.set_array(id_map.cpu()[0])
    im6.set_array(warped_back_noisy.cpu()[0])
    im7.set_array(depth_occ_noisy.cpu()[0])
    im8.set_array(depth_occ_noisy2.cpu()[0])
    return (im1,im2,im3,im4,im5)

animate(37)
fig.subplots_adjust(hspace=0.1,wspace=0.1)
plt.show()

ani = animation.FuncAnimation(fig, animate, interval=50, blit=True)

HTML(ani.to_html5_video())