In [None]:
import torch

pbnds_weights = torch.load('../output/exp_02/weights/NeuralRenderer01.pth')
#pbnds_w = {k: v for k, v in pbnds_weights.items() if 'unet' not in k}

In [None]:
torch.save(pbnds_w, '../output/exp_02/weights/NeuralRenderer01.pth')

In [None]:
from models.neural_renderer import NeuralRenderer

nr = NeuralRenderer()

nr.load_state_dict(pbnds_weights)

In [None]:
import sys, os, math
sys.path.append('..')
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'

import cv2 as cv
import numpy as np

import torch
import torch.nn as nn
from models.neural_renderer import NeuralRenderer

exp_name = 'exp_02'

neural_renderer = NeuralRenderer()
w = torch.load(f'../output/{exp_name}/weights/NeuralRenderer.pth')
neural_renderer.load_state_dict({k: v for k, v in w.items() if 'unet' not in k})
neural_renderer = neural_renderer.to('cuda')

def get_view_pos(depth, width, height, fov):
    fovx = math.radians(fov)
    fovy = 2 * math.atan(math.tan(fovx / 2) / (width / height))
    vpos = torch.zeros(height, width, 3)
    Y = 1 - (torch.arange(height) + 0.5) / height
    Y = Y * 2 - 1
    X = (torch.arange(width) + 0.5) / width
    X = X * 2 - 1
    Y, X = torch.meshgrid(Y, X, indexing='ij')
    vpos[..., 0] = depth * X * math.tan(fovx / 2)
    vpos[..., 1] = depth * Y * math.tan(fovy / 2)
    vpos[..., 2] = -depth
    return vpos

def load_sdr(image_name):
    image = cv.imread(image_name, cv.IMREAD_UNCHANGED)
    
    if len(image.shape) == 3:
        if image.shape[2] == 4:
            alpha_channel = image[...,3]
            bgr_channels = image[...,:3]
            rgb_channels = cv.cvtColor(bgr_channels, cv.COLOR_BGR2RGB)
            
            # White Background Image
            background_image = np.zeros_like(rgb_channels, dtype=np.uint8)
            
            # Alpha factor
            alpha_factor = alpha_channel[:,:,np.newaxis].astype(np.float32) / 255.
            alpha_factor = np.concatenate((alpha_factor,alpha_factor,alpha_factor), axis=2)

            # Transparent Image Rendered on White Background
            base = rgb_channels * alpha_factor
            background = background_image * (1 - alpha_factor)
            image = base + background
        else:
            image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    
    image = cv.resize(image, (256, 256), interpolation=cv.INTER_NEAREST)
    
    return torch.from_numpy(image)

def load_hdr(image_name, resize=True, to_ldr=False):
    image = cv.imread(image_name, -1)
    image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
    
    if resize:
        image = cv.resize(image, (256, 256), interpolation=cv.INTER_NEAREST)

    if to_ldr:
        image = image.clip(0, 1) ** (1 / 2.2)
    
    return torch.from_numpy(image)

base_path = f'../dataset/ffhq256_pbr/'

rgb_gt = load_sdr(base_path + f'bgremoval/06000/06112.png') / 255.
normal_gt = load_sdr(base_path + f'texture/normal/06000/normal_06112.png')
normal_gt = ((normal_gt / 255.) * 2 - 1).to(torch.float32)
albedo_gt = load_sdr(base_path + f'texture/albedo/06000/albedo_06112.png') / 255.
roughness_gt = load_sdr(base_path + f'texture/roughness/06000/roughness_06112.png') / 255.
specular_gt = load_sdr(base_path + f'texture/specular/06000/specular_06112.png') / 255.
depth_gt = load_hdr(base_path + f'texture/depth/06000/depth_06112.exr')[...,0]
mask_gt = (rgb_gt != 0)[...,0]
hdri_gt = load_hdr(base_path + f'/hdri/06000/hdri_06112.exr')

view_pos_gt = get_view_pos(depth=depth_gt, width=256, height=256, fov=50)

In [None]:
render_buffer = {
    'rgb_gt': rgb_gt[mask_gt].cuda(),
    'normal_gt': normal_gt[mask_gt].cuda(),
    'albedo_gt': albedo_gt[mask_gt].cuda(),
    'roughness_gt': roughness_gt[mask_gt].cuda(),
    'specular_gt': specular_gt[mask_gt].cuda(),
    'view_pos_gt': view_pos_gt[mask_gt].cuda(),
    'hdri_gt': hdri_gt[None].cuda(),
}

with torch.no_grad():
    shading_rgb = neural_renderer(render_buffer, num_light_samples=128)

rec_image = torch.zeros(256,256,3).cuda()
rec_image[mask_gt] = shading_rgb

In [None]:
import torchvision.transforms.functional as tvf

tvf.to_pil_image(rec_image.permute(2,0,1))

In [None]:
h, w, _ = albedo_gt.shape

albedo = albedo_gt.reshape(1, h*w, 1, 3)
roughness = roughness_gt.reshape(1, h*w, 1, 1)
specular = specular_gt.reshape(1, h*w, 1, 1)
normal = normal_gt.reshape(1, h*w, 1, 3)

# Sampling the HDRi environment map
sampled_hdri_map, sampled_direction = neural_renderer.uniform_sampling(hdri_map=hdri_gt[None], num_samples=128)

cam_pos = torch.tensor([0., 0., 0.])[None, None, :]

in_dirs = sampled_direction.repeat(view_pos_gt.shape[0],1,1)                                                               # [S,N,3]
out_dirs = (cam_pos - view_pos_gt[None].unsqueeze(1))
out_dirs = nn.functional.normalize(out_dirs, dim=-1)                                                                    # [S,N,3]

In [None]:
hdri_gt.shape

In [None]:
in_dirs.shape

In [None]:
sampled_hdri_map.shape, sampled_direction.shape, view_pos_gt.shape

In [None]:
in_dirs.shape, out_dirs.shape

In [None]:


# Repeat light for multiple pixels for sharing
light = light.repeat_interleave(2, dim=1)
light = light.repeat_interleave(2, dim=2)
light = light.reshape(b,h*w,self.env_height*self.env_width,3)

# Diffuse BRDF
# diffuse_brdf = (1 - metallic) * albedo / torch.pi
diffuse_brdf = albedo_gt

# Diffuse BRDF
half_dirs = in_dirs + out_dirs
half_dirs = nn.functional.normalize(half_dirs, dim=-1)
h_d_n = (half_dirs * normal).sum(dim=-1, keepdim=True).clamp(min=0)
h_d_o = (half_dirs * out_dirs).sum(dim=-1, keepdim=True).clamp(min=0)
n_d_i = (normal * in_dirs).sum(dim=-1, keepdim=True).clamp(min=0)
n_d_o = (normal * out_dirs).sum(dim=-1, keepdim=True).clamp(min=0)

# Fresnel term F (Schlick Approximation)
F0 = 0.04 * (1 - metallic) + albedo * metallic
F = F0 + (1. - F0) * ((1. - h_d_o) ** 5)

# Geometry term with Smiths Approximation
V = self.v_schlick_ggx(roughness, n_d_i) * self.v_schlick_ggx(roughness, n_d_o)

# Normal distributed function (SG)
D = self.d_sg(roughness, h_d_n).clamp(max=1)

specular_brdf = D * F * V 

# RGB color shading
incident_area = torch.ones_like(light) * 2 * torch.pi
render_output = ((diffuse_brdf + specular_brdf) * light * incident_area * n_d_i).mean(dim=2)