In [None]:
!sudo apt-get install ninja-build
!ninja --version

In [None]:
import os

if os.path.exists('/content/nvdiffrast'):
  !rm -rf /content/nvdiffrast

!git clone --recursive https://github.com/NVlabs/nvdiffrast
%cd /content/nvdiffrast
!pip install .
%cd /content/


In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import nvdiffrast.torch as dr

from IPython.display import Image
from torchvision.transforms import ToPILImage
from torchvision.transforms import RandomPerspective
from einops import rearrange, reduce, repeat

to_img = ToPILImage()

In [None]:
use_opengl = False # On T4 GPU, only False works, but rasterizer works much better if = True
glctx = dr.RasterizeGLContext() if use_opengl else dr.RasterizeCudaContext()

In [None]:
def cond_noise_sampling(src_noise, level=3):

    B, C, H, W = src_noise.shape

    up_factor = 2 ** level

    upscaled_means = F.interpolate(src_noise, scale_factor=(up_factor, up_factor), mode='nearest')

    up_H = up_factor * H
    up_W = up_factor * W

    """
        1) Unconditionally sample a discrete Nk x Nk Gaussian sample
    """

    raw_rand = torch.randn(B, C, up_H, up_W)

    """
        2) Remove its mean from it
    """

    Z_mean = raw_rand.unfold(2, up_factor, up_factor).unfold(3, up_factor, up_factor).mean((4, 5))
    Z_mean = F.interpolate(Z_mean, scale_factor=up_factor, mode='nearest')
    mean_removed_rand = raw_rand - Z_mean

    """
        3) Add the pixel value to it
    """

    up_noise = upscaled_means / up_factor + mean_removed_rand

    return up_noise

In [None]:
import cv2

def gridy2x_pers2erp(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx):
    H, W, h, w = *HWy, *HWx
    hFOVy, wFOVy = FOVy * float(H) / W, FOVy
    hFOVx, wFOVx = FOVx * float(h) / w, FOVx
    
    # gridy2x
    ### onto sphere
    gridy = gridy.reshape(-1, 2).float()
    gridy[:, 0] *= np.tan(np.radians(hFOVy / 2.0))
    gridy[:, 1] *= np.tan(np.radians(wFOVy / 2.0))
    gridy = gridy.double().flip(-1)
    
    x0 = torch.ones(gridy.shape[0], 1)
    gridy = torch.cat((x0, gridy), dim=-1)
    gridy /= torch.norm(gridy, p=2, dim=-1, keepdim=True)
    
    ### rotation
    y_axis = np.array([0.0, 1.0, 0.0], np.float64)
    z_axis = np.array([0.0, 0.0, 1.0], np.float64)
    [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
    [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI))   
    
    gridy = torch.mm(torch.from_numpy(R1), gridy.permute(1, 0)).permute(1, 0)
    gridy = torch.mm(torch.from_numpy(R2), gridy.permute(1, 0)).permute(1, 0)

    ### sphere to gridx
    lat = torch.arcsin(gridy[:, 2]) / np.pi * 2
    lon = torch.atan2(gridy[:, 1] , gridy[:, 0]) / np.pi
    gridx = torch.stack((lat, lon), dim=-1)

    # masky
    mask = torch.where(torch.abs(gridx) > 1, 0, 1)
    mask = mask[:, 0] * mask[:, 1]

    return gridx.float(), mask.float()

def gridy2x_erp2pers(gridy, HWy, HWx, THETA, PHI, FOVy, FOVx):
    H, W, h, w = *HWy, *HWx
    hFOVy, wFOVy = FOVy * float(H) / W, FOVy
    hFOVx, wFOVx = FOVx * float(h) / w, FOVx

    # gridy2x
    ### onto sphere
    gridy = gridy.reshape(-1, 2).float()
    lat = gridy[:, 0] * np.pi / 2
    lon = gridy[:, 1] * np.pi

    z0 = torch.sin(lat)
    y0 = torch.cos(lat) * torch.sin(lon)
    x0 = torch.cos(lat) * torch.cos(lon)
    gridy = torch.stack((x0, y0, z0), dim=-1).double()

    ### rotation
    y_axis = np.array([0.0, 1.0, 0.0], np.float64)
    z_axis = np.array([0.0, 0.0, 1.0], np.float64)
    [R1, _] = cv2.Rodrigues(z_axis * np.radians(THETA))
    [R2, _] = cv2.Rodrigues(np.dot(R1, y_axis) * np.radians(PHI))

    R1_inv = torch.inverse(torch.from_numpy(R1))
    R2_inv = torch.inverse(torch.from_numpy(R2))

    gridy = torch.mm(R2_inv, gridy.permute(1, 0)).permute(1, 0)
    gridy = torch.mm(R1_inv, gridy.permute(1, 0)).permute(1, 0)

    ### sphere to gridx
    z0 = gridy[:, 2] / gridy[:, 0]
    y0 = gridy[:, 1] / gridy[:, 0]
    gridx = torch.stack((z0, y0), dim=-1).float()

    # masky
    mask = torch.where(torch.abs(gridx) > 1, 0, 1)
    mask = mask[:, 0] * mask[:, 1]
    mask *= torch.where(gridy[:, 0] < 0, 0, 1)

    return gridx.float(), mask.float()

