In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from functools import partial
import torch.nn.functional as F

import thre3d_atom.modules.thre3d_singan.volumetric_model as volume_model
from thre3d_atom.utils.misc import batchify
from thre3d_atom.rendering.volumetric.voxels import FeatureGrid, render_feature_grid, get_voxel_size_from_scene_bounds_and_hem_rad
from thre3d_atom.modules.thre3d_singan.utils import render_image_in_chunks
from thre3d_atom.utils.imaging_utils import pose_spherical, scale_camera_intrinsics


In [None]:
# ==========================================================================================
# tweakbale hyperparameters for the notebook
# ==========================================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_path = Path("/home/animesh/work/ucl/projects/direct_supervision_thre3d_singan/vol_mod/fg_mlp" +
             "/localized_feats/fish/model_final.pth")

camera_pose = pose_spherical(45, -30, 15.0)

# ===========================================================================================

In [None]:
camera_pose

In [None]:
vol_mod, extra_info = volume_model.create_vol_mod_from_saved_model(model_path)

In [None]:
extra_info

In [None]:
scene_bounds = extra_info["scene_bounds"]
camera_intrinsics = scale_camera_intrinsics(extra_info["camera_intrinsics"], 0.5)
hem_rad = extra_info["hemispherical_radius"]

In [None]:
# process the volumetric model:
vol_mod = volume_model.process_hybrid_rgba_volumetric_model(vol_mod)

### Decoding the feature grid only at feature grid locations:

In [None]:
vol_mod.feature_grid.features.shape

In [None]:
features = vol_mod.feature_grid.features
x_dim, y_dim, z_dim, nc = features.shape
flat_features = features.reshape(-1, nc)

with torch.no_grad():
    fixed_view_dir=-torch.ones(1, 3, device=device).repeat(flat_features.shape[0], 1)
    rgba_values=batchify(
        processor_fn=vol_mod.render_mlp,
        collate_fn=partial(torch.cat, dim=0),
        chunk_size=512,
        verbose=True,
    ) (torch.cat([flat_features, fixed_view_dir], dim=-1))
    rgb_values, a_values = rgba_values[..., :3], rgba_values[..., 3:]
    rgb_values = vol_mod._colour_producer(rgb_values)
    a_values = vol_mod._transmittance_behaviour(a_values, torch.ones_like(a_values))
    rgba_values = torch.cat([rgb_values, a_values], dim=-1)
    
    rgba_grid = rgba_values.reshape(x_dim, y_dim, z_dim, -1)

In [None]:
rgba_grid.shape

In [None]:
rgba_feature_grid = FeatureGrid(
    features=rgba_grid.permute(3, 0, 1, 2),
    voxel_size=get_voxel_size_from_scene_bounds_and_hem_rad(hem_rad, 128, scene_bounds),
    tunable=False,
)

### Render the decoded rgba-feature-grid

In [None]:
rendered_output_rgba = render_image_in_chunks(
    cam_intrinsics=camera_intrinsics,
    camera_pose=camera_pose,
    num_rays_chunk=vol_mod._render_params.num_rays_chunk,
    num_samples_per_ray=vol_mod._render_params.num_samples_per_ray,
    feature_grid=rgba_feature_grid,
    scene_bounds=scene_bounds,
    density_noise_std=0.0,
    perturb_sampled_points=False,
    raw2alpha=lambda x, y: torch.clip(x, 0.0, 1.0),
    colour_producer=lambda x: torch.clip(x, 0.0, 1.0),
    verbose=True,
)

In [None]:
rendered_output_fg = render_image_in_chunks(
    cam_intrinsics=camera_intrinsics,
    camera_pose=camera_pose,
    num_rays_chunk=vol_mod._render_params.num_rays_chunk,
    num_samples_per_ray=vol_mod._render_params.num_samples_per_ray,
    feature_grid=vol_mod.feature_grid,
    processor_mlp=vol_mod.render_mlp,
    scene_bounds=scene_bounds,
    density_noise_std=0.0,
    perturb_sampled_points=False,
    raw2alpha=vol_mod._transmittance_behaviour,
    colour_producer=vol_mod._colour_producer,
    verbose=True,
)

In [None]:
# plot the feature directly:
colour = rendered_output_fg.colour
fig = plt.figure()
plt.title("feature-grid_render")
plt.imshow(colour.detach().cpu().numpy())
plt.savefig(f"/home/animesh/feature_grid_render.png", dpi=1000, facecolor=fig.get_facecolor(), edgecolor="none")
plt.show()

# plot the Decoded RGBA render:
colour = rendered_output_rgba.colour
fig = plt.figure()
plt.title("decoded RGBA render")
plt.imshow(colour.detach().cpu().numpy())
plt.savefig(f"/home/animesh/feature_location_decoded_rgba_render.png", dpi=1000, facecolor=fig.get_facecolor(), edgecolor="none")
plt.show()


### Decoding the feature grid at twice the resolution than the feature-grid:

In [None]:
features = vol_mod.feature_grid.features
x_dim, y_dim, z_dim, nc = features.shape

interpolated_features = F.interpolate(features.permute(3, 0, 1, 2)[None, ...], scale_factor=2, 
                                      mode="trilinear", align_corners=False)[0].permute(1, 2, 3, 0)

flat_features = interpolated_features.reshape(-1, nc)

