<a href="https://colab.research.google.com/github/VigneshBaskar/forfun/blob/master/Copy_of_ECCV_voxel_grid_training_with_implicitron_PUBLIC.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved.

# Demo of Voxel grid training
In this demo, we show some training with an implicitron model on a blender synthetic scene

## 0. Install and import modules

Ensure `torch` and `torchvision` are installed. (On Colab, they are already present). If `pytorch3d` is not installed, install it using the following cell:


In [None]:
import os
import sys
import torch
need_pytorch3d=False
try:
    import pytorch3d
except ModuleNotFoundError:
    need_pytorch3d=True
if need_pytorch3d:
    if torch.__version__.startswith("1.12.") and sys.platform.startswith("linux"):
        # We try to install PyTorch3D via a released wheel.
        pyt_version_str=torch.__version__.split("+")[0].replace(".", "")
        version_str="".join([
            f"py3{sys.version_info.minor}_cu",
            torch.version.cuda.replace(".",""),
            f"_pyt{pyt_version_str}"
        ])
        !pip install fvcore iopath
        !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html
    else:
        # We try to install PyTorch3D from source.
        !curl -LO https://github.com/NVIDIA/cub/archive/1.10.0.tar.gz
        !tar xzf 1.10.0.tar.gz
        os.environ["CUB_HOME"] = os.getcwd() + "/cub-1.10.0"
        !pip install 'git+https://github.com/facebookresearch/pytorch3d.git@stable'

Ensure omegaconf is installed. If not, run this cell. (On colab, this is needed, but despite the warning it is not necessary to restart the runtime.)

In [None]:
!pip install omegaconf

Download one scene from the Blender synthetic dataset.

In [None]:
!wget https://dl.fbaipublicfiles.com/pytorch3d/data/implicitron_tutorial/nerf-synthetic-chair.tar.gz
!tar -xzf nerf-synthetic-chair.tar.gz

In [None]:
import itertools
import logging
import time
from collections import defaultdict
from pathlib import Path
from typing import Iterator, Tuple

import matplotlib.animation as animation
import matplotlib.pyplot as plt
import numpy as np
import torch
import tqdm
from IPython.display import HTML
from omegaconf import OmegaConf
from PIL import Image
from pytorch3d.implicitron.dataset.dataset_base import FrameData
from pytorch3d.implicitron.dataset.data_loader_map_provider import SequenceDataLoaderMapProvider
from pytorch3d.implicitron.dataset.blender_dataset_map_provider import BlenderDatasetMapProvider
from pytorch3d.implicitron.models.generic_model import GenericModel
from pytorch3d.implicitron.models.implicit_function.base import ImplicitFunctionBase
from pytorch3d.implicitron.models.renderer.base import EvaluationMode
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args, registry, remove_unused_components
from pytorch3d.implicitron.tools.stats import Stats
from pytorch3d.renderer import RayBundle
from pytorch3d.renderer.implicit.renderer import VolumeSampler
from pytorch3d.structures import Volumes
from pytorch3d.vis.plotly_vis import plot_batch_individually, plot_scene

In [None]:
torch.set_printoptions(sci_mode=False)

`BlenderDatasetMapProvider` in implicitron reads data formatted on disk like the blender dataset

In [None]:
provider = BlenderDatasetMapProvider(
    base_dir="./chair",
    object_name="chair",
)

In [None]:
dataset_map = provider.get_dataset_map()

We define the configuration for the model in yaml. This includes choosing the components we wish to use (e.g. implicit function, raysampler) as well as parameters of the many components.

