# iNERF test

In [34]:
import torch
import torch.nn as nn
from models.vanilla_nerf.model import NeRF
from PIL import Image
from pathlib import  Path as P
import json
import torchvision.transforms as transforms
import numpy as np
from datasets.ray_utils import get_ray_directions

## Helper functions

In [82]:
def load_json(json_fname):
    with open(json_fname, 'r') as json_file:
        data_dict = json.load(json_file)
    return data_dict

def remove_model_prefix(input_dict):
    """
    Remove the "model." prefix from all key names in a dictionary.

    Args:
        input_dict (dict): The input dictionary.

    Returns:
        dict: A new dictionary with the "model." prefix removed from key names.
    """
    output_dict = {}
    for key, value in input_dict.items():
        # Check if the key starts with "model."
        if key.startswith("model."):
            # Remove the "model." prefix and add to the new dictionary
            new_key = key[len("model."):]
            output_dict[new_key] = value
        else:
            # If the key doesn't start with "model.", add it as is
            output_dict[key] = value
    return output_dict

def load_model_with_check(model, state_dict_dict):
    """
    Load a PyTorch model's state_dict from a dictionary and report missing
    and unexpected keys.

    Args:
        model (torch.nn.Module): The model to which the state_dict should be loaded.
        state_dict_dict (dict): The state_dict dictionary.

    Returns:
        model (torch.nn.Module): The model with the loaded state_dict.
        missing_keys (list): List of keys that were in the state_dict but not in the model.
        unexpected_keys (list): List of keys that were in the model but not in the state_dict.
    """
    # Load the state_dict
    state_dict = state_dict_dict

    # Load the model's state_dict and track missing and unexpected keys
    missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)

    # Report missing and unexpected keys
    if missing_keys:
        print("Missing keys in the model's state_dict:")
        for key in missing_keys:
            print(key)
    if unexpected_keys:
        print("Unexpected keys found in the model's state_dict:")
        for key in unexpected_keys:
            print(key)

    return model, missing_keys, unexpected_keys

def select_element(value, seg, part_num=0):
    # Squeeze the seg tensor to remove the singleton dimension
    seg = seg.squeeze(dim=1)
    
    # Create a mask for the selected segments
    mask = (seg == part_num)

    # Apply the mask to the value tensor to select the desired elements
    selected_value = value[mask]

    return selected_value

## Load data for optimization

In [91]:
root_path = P("./data/laptop_art_same_pose/train/idx_1/")
device = 'cuda'
transform_meta = load_json(str(root_path / 'transforms.json'))
frame_id = 'r_0'
pose_np = np.array(transform_meta['frame'][frame_id])
rgb_pil = Image.open(str(root_path/'rgb'/(frame_id + '.png')))
seg_pil = Image.open(str(root_path/'seg'/(frame_id + '.png')))
pose = torch.Tensor(pose_np).to(device)
rgb = transforms.ToTensor()(rgb_pil).to(device)
rgb = rgb.view(4, -1).permute(1, 0) # (H*W, 4) RGBA
rgb = rgb[:, :3]*rgb[:, -1:] + (1-rgb[:, -1:]) # blend A to RGB
seg_np = np.array(seg_pil)
seg = torch.Tensor(seg_np).to(device).view([1, -1]).permute(1, 0)
seg = seg.type(torch.LongTensor)
seg = seg - 1 # starts with 2
seg[seg<0] = 0
focal = transform_meta['focal']
h, w = 640, 480
directions = get_ray_directions(h, w, focal).view([-1, 3])

In [92]:
seg.max()

tensor(2)

## NeRF model setup

In [30]:
# load ckpt state_dict
ckpt_file = "results/laptop/nerf_laptop/last.ckpt"
ckpt_dict_model = torch.load(ckpt_file)['state_dict']
ckpt_dict = remove_model_prefix(ckpt_dict_model)
# initialize model and load pre-trained weights
model = NeRF()
model, _, _ = load_model_with_check(model, ckpt_dict)
model = model.to(device)

## Add view transform variable

In [76]:
class ViewTransform(nn.Module):
    def __init__(self):
        super().__init__()
        initialize_param = torch.Tensor([
            [1, 0, 0, 0], 
            [0, 1, 0, 0],
            [0, 0, 1, 0]
        ]).view([-1])
        
        self.weight = nn.Parameter(initialize_param, requires_grad = True)

    def forward(self, input):
        """
        input: 4x4 c2w matrix
        """
        constant = torch.Tensor([
            [0, 0, 0, 1]
        ]).to(dtype=self.weight.dtype, device=self.weight.device)
        
        weight = torch.cat((self.weight.view([3, 4]), constant), dim=0)
        new_view_point = torch.matmul(weight, input)
        return new_view_point

view_deform = ViewTransform().to(device)

## configure optimizer

In [77]:
optimizer = torch.optim.Adam(view_deform.parameters(), lr=1e-2, weight_decay=1e-4)

## Forward function

In [93]:
from models.vanilla_nerf.model_nerfseg import  get_rays_torch
from models.vanilla_nerf.helper import img2mse

directions = directions.to(device)
# gather directions and rgbs based on part label from seg (select value where seg == part_num)
# help me write the function select_element(value, seg, part_num=0), value in shape [N, k], seg in shape [N, 1], return [n, k]
selected_dirs = select_element(directions, seg, part_num=2).to(device)
selected_rgbs = select_element(rgb, seg, part_num=2).to(device)



In [94]:
selected_dirs.shape

torch.Size([14282, 3])

In [95]:
optimize_step = 30
result = []
# loop over the following steps until converge or optimize for a certain number of step
for _ in range(optimize_step):
    optimizer.zero_grad()
    # go through view_deform module to get new view point matrix
    new_pose = view_deform(pose)
    # generate rays with new_view_point and selected directions
    rays_o, viewdirs, rays_d = get_rays_torch(selected_dirs, new_pose[:3, :], output_view_dirs=True)
    # gather input_dict for NeRF
    input_dict = {
        'rays_o': rays_o,
        'rays_d': rays_d,
        'viewdirs': viewdirs
    }
    # gather rendered resutls from NeRF coarse and fine
    rendered_results = model(input_dict, False, True, 2, 6, )
    coarse_rgb = rendered_results[0][0]
    fine_rgb = rendered_results[1][0]
    # calculate and print loss
    loss0 = img2mse(coarse_rgb, selected_rgbs)
    loss1 = img2mse(fine_rgb, selected_rgbs)
    loss = loss0 + loss1
    # update variable
    loss.backward()
    optimizer.step()
    pass


KeyboardInterrupt: 

In [69]:
pose

tensor([[-9.3541e-01,  2.9152e-01, -2.0005e-01, -1.0130e+00],
        [-3.5356e-01, -7.7128e-01,  5.2926e-01,  2.6800e+00],
        [-7.4506e-09,  5.6581e-01,  8.2454e-01,  4.1751e+00],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  1.0000e+00]], device='cuda:0')