In [None]:
"""
  Defining source noise map
"""
# Noise config
up_level = 3     # Upsampling level k
batch_size = 1   # Batch size
dim_channel = 3  # Channel dimension
H, W = 64, 128   # Original H, W
visualize_sz = (512, 1024)

# Sample the source noise
src_noise = torch.randn(batch_size, dim_channel, H, W)

# Upscale to 512 x 1024 and visualize, just for visualization purposes
view_test_noise = F.interpolate(src_noise, size=visualize_sz, mode='nearest')
to_img(view_test_noise[0]).save("0_view_test_noise.png")

# Generate conditionally upsampled noise by k = up_level
up_noise = cond_noise_sampling(src_noise, level=up_level)

# Visualize upsampled noise
test_upsampled_noise_vis = F.interpolate(up_noise, size=visualize_sz, mode='nearest')
to_img(test_upsampled_noise_vis[0]).save("1_test_upsampled_noise_vis.png")

In [None]:
# Perspective-to-ERP 매핑을 위한 설정
H_pers, W_pers = 64, 64  # Perspective view(tgt)의 해상도
theta, phi = 0, 0        # View direction (THETA: yaw, PHI: pitch)

In [None]:
"""
  Defining the partitioned polygons for target noise map
"""
B, C, H, W = src_noise.shape


# Defining a 2x upscaled, partitioned pixel map with vertex index numbers for rasterization
tr_H_pers = H_pers * 2 + 1
tr_W_pers = W_pers * 2 + 1

i, j = torch.meshgrid(
        torch.arange(tr_H_pers, dtype=torch.int32),
        torch.arange(tr_W_pers, dtype=torch.int32),
        indexing="ij",
    )

mesh_idxs = torch.stack((i,j), dim=-1) # (tr_H_pers, tr_W_pers, 2)
reshaped_mesh_idxs = mesh_idxs.reshape(-1,2)

# per_tri_verts defining 8 polygonal partitions for a single original pixel
front_tri_verts = torch.tensor([[0, 1, 1+tr_W_pers], [0, tr_W_pers, 1+tr_W_pers], [tr_W_pers, 1+tr_W_pers, 1+2*tr_W_pers], [tr_W_pers, 2*tr_W_pers, 1+2*tr_W_pers]])
per_tri_verts = torch.cat((front_tri_verts, front_tri_verts + 1),dim=0)


# Defining 'starting vertex indices' representing original pixels at 2x upscaled pixel map
width = torch.arange(0, tr_W_pers - 1, 2)
height = torch.arange(0, tr_H_pers-1, 2) * (tr_W_pers)
# width_l = torch.linspace(0, tr_W_pers-2, tr_W-1)

start_idxs = (width[None,...] + height[...,None]).reshape(-1,1)
vertices = (start_idxs.repeat(1,8)[...,None] + per_tri_verts[None,...]).reshape(-1,3)
num_faces = vertices.shape[0]

# print(width)
# print(height)
# print(start_idxs.shape) # --> size == original resolution's pixel num == H x W if correct
# print(vertices)

In [None]:
"""
  Defining target noise map coord. to source noise map
"""

# Perspective view grid 생성 (not pixel's center, but pixel border)
pers_grid_y, pers_grid_x = torch.meshgrid(
   torch.linspace(-1, 1, tr_H_pers),  # 수직 좌표 (normalized)
   torch.linspace(-1, 1, tr_W_pers),  # 수평 좌표 (normalized)
   indexing="ij"
)

pers_grid = torch.stack((pers_grid_y, pers_grid_x), dim=-1)  # (tr_H_pers, tr_W_pers, 2)

# Perspective-to-ERP 좌표 변환
pers2erp_coords, _ = gridy2x_pers2erp(
   gridy=pers_grid,
   HWy=(2*H_pers, 2*W_pers),  # Perspective 해상도
   HWx=(2*H, 2*W),            # ERP 해상도
   THETA=theta, 
   PHI=phi, 
   FOVy=90,
   FOVx=360
)

print(pers2erp_coords[..., 0].max(), pers2erp_coords[..., 0].min())
print(pers2erp_coords[..., 1].max(), pers2erp_coords[..., 1].min())

# Perspective-to-ERP 좌표를 ERP 해상도로 매핑
pers2erp_coords[..., 0] = (pers2erp_coords[..., 0] + 1) / 2 * (2*H)
pers2erp_coords[..., 1] = (pers2erp_coords[..., 1] + 1) / 2 * (2*W)