In [None]:
model_yaml = """\
model_factory_ImplicitronModelFactory_args:
  model_GenericModel_args:
    render_image_width: 800
    render_image_height: 800
    mask_images: False
    mask_threshold: 0
    n_train_target_views: -1
    num_passes: 1
    chunk_size_grid: 16000
    implicit_function_class_type: VoxelGridImplicitFunction
    implicit_function_VoxelGridImplicitFunction_args:
      scaffold_calculating_epochs:
        - 1
        - 2
        - 3
        - 4
        - 5      
      #scaffold_resolution: Tuple[int, int, int] = (128, 128, 128)
      scaffold_empty_space_threshold: 0.01
      scaffold_filter_points: false
      volume_cropping_epochs:
        - 1
        - 2
        - 3
        - 4
        - 5

      voxel_grid_density_args:
        param_groups:
          self: "grids"
        voxel_grid_class_type: VMFactorizedVoxelGrid
        voxel_grid_VMFactorizedVoxelGrid_args:
          n_features: 1
          n_components: 48
          basis_matrix: False
          resolution_changes:
            0:
              - 128
              - 128
              - 128
            3:
              - 256
              - 256
              - 256
            6:
              - 500
              - 500
              - 500
        extents:
          - 6.5
          - 6.5
          - 6.5
      voxel_grid_color_args:
        param_groups:
          self: "grids"
        voxel_grid_class_type: VMFactorizedVoxelGrid
        voxel_grid_VMFactorizedVoxelGrid_args:
          n_features: 27
          n_components: 144
          resolution_changes:
            0:
              - 128
              - 128
              - 128
            3:
              - 256
              - 256
              - 256
            6:
              - 500
              - 500
              - 500
        extents:
          - 6.5
          - 6.5
          - 6.5
      harmonic_embedder_xyz_density_args:
        n_harmonic_functions: 0
      harmonic_embedder_xyz_color_args:
        n_harmonic_functions: 2
      harmonic_embedder_dir_color_args:
        n_harmonic_functions: 2
      decoder_density_class_type: ElementwiseDecoder
      decoder_density_ElementwiseDecoder_args:
        operation: SOFTPLUS
        scale: 1
        shift: -5 # -10 ?
      decoder_color_class_type: MLPDecoder
      decoder_color_MLPDecoder_args:
        network_args:
          n_layers: 2
          output_dim: 3
          hidden_dim: 128
          last_activation: SIGMOID
          last_layer_bias_init: 0.0
          use_xavier_init: false

    raysampler_class_type: NearFarRaySampler
    raysampler_NearFarRaySampler_args:
      n_rays_total_training: 1024
      n_rays_per_image_sampled_from_mask: null
      n_pts_per_ray_training: 512 #64 norm(resolution)/0.5
      min_depth: 2.0
      max_depth: 6.0
    renderer_MultiPassEmissionAbsorptionRenderer_args:
      density_noise_std_train: 0.0
      n_pts_per_ray_fine_training: 128
      n_pts_per_ray_fine_evaluation: 128
      raymarcher_EmissionAbsorptionRaymarcher_args:
        blend_output: false
        background_opacity: 0.0
        replicate_last_interval: true
        bg_color:
        - 0.0
    loss_weights:
      loss_rgb_mse: 1.0
      loss_prev_stage_rgb_mse: 0.0
      loss_mask_bce: 0.0
      loss_prev_stage_mask_bce: 0.0
      loss_autodecoder_norm: 0.00

    # suppress progress bars
    tqdm_trigger_threshold: 19000 
"""

In [None]:
model_cfg = OmegaConf.create(model_yaml)
model_cfg_full = OmegaConf.merge(get_default_args(GenericModel), model_cfg.model_factory_ImplicitronModelFactory_args.model_GenericModel_args)
gm = GenericModel(**model_cfg_full)

In [None]:
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    gm.to(device)
    assert next(gm.parameters()).is_cuda
else:
    print("CPU ONLY")
    device = torch.device("cpu") ###

In [None]:
gm.train();

In [None]:
for name, param in gm.named_parameters():
    print(f"{name.ljust(75)} {tuple(param.shape)}")

In [None]:
def make_optimizer():
    decoder_params = [j for i,j in gm.named_parameters() if "voxel_grid" not in i if j.requires_grad]
    grid_params = [j for i,j in gm.named_parameters() if "voxel_grid" in i if j.requires_grad]
    p_groups = [
        {"params": decoder_params, "lr": 0.001},
        {"params": grid_params, "lr": 0.02},
    ]
    lr = 0.001
    
    optimizer = torch.optim.Adam(p_groups, foreach=True, lr=lr, weight_decay=0.0, betas=[0.9, 0.999])
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lambda epoch: 0.1 ** (epoch / 3001),
        verbose=False,
    )
    return optimizer, scheduler


optimizer, scheduler = make_optimizer()

Make a dataloader which will provide all the training data (as a FrameData) in every iteration, 
with 1000 iterations as an "epoch".

In [None]:
class WholeDatasetLoader:
    """
    Loades the whole dataset on device and provides and iterator over it.
    Returns `n_batches_in_epoch` batches, where one batch is the whole
    dataset.

    Members:
        train_dataset: dataset to load
        n_batches_in_epoch: how many batches to have in an epoch.
        device: torch.device on which to load the dataset,
    """

    def __init__(
        self, train_dataset, n_batches_in_epoch: int, device: torch.device
    ) -> None:
        self.n_batches_in_epoch = n_batches_in_epoch
        # pyre-ignore[6]
        train_data = [train_dataset[i] for i in range(len(train_dataset))]
        self.train_dataset_batch = train_data[0].collate(train_data).to(device)

    def __iter__(self) -> Iterator[FrameData]:
        return itertools.repeat(self.train_dataset_batch, self.n_batches_in_epoch)

    def __len__(self) -> int:
        return self.n_batches_in_epoch

In [None]:
loader = WholeDatasetLoader(dataset_map.train, 1000, device)

List of events which happen at each epoch

