In [None]:
import os
import torch
import matplotlib.pyplot as plt

from pytorch3d.io import load_objs_as_meshes
from pytorch3d.renderer import (
    look_at_view_transform,
    FoVPerspectiveCameras,
    PointLights,
    DirectionalLights,
    RasterizationSettings,
    MeshRenderer,
    MeshRasterizer,
    SoftPhongShader,
    HardPhongShader,
)

# add path for demo utils functions
import sys
import os

sys.path.append(os.path.abspath(""))

%load_ext autoreload
%autoreload 2

In [None]:
# Setup
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)
else:
    device = torch.device("cpu")

# Set paths
DATA_DIR = "./data"
obj_filename = os.path.join(DATA_DIR, "cow_mesh/cow.obj")

# Load obj file
mesh = load_objs_as_meshes([obj_filename], device=device)

# We scale normalize and center the target mesh to fit in a sphere of radius 1
# centered at (0,0,0). (scale, center) will be used to bring the predicted mesh
# to its original center and scale.  Note that normalizing the target mesh,
# speeds up the optimization but is not necessary!
verts = mesh.verts_packed()
N = verts.shape[0]
center = verts.mean(0)
scale = max((verts - center).abs().max(0)[0])
mesh.offset_verts_(-center)
mesh.scale_verts_((1.0 / float(scale)))

In [None]:
from lac.utils.plotting import plot_poses, plot_mesh

fig = plot_mesh(mesh)
fig.update_layout(width=1200, height=700, scene=dict(aspectmode="data"))
fig.show()

In [None]:
# the number of different viewpoints from which we want to render the mesh.
num_views = 20

elev = torch.linspace(0, 360, num_views)
azim = torch.linspace(-180, 180, num_views)

# lights = PointLights(
#     device=device,
#     diffuse_color=[[1.0, 1.0, 1.0]],
#     ambient_color=[[0.0, 0.0, 0.0]],
#     specular_color=[[0.0, 0.0, 0.0]],
#     location=[[0.0, 0.0, -3.0]],
# )
lights = DirectionalLights(
    device=device,
    diffuse_color=[[1.0, 1.0, 1.0]],
    ambient_color=[[0.0, 0.0, 0.0]],
    specular_color=[[0.0, 0.0, 0.0]],
    direction=[[0.0, 0.0, -1.0]],
)

R, T = look_at_view_transform(dist=2.7, elev=elev, azim=azim)
cameras = FoVPerspectiveCameras(device=device, R=R, T=T)
camera = FoVPerspectiveCameras(device=device, R=R[None, 1, ...], T=T[None, 1, ...])

raster_settings = RasterizationSettings(
    image_size=128,
    blur_radius=0.0,
    faces_per_pixel=1,
)

renderer = MeshRenderer(
    rasterizer=MeshRasterizer(cameras=camera, raster_settings=raster_settings),
    shader=HardPhongShader(device=device, cameras=camera, lights=lights),
)

In [None]:
target_image = renderer(mesh, cameras=camera, lights=lights)
target_rgb = target_image[0, ..., :3].cpu().numpy()
plt.imshow(target_rgb)

In [None]:
lights.location

In [None]:
# ----------------------------------------
# Step 1: Render Depth Map from Light's Perspective
# ----------------------------------------
light_camera = FoVPerspectiveCameras(
    device=device, R=torch.eye(3).unsqueeze(0).to(device), T=-lights.location, znear=0.1
)

light_raster_settings = RasterizationSettings(image_size=128, blur_radius=0.0, faces_per_pixel=1)

light_rasterizer = MeshRasterizer(cameras=light_camera, raster_settings=light_raster_settings)

# Render depth from light's view
fragments = light_rasterizer(mesh)
depth_map = fragments.zbuf.squeeze(0).min(dim=-1)[0]  # Get min depth

In [None]:
plt.imshow(depth_map.cpu().numpy())

In [None]:
def apply_shadows(mesh, cameras, light_camera, depth_map, image_size=128):
    # Step 1: Transform vertices into light's view to check occlusion
    verts_world = mesh.verts_packed()  # (V, 3)
    verts_light = light_camera.get_world_to_view_transform().transform_points(verts_world)  # (V, 3)

    # Step 2: Project vertices into light's image space
    verts_light_ndc = light_camera.transform_points(verts_world)  # (V, 3)
    pix_coords = (verts_light_ndc[:, :2] + 1) * 0.5 * (image_size - 1)  # Scale to [0, image_size]

    # Step 3: Clamp pixel coordinates to avoid out-of-bounds access
    pix_coords = pix_coords.clamp(0, image_size - 1)

    pix_x = pix_coords[:, 0].long()
    pix_y = pix_coords[:, 1].long()

    # Ensure depth map is on the same device as pix_x and pix_y
    depth_map = depth_map.to(verts_light.device)

    # Step 4: Sample depth map at projected pixel locations
    sampled_depth = depth_map[pix_y, pix_x]

    # Step 5: Compare depths to determine shadows
    is_shadowed = verts_light[:, 2] > sampled_depth  # Shadow if light depth < surface depth

    # Step 6: Render mesh and apply shadows
    fragments = renderer.rasterizer(mesh)
    images = renderer.shader(fragments, mesh, lights=lights, cameras=cameras)

    # Apply shadow mask (darken shadowed areas)
    shadow_mask = is_shadowed[fragments.pix_to_face].squeeze(-1)
    shadow_factor = torch.where(shadow_mask, 0.2, 1.0).unsqueeze(-1)  # Darken by 80% in shadow
    shadowed_image = images[..., :3] * shadow_factor

    return shadowed_image


# Render with shadows applied
shadowed_image = apply_shadows(mesh, camera, light_camera, depth_map)
plt.imshow(shadowed_image[0, ..., :3].cpu().numpy())
plt.show()