with torch.no_grad():
    fixed_view_dir=-torch.ones(1, 3, device=device).repeat(flat_features.shape[0], 1)
    rgba_values=batchify(
        processor_fn=vol_mod.render_mlp,
        collate_fn=partial(torch.cat, dim=0),
        chunk_size=512,
        verbose=True,
    ) (torch.cat([flat_features, fixed_view_dir], dim=-1))
    rgb_values, a_values = rgba_values[..., :3], rgba_values[..., 3:]
    rgb_values = vol_mod._colour_producer(rgb_values)
    a_values = vol_mod._transmittance_behaviour(a_values, torch.ones_like(a_values))
    rgba_values = torch.cat([rgb_values, a_values], dim=-1)
    
    rgba_grid = rgba_values.reshape(2 * x_dim, 2 * y_dim, 2 * z_dim, -1)

In [None]:
rgba_grid.shape

In [None]:
rgba_feature_grid = FeatureGrid(
    features=rgba_grid.permute(3, 0, 1, 2),
    voxel_size=get_voxel_size_from_scene_bounds_and_hem_rad(hem_rad, 256, scene_bounds),
    tunable=False,
)

rendered_output_rgba = render_image_in_chunks(
    cam_intrinsics=camera_intrinsics,
    camera_pose=camera_pose,
    num_rays_chunk=vol_mod._render_params.num_rays_chunk,
    num_samples_per_ray=vol_mod._render_params.num_samples_per_ray,
    feature_grid=rgba_feature_grid,
    scene_bounds=scene_bounds,
    density_noise_std=0.0,
    perturb_sampled_points=False,
    raw2alpha=lambda x, y: torch.clip(x, 0.0, 1.0),
    colour_producer=lambda x: torch.clip(x, 0.0, 1.0),
    verbose=True,
)

# plot the Decoded RGBA render:
colour = rendered_output_rgba.colour
fig = plt.figure()
plt.title("decoded RGBA render")
plt.imshow(colour.detach().cpu().numpy())
plt.savefig(f"/home/animesh/2x_resolution_fg_decoded_rgba_render.png", dpi=1000, facecolor=fig.get_facecolor(), edgecolor="none")
plt.show()

### decoding the feature-grid at different locations for viewing:

In [None]:
torch.linspace(-1, 1, 10)

In [None]:
features = vol_mod.feature_grid.features
x_dim, y_dim, z_dim, nc = features.shape
x_size, y_size, z_size = (2 / (x_dim - 1)), (2 / (y_dim - 1)), (2 / (z_dim - 1))

points = torch.stack(torch.meshgrid(
                        torch.linspace(-1, 1 - x_size, x_dim - 1, device=device),
                        torch.linspace(-1, 1 - x_size, y_dim - 1, device=device),
                        torch.linspace(-1, 1 - x_size, z_dim - 1, device=device),
                     ), 
                     dim=-1)

jitter_offset = (
    torch.rand(size=(1, 1, 1, 3), device=device)
    * torch.tensor(
        [x_size, y_size, z_size], dtype=torch.float32, device=device
    )[None, None, None, :]
)

jittered_points = (points + jitter_offset)[None, ...]

point_features = F.grid_sample(
    features[None, ...].permute(0, 4, 3, 2, 1),
    jittered_points,
    align_corners=True,
)

flat_features = point_features[0].permute(1, 2, 3, 0).reshape(-1, nc)

with torch.no_grad():
    random_view_dir = torch.rand(1, 3, device=device)
    random_view_dir /= random_view_dir.norm(dim=-1, keepdim=True)
    random_view_dir[..., -1] = -torch.abs(random_view_dir[..., -1])
    random_view_dir = random_view_dir.repeat(flat_features.shape[0], 1)
    
    rgba_values=batchify(
        processor_fn=vol_mod.render_mlp,
        collate_fn=partial(torch.cat, dim=0),
        chunk_size=512,
        verbose=True,
    ) (torch.cat([flat_features, random_view_dir], dim=-1))
    rgb_values, a_values = rgba_values[..., :3], rgba_values[..., 3:]
    rgb_values = vol_mod._colour_producer(rgb_values)
    a_values = vol_mod._transmittance_behaviour(a_values, torch.ones_like(a_values))
    rgba_values = torch.cat([rgb_values, a_values], dim=-1)
    
    rgba_grid = rgba_values.reshape(x_dim - 1, y_dim - 1, z_dim - 1, -1)

In [None]:
rgba_feature_grid = FeatureGrid(
    features=rgba_grid.permute(3, 0, 1, 2),
    voxel_size=get_voxel_size_from_scene_bounds_and_hem_rad(hem_rad, 127, scene_bounds),
    tunable=False,
)

rendered_output_rgba = render_image_in_chunks(
    cam_intrinsics=camera_intrinsics,
    camera_pose=camera_pose,
    num_rays_chunk=vol_mod._render_params.num_rays_chunk,
    num_samples_per_ray=vol_mod._render_params.num_samples_per_ray,
    feature_grid=rgba_feature_grid,
    scene_bounds=scene_bounds,
    density_noise_std=0.0,
    perturb_sampled_points=False,
    raw2alpha=lambda x, y: torch.clip(x, 0.0, 1.0),
    colour_producer=lambda x: torch.clip(x, 0.0, 1.0),
    verbose=True,
)

# plot the Decoded RGBA render:
colour = rendered_output_rgba.colour
fig = plt.figure()
plt.title("decoded RGBA render")
plt.imshow(colour.detach().cpu().numpy())
plt.savefig(f"/home/animesh/4.png", dpi=1000, facecolor=fig.get_facecolor(), edgecolor="none")
plt.show()