In [None]:
subscribers = defaultdict(list)
def _collect_epoch_subscribers(module: torch.nn.Module) -> None:
    subscribe_to_epochs = getattr(module, "subscribe_to_epochs", None)
    if callable(subscribe_to_epochs):
        wanted_epochs, apply_func = subscribe_to_epochs()
        for epoch in wanted_epochs:
            # pyre-ignore[16]
            subscribers[epoch].append(apply_func)
gm.apply(_collect_epoch_subscribers);

Stats object keeps running average losses for us

In [None]:
stats = Stats(gm.log_vars)

In [None]:
def train_epoch(epoch):
    global optimizer, preds, scheduler
    print(f"doing epoch {epoch}")
    our_subscribers = subscribers[epoch]
    change = len(our_subscribers) > 0
    for subscriber in our_subscribers:
        subscriber(epoch)
    if change:
        optimizer, scheduler = make_optimizer()
        for _ in range(epoch):
            scheduler.step()
    stats.new_epoch()
    t_start = time.time()
    for it, net_input in enumerate(loader):  # enumerate(tqdm.tqdm(loader)):
        net_input = net_input.to(device)
        optimizer.zero_grad()
        preds = gm(**{**net_input, "evaluation_mode": EvaluationMode.TRAINING})
        preds["objective"].backward()
        optimizer.step()
        scheduler.step()
        stats.update(preds, time_start=t_start)
        if it % 100 == 0:
            # print(f"objective: {float(preds['objective']):.5f}, rgb_psnr: {float(preds['loss_rgb_psnr']):.5f}")
            stats.print()

### Do a little bit of training and visualise

In [None]:
train_epoch(0)

In [None]:
train_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.train]
test_data_collated = [FrameData.collate([frame.to(device)]) for frame in dataset_map.test]


In [None]:
def to_numpy_image(image):
    # Takes an image of shape (C, H, W) in [0,1], where C=3 or 1
    # to a numpy uint image of shape (H, W, 3)
    return (image * 255).to(torch.uint8).permute(1, 2, 0).detach().cpu().expand(-1, -1, 3).numpy()
def resize_image(image, output_resolution):
    if output_resolution is None:
        return image
    # Takes images of shape (B, C, H, W) to (B, C, output_resolution, output_resolution)
    return torch.nn.functional.interpolate(image, size=(output_resolution, output_resolution))

def image_data(collated_frames, output_resolution=100):
    gm.eval()
    images = []
    expected = []
    masks = []
    masks_expected = []
    psnrs = []
    for frame in tqdm.tqdm(collated_frames):
        with torch.inference_mode():
            out = gm(**frame, evaluation_mode=EvaluationMode.EVALUATION)
            rendered_image = torch.clamp(out["images_render"],0,1)
            
        image_rgb = to_numpy_image(resize_image(rendered_image, output_resolution)[0])
        mask = to_numpy_image(resize_image(out["masks_render"], output_resolution)[0])
        expd = to_numpy_image(resize_image(frame.image_rgb, output_resolution)[0])
        mask_expected = to_numpy_image(resize_image(frame.fg_probability, output_resolution)[0])

        images.append(image_rgb)
        expected.append(expd)
        masks.append(mask)
        masks_expected.append(mask_expected)
        psnrs.append(float(out["loss_rgb_psnr"]))
    return [images, expected, masks, masks_expected, psnrs]

def make_mosaic(images, expected, masks, masks_expected, n_rows=1):
    images_to_display = [images.copy(), expected.copy(), masks.copy(), masks_expected.copy()]
    n_images = len(images)
    blank_image = images[0] * 0
    n_per_row = 1+(n_images-1)//n_rows
    for _ in range(n_per_row*n_rows - n_images):
        for group in images_to_display:
            group.append(blank_image)

    images_to_display_listed = [[[i] for i in j] for j in images_to_display]
    split = []
    for row in range(n_rows):
        for group in images_to_display_listed:
            split.append(group[row*n_per_row:(row+1)*n_per_row])  

    return Image.fromarray(np.block(split))


In [None]:
train_image_data = image_data(train_data_collated[::20])

In [None]:
#TRAIN
print("train psnrs", train_image_data[4])
make_mosaic(*train_image_data[:4])

In [None]:
test_image_data = image_data(test_data_collated[::40])

In [None]:
#TEST
print("test psnrs", test_image_data[4])
make_mosaic(*test_image_data[:4], n_rows=1)

### Do more training and visualise

In [None]:
gm.train()
for epoch in range(1, 6):
    train_epoch(epoch)

In [None]:
train_image_data = image_data(train_data_collated[::20])

In [None]:
#TRAIN
print("train psnrs", train_image_data[4])
make_mosaic(*train_image_data[:4])


In [None]:
test_image_data = image_data(test_data_collated[::40])

In [None]:
#TEST
print("test psnrs", test_image_data[4])
make_mosaic(*test_image_data[:4], n_rows=1)

In [None]:
one_full_train_image = image_data(train_data_collated[:1], output_resolution=None)

In [None]:
make_mosaic(*one_full_train_image[:4])