tgt_to_src_map = pers2erp_coords.view(tr_H_pers, tr_W_pers, 2)

print(pers2erp_coords[..., 0].max(), pers2erp_coords[..., 0].min())
print(pers2erp_coords[..., 1].max(), pers2erp_coords[..., 1].min())
print("tgt_to_src_map min/max:", tgt_to_src_map.min(), tgt_to_src_map.max())
print("vertices min/max:", vertices.min(), vertices.max())
print("vertices shape:", vertices.shape)

In [None]:
"""
    Triangle rasterization using Nvdiffrast
"""

idx_y = reshaped_mesh_idxs[...,0].int()
idx_x = reshaped_mesh_idxs[...,1].int()

# Nvdiffrast input must be (x,y), so it must be flipped!!
coords_len = idx_y.shape[0] # 16641
warped_coords = tgt_to_src_map[idx_y, idx_x].fliplr()

resolution= H * (2 ** up_level)
device = "cuda"

warped_coords = warped_coords.float()
warped_coords[..., 0] = (warped_coords[..., 0] - W) / W
warped_coords[..., 1] = (warped_coords[..., 1] - H) / H

warped_vtx_pos = torch.cat((warped_coords, torch.zeros(coords_len, 1), torch.ones(coords_len, 1)), dim=-1)

# To avoid ranged error
warped_vtx_pos = warped_vtx_pos[None,...].to(device)
vertices = vertices.int().to(device)

print(warped_vtx_pos[..., 0].max(), warped_vtx_pos[..., 0].min())
print(warped_vtx_pos[..., 1].max(), warped_vtx_pos[..., 1].min())


with torch.no_grad():
    rast_out, _ = dr.rasterize(glctx, warped_vtx_pos, vertices, resolution=[resolution, 2*resolution])

rast = rast_out[:,:,:,3:].permute(0,3,1,2).to(torch.int64) # 1, 1, 512, 1024

# # Delete the context
# del glctx

# Rasterization visualization
up_noise_vis = F.interpolate(rast.float(), size=visualize_sz, mode='nearest')

plt.imsave(f'2_rast_{theta}_{phi}.png', rast[0, 0].cpu().numpy(), cmap='viridis')  # cmap은 필요하면 설정

In [None]:
"""
  finding pixel indices in cond-upsampled map that belong to each polygon triangle,
  and then adding them up

  This implementation uses torch.scatter() to do this in parallel, so it is faster
"""

# Assign same index to triangles from the same original pixel, 0 if no index
indices = (rast - 1) // 8 + 1 # there is 8 triangles per pixel

# Flatten the upsampled noise
up_noise_flat = up_noise.reshape(B*C, -1).cpu()

# Create a flatten vector of ones for "Cardinality" value i.e. number of contained pixels
ones_flat = torch.ones_like(up_noise_flat[:1])

# Flatten the indices (and broadcast to batch size)
indices_flat = indices.reshape(1, -1).cpu().to(torch.int64)

# Aggregate the noise values and cardinality using scattering operation
fin_v_val = torch.zeros(B*C, H_pers*W_pers+1).scatter_add_(1, index=indices_flat.repeat(B*C, 1), src=up_noise_flat)[..., 1:]
fin_v_num = torch.zeros(1, H_pers*W_pers+1).scatter_add_(1, index=indices_flat, src=ones_flat)[..., 1:]

assert fin_v_num.min() != 0

print(fin_v_num.min(), fin_v_num.max())

final_values = fin_v_val / torch.sqrt(fin_v_num)

warped_noise_fast = final_values.reshape(B, C, H_pers, W_pers).float()

up_vis_fast = F.interpolate(warped_noise_fast, size=(512, 512), mode='nearest')

to_img(up_vis_fast[0]).save(f"3_up_vis_fast_{theta}_{phi}.png")

In [None]:
# Flatten the tensor across all channels and batches
flattened_noise = warped_noise_fast.flatten().numpy()

# Plot the histogram of pixel values
plt.figure(figsize=(8, 6))
plt.hist(flattened_noise, bins=100, density=True, alpha=0.6, color='blue', label='Pixel Values')

# Overlay a Gaussian curve for comparison
mean, std = 0, 1
print(mean, std)
x = np.linspace(-4.5, 4.5, 1000)
gaussian_curve = (1 / (std * np.sqrt(2 * np.pi))) * np.exp(-0.5 * ((x - mean) / std) ** 2)

plt.plot(x, gaussian_curve, color='red', linestyle='--', label='Gaussian Fit')
plt.title("Distribution of Pixel Values in warped_noise_fast")
plt.xlabel("Pixel Value")
plt.ylabel("Density")
plt.xticks(np.arange(-5, 6, 1))  # Customize x-axis grid
plt.yticks(np.arange(0, 0.5, 0.05))  # Customize y-axis grid
plt.legend()
plt.grid(True)
plt.savefig(f"4_hist_{theta}_{phi}.png")
plt.show()