In [None]:
import os
import sys
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import trange
from scipy.spatial import cKDTree
from skimage.measure import marching_cubes
import trimesh
from trimesh.transformations import euler_matrix # Added for rotate_rays

# If you have the data package
sys.path.append("..")
try:
    from data.pollen_dataset import PollenDataset, get_train_test_split
except ImportError:
    # Fallback if not available
    PollenDataset = None
    get_train_test_split = None

torch.backends.cudnn.benchmark = True
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# 1. Positional Encoding (Reduced Frequencies)
# -----------------------------------------------------------------------------
def positional_encoding(x, L=4):
    """
    Encode coordinates x with sine/cosine functions at increasing frequencies.
    We use L=4 here (fewer than the classic L=10) to reduce high-frequency overfitting.
    """
    out = [x]
    for i in range(L):
        for fn in (torch.sin, torch.cos):
            out.append(fn((2.0**i) * np.pi * x))
    return torch.cat(out, dim=-1)


# -----------------------------------------------------------------------------
# 2. NeRF Model
# -----------------------------------------------------------------------------
class NeRF(nn.Module):
    def __init__(self, D=6, W=128, L=4):
        """
        Args:
            D: Number of hidden layers
            W: Number of hidden units per layer
            L: Positional encoding frequency levels
        """
        super(NeRF, self).__init__()
        self.L = L
        in_ch = 3 * (2 * L + 1)  # 3 coords * (2L + 1)
        layers = [nn.Linear(in_ch, W)] + [nn.Linear(W, W) for _ in range(D - 1)]
        self.layers = nn.ModuleList(layers)
        self.output_layer = nn.Linear(W, 4)
        # Initialize sigma bias to something non-zero
        with torch.no_grad():
            self.output_layer.bias[3] = 0.1

    def forward(self, x):
        """
        Forward pass: x is (N, 3), output is (N, 4) => [R, G, B, sigma].
        """
        x_enc = positional_encoding(x, self.L)
        h = x_enc
        for l in self.layers:
            h = torch.relu(l(h))
        return self.output_layer(h)


# -----------------------------------------------------------------------------
# 3. Render Rays (RGB + Alpha)
# -----------------------------------------------------------------------------
def render_rays(
    model, rays_o, rays_d, near=0.5, far=1.5, N_samples=128, sigma_scale=1.0
):
    """
    Volumetric rendering for a batch of rays:
      - Sample points along each ray
      - Query MLP for color (rgb) and density (sigma)
      - Composite color and alpha
    """
    device = rays_o.device
    z_vals = torch.linspace(near, far, N_samples, device=device)

    pts = (
        rays_o[:, None, :] + rays_d[:, None, :] * z_vals[None, :, None]
    )  # (B, N_samples, 3)
    pts_flat = pts.reshape(-1, 3)

    raw = model(pts_flat).reshape(pts.shape[0], N_samples, 4)
    rgb = torch.sigmoid(raw[..., :3])
    sigma = torch.relu(raw[..., 3]) * sigma_scale

    deltas = z_vals[1:] - z_vals[:-1]
    deltas = torch.cat([deltas, torch.tensor([1e10], device=device)])
    deltas = deltas[None, :].expand(sigma.shape)

    alpha = 1.0 - torch.exp(-sigma * deltas)
    T = torch.cumprod(
        torch.cat(
            [torch.ones((sigma.shape[0], 1), device=device), 1.0 - alpha + 1e-10],
            dim=-1,
        ),
        dim=-1,
    )[:, :-1]
    weights = alpha * T

    rgb_map = torch.sum(weights[..., None] * rgb, dim=1)
    alpha_map = torch.sum(weights, dim=1)
    return rgb_map, alpha_map


# -----------------------------------------------------------------------------
# 4. Losses: Silhouette, Spherical Prior, etc.
# -----------------------------------------------------------------------------
def silhouette_loss(alpha, mask):
    """Calculates the L2 loss between predicted alpha and ground truth mask."""
    return torch.mean((alpha - mask) ** 2)


def spherical_prior_loss(
    model, num_samples=2000, bound=1.0, desired_radius=0.6, sigma_scale=2.0, device=None
):
    """Encourages density to concentrate near a specific radius."""
    if device is None:
        device = next(model.parameters()).device
    coords = torch.rand(num_samples, 3, device=device) * (2 * bound) - bound

    raw = model(coords)
    sigma = torch.relu(raw[..., 3]) * sigma_scale
    dists = torch.norm(coords, dim=1)
    # Encourage high sigma near the desired radius
    loss = torch.mean(sigma * (dists - desired_radius) ** 2)
    return loss


def foreground_density_loss(alpha_map, mask, target_density=1.0):
    """Encourages the integrated density within the mask to be high."""
    eps = 1e-6
    # D is approximate integrated density along the ray
    D = -torch.log(1.0 - alpha_map + eps)
    fg_mask = mask > 0.5
    if torch.sum(fg_mask) > 0:
        # Penalize density lower than target_density within the mask
        return torch.mean(torch.clamp(target_density - D[fg_mask], min=0.0))
    else:
        # No foreground pixels in this batch
        return torch.tensor(0.0, device=alpha_map.device)


def smoothness_prior_loss(
    model, num_samples=2000, bound=1.0, offset=0.01, sigma_scale=2.0, device=None
):
    """Encourages adjacent points in space to have similar density (sigma)."""
    if device is None:
        device = next(model.parameters()).device
    coords = torch.rand(num_samples, 3, device=device) * (2 * bound) - bound

    raw_center = model(coords)
    sigma_center = torch.relu(raw_center[..., 3]) * sigma_scale

    # Sample neighbors along axes
    offsets = torch.tensor(
        [
            [offset, 0, 0],
            [-offset, 0, 0],
            [0, offset, 0],
            [0, -offset, 0],
            [0, 0, offset],
            [0, 0, -offset],
        ],
        device=device,
    ).float()

    total_diff = 0.0
    for off in offsets:
        neighbor_coords = coords + off
        # Clamp coordinates to stay within bounds (optional but can help)
        # neighbor_coords = torch.clamp(neighbor_coords, -bound, bound)
        raw_neighbor = model(neighbor_coords)
        sigma_neighbor = torch.relu(raw_neighbor[..., 3]) * sigma_scale
        # L2 difference in sigma between center and neighbor
        total_diff += torch.mean((sigma_center - sigma_neighbor) ** 2)

    # Average difference across all offset directions
    return total_diff / offsets.shape[0]


# -----------------------------------------------------------------------------
# 5. Ray Generation (Two Orthogonal Views)
# -----------------------------------------------------------------------------
def get_rays(H, W, focal=300.0):
    """
    Returns canonical (unrotated) front‐and‐side rays:
      rays_o_front, rays_d_front, rays_o_side, rays_d_side
      all as (H*W, 3) tensors.
    """
    i, j = torch.meshgrid(
        torch.linspace(0, W - 1, W),
        torch.linspace(0, H - 1, H),
        indexing="xy",
    )
    # front view (looking along -Z)
    dirs_f = torch.stack(
        [(i - W / 2.0) / focal, -(j - H / 2.0) / focal, -torch.ones_like(i)],
        dim=-1,
    )
    rays_d_f = dirs_f / torch.norm(dirs_f, dim=-1, keepdim=True)
    rays_o_f = torch.zeros_like(rays_d_f) # Origin at (0,0,0)

    # side view (looking along -X from X=-1.5)
    dirs_s = torch.stack(
        [torch.ones_like(i), -(j - H / 2.0) / focal, -(i - W / 2.0) / focal],
        dim=-1,
    )
    rays_d_s = dirs_s / torch.norm(dirs_s, dim=-1, keepdim=True)
    rays_o_s = torch.zeros_like(rays_d_s)
    rays_o_s[..., 0] = -1.5 # Origin offset along X

    return (
        rays_o_f.reshape(-1, 3),
        rays_d_f.reshape(-1, 3),
        rays_o_s.reshape(-1, 3),
        rays_d_s.reshape(-1, 3),
    )


def rotate_rays(rays_o, rays_d, euler_angles):
    """
    Apply the sample's rotation (in radians) to both origins and directions.
    euler_angles: tensor([rx, ry, rz]) in radians, in 'sxyz' convention.
    """
    # build 4×4 rotation matrix using trimesh, then extract the 3×3 upper‐left block
    R4 = euler_matrix(
        float(euler_angles[0]), # rx
        float(euler_angles[1]), # ry
        float(euler_angles[2]), # rz
        "sxyz", # Euler convention
    )
    R = torch.from_numpy(R4[:3, :3]).to(rays_o.device).float()

    # rotate origins & directions by multiplying with the rotation matrix
    # (R @ vector.T).T is efficient for batch matrix-vector multiplication
    ro = (R @ rays_o.T).T
    rd = (R @ rays_d.T).T
    return ro, rd


# -----------------------------------------------------------------------------
# 6. Weighted Ray Sampling (Edges + Foreground)
# -----------------------------------------------------------------------------
def sample_rays_weighted(rays_o, rays_d, rgb, mask, original_shape, batch_size=1024):
    """
    Sample rays with higher probability at silhouette edges and foreground.
    Handles cases with one or two views concatenated in the input tensors.
    """
    H, W = original_shape
    total_pixels = mask.shape[0]
    pixels_per_view = H * W

    if total_pixels == 2 * pixels_per_view: # Check if two views are present
        # Two views concatenated (e.g., front and side)
        weights_list = []
        for view_idx in range(2):
            start_idx = view_idx * pixels_per_view
            end_idx = start_idx + pixels_per_view
            view_mask = mask[start_idx:end_idx]
            mask_2d = view_mask.reshape(H, W)

            # Use a simple edge detection kernel (Laplacian)
            kernel = (
                torch.tensor(
                    [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], device=mask.device
                ).float()
                / 8.0 # Normalize slightly
            )
            kernel = kernel.reshape(1, 1, 3, 3) # Add batch and channel dims

            # Apply convolution to find edges
            edges = torch.abs(
                torch.nn.functional.conv2d(
                    mask_2d.reshape(1, 1, H, W), kernel, padding=1 # Pad to keep size
                )
            ).reshape(H, W)

            # Combine edge weights and foreground weights
            edge_weights = edges.reshape(-1) + 0.1 # Add small base weight
            fg_weights = (view_mask > 0.5).float() * 2.0 # Higher weight for foreground
            weights = edge_weights + fg_weights
            weights_list.append(weights)
        # Concatenate weights from both views
        weights = torch.cat(weights_list)
    else:
        # Single view fallback (or if total_pixels doesn't match 2*H*W)
        mask_2d = mask.reshape(H, W)
        kernel = (
            torch.tensor(
                [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], device=mask.device
            ).float()
            / 8.0
        )
        kernel = kernel.reshape(1, 1, 3, 3)
        edges = torch.abs(
            torch.nn.functional.conv2d(mask_2d.reshape(1, 1, H, W), kernel, padding=1)
        ).reshape(H, W)
        edge_weights = edges.reshape(-1) + 0.1
        fg_weights = (mask > 0.5).float() * 2.0
        weights = edge_weights + fg_weights

    # Normalize weights to get probabilities
    p = weights / weights.sum()
    # Sample indices based on probabilities
    idx = torch.multinomial(p, batch_size, replacement=True)

    # Return the sampled rays and corresponding pixel data
    return rays_o[idx], rays_d[idx], rgb[idx], mask[idx]


# -----------------------------------------------------------------------------
# 7. Debug Rendering (with Extra Mask Comparison)
# -----------------------------------------------------------------------------
@torch.no_grad()
def debug_render(
    model,
    rays_o,
    rays_d,
    H,
    W,
    near=0.5,
    far=1.5,
    sigma_scale=2.0,
    N_samples=64, # Fewer samples for faster debug rendering
    device=None,
    title_prefix="debug",
    iteration=0,
    out_dir="debug_renders",
):
    """
    Render the entire image (front or side) for debugging.
    Saves color and alpha maps.
    """
    if device is None:
        device = rays_o.device
    os.makedirs(out_dir, exist_ok=True)

    model.eval() # Set model to evaluation mode
    B = rays_o.shape[0]
    chunk_size = 2048 # Process in chunks to avoid OOM
    all_rgb = []
    all_alpha = []

    for start in range(0, B, chunk_size):
        end = start + chunk_size
        rgb_chunk, alpha_chunk = render_rays(
            model,
            rays_o[start:end],
            rays_d[start:end],
            near=near,
            far=far,
            sigma_scale=sigma_scale,
            N_samples=N_samples,
        )
        all_rgb.append(rgb_chunk)
        all_alpha.append(alpha_chunk)

    # Concatenate chunks and reshape to image dimensions
    rgb_full = torch.cat(all_rgb, dim=0).reshape(H, W, 3).cpu().numpy()
    alpha_full = torch.cat(all_alpha, dim=0).reshape(H, W).cpu().numpy()
    model.train() # Set model back to train mode

    # 1. Save the RGB image
    plt.figure(figsize=(6, 6))
    plt.imshow(np.clip(rgb_full, 0, 1)) # Clip values to [0, 1]
    plt.title(f"{title_prefix}_rgb_iter_{iteration}")
    plt.axis("off")
    rgb_path = os.path.join(out_dir, f"{title_prefix}_rgb_iter_{iteration}.png")
    plt.savefig(rgb_path)
    plt.close()

    # 2. Save the alpha map
    plt.figure(figsize=(6, 6))
    plt.imshow(alpha_full, cmap="gray", vmin=0, vmax=1)
    plt.title(f"{title_prefix}_alpha_iter_{iteration}")
    plt.axis("off")
    alpha_path = os.path.join(out_dir, f"{title_prefix}_alpha_iter_{iteration}.png")
    plt.savefig(alpha_path)
    plt.close()

    print(
        f"[DEBUG RENDER] Saved {title_prefix} images at iter {iteration} in {out_dir}/"
    )


@torch.no_grad()
def debug_compare_mask_and_alpha(
    model,
    rays_o,
    rays_d,
    mask, # Ground truth mask for comparison
    H,
    W,
    near=0.5,
    far=1.5,
    sigma_scale=2.0,
    N_samples=64,
    device=None,
    title_prefix="debug",
    iteration=0,
    out_dir="debug_renders",
):
    """
    Render alpha for all rays, then show a side-by-side comparison
    of predicted alpha vs. ground-truth mask for debugging.
    """
    if device is None:
        device = rays_o.device
    os.makedirs(out_dir, exist_ok=True)

    model.eval() # Set model to evaluation mode
    # Render alpha
    B = rays_o.shape[0]
    chunk_size = 2048
    all_alpha = []
    for start in range(0, B, chunk_size):
        end = start + chunk_size
        _, alpha_chunk = render_rays( # Only need alpha
            model,
            rays_o[start:end],
            rays_d[start:end],
            near=near,
            far=far,
            sigma_scale=sigma_scale,
            N_samples=N_samples,
        )
        all_alpha.append(alpha_chunk)
    alpha_full = torch.cat(all_alpha, dim=0).reshape(H, W).cpu().numpy()
    model.train() # Set model back to train mode

    # Reshape the ground-truth mask as well
    mask_gt = mask.reshape(H, W).cpu().numpy()

    # Plot side-by-side
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(mask_gt, cmap="gray", vmin=0, vmax=1)
    plt.title(f"{title_prefix} GT Mask (iter={iteration})")
    plt.axis("off")

    plt.subplot(1, 2, 2)
    plt.imshow(alpha_full, cmap="gray", vmin=0, vmax=1)
    plt.title(f"{title_prefix} Predicted Alpha (iter={iteration})")
    plt.axis("off")

    compare_path = os.path.join(
        out_dir, f"{title_prefix}_mask_vs_alpha_iter_{iteration}.png"
    )
    plt.savefig(compare_path)
    plt.close()

    print(
        f"[DEBUG] Saved mask-vs-alpha comparison for {title_prefix} at iter {iteration} in {out_dir}/"
    )


# -----------------------------------------------------------------------------
# 8. 3D Extraction via Marching Cubes
# -----------------------------------------------------------------------------
def extract_3d_from_nerf(
    model, resolution=128, bound=1.0, sigma_scale=2.0, device=None, sigma_threshold=None
):
    """
    Extracts a 3D mesh from the learned NeRF density field using marching cubes.
    """
    print("\n[EXTRACT 3D] Running marching cubes...")
    if device is None:
        device = next(model.parameters()).device

    model.eval() # Set model to evaluation mode

    # Create a grid of coordinates covering the volume
    coords = (
        torch.stack(
            torch.meshgrid(
                torch.linspace(-bound, bound, resolution, device=device),
                torch.linspace(-bound, bound, resolution, device=device),
                torch.linspace(-bound, bound, resolution, device=device),
                indexing="ij", # Important: use 'ij' indexing for marching cubes
            ),
            dim=-1,
        )
        .reshape(-1, 3)
    )

    # Query the NeRF model for density (sigma) at each grid point
    sigmas = []
    chunk = 4096 # Process in chunks
    with torch.no_grad():
        for start in range(0, coords.shape[0], chunk):
            end = start + chunk
            out = model(coords[start:end])
            # Apply ReLU and scale sigma as done during rendering
            sigma_part = torch.relu(out[..., 3]) * sigma_scale
            sigmas.append(sigma_part.cpu()) # Move to CPU to save GPU memory
    # Reshape the flat sigma values into a 3D volume
    sigma_volume = torch.cat(sigmas).reshape(resolution, resolution, resolution).numpy()

    # Determine the iso-level for marching cubes
    vol_min, vol_max = sigma_volume.min(), sigma_volume.max()
    vol_mean, vol_std = sigma_volume.mean(), sigma_volume.std()
    print(
        f"  Sigma volume stats: min={vol_min:.4f}, max={vol_max:.4f}, mean={vol_mean:.4f}, std={vol_std:.4f}"
    )

    if sigma_threshold is None:
        # Heuristic: level slightly above the mean density
        level = vol_mean + 0.3 * vol_std
        # Ensure level is within the range of observed sigma values
        if (level <= vol_min) or (level >= vol_max):
            level = vol_mean # Fallback to mean if heuristic is out of bounds
        print(f"  Using iso-level={level:.4f} (heuristic)")
    else:
        level = sigma_threshold
        print(f"  Using iso-level={level:.4f} (user-defined)")


    mesh = None
    try:
        # Run marching cubes
        verts, faces, normals, _ = marching_cubes(sigma_volume, level=level)

        # Rescale vertices from grid coordinates [0, resolution-1] to world coordinates [-bound, bound]
        verts = (verts / (resolution - 1.0)) * (2.0 * bound) - bound # More accurate scaling

        mesh = trimesh.Trimesh(vertices=verts, faces=faces, process=False) # process=False avoids auto-fixing
        # Attempt to fix inverted faces if normals are not provided or inconsistent
        mesh.fix_normals()
        mesh.fill_holes() # Optional: try to fill holes
        mesh.remove_unreferenced_vertices() # Clean up

        # Check if mesh is valid after processing
        if not mesh.is_watertight:
            print("  Warning: Extracted mesh is not watertight.")
        if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
             print("  Warning: Marching cubes resulted in an empty mesh.")
             mesh = None # Treat as failure
        else:
             mesh.export("nerf_reconstruction.stl")
             print(f"  --> Saved mesh ({len(mesh.vertices)} verts, {len(mesh.faces)} faces) to nerf_reconstruction.stl")

    except Exception as e:
        print(f"  Marching cubes error: {e}")
        print(f"  Failed to extract mesh at level {level:.4f}.")
        mesh = None # Ensure mesh is None on failure

    model.train() # Set model back to train mode
    return mesh


# -----------------------------------------------------------------------------
# 9. Chamfer Distance
# -----------------------------------------------------------------------------
def chamfer_distance(points1, points2, num_points_sample=None):
    """
    Calculates the Chamfer distance between two point clouds.
    Optionally subsamples points for faster computation.
    points1, points2: numpy arrays of shape (N, 3) and (M, 3)
    """
    if num_points_sample:
        if points1.shape[0] > num_points_sample:
            idx1 = np.random.choice(points1.shape[0], num_points_sample, replace=False)
            points1 = points1[idx1]
        if points2.shape[0] > num_points_sample:
            idx2 = np.random.choice(points2.shape[0], num_points_sample, replace=False)
            points2 = points2[idx2]

    if points1.shape[0] == 0 or points2.shape[0] == 0:
        print("[Chamfer] Warning: One or both point clouds are empty.")
        return float('inf') # Or handle as appropriate

    # Build KD-Trees for efficient nearest neighbor search
    tree1 = cKDTree(points1)
    tree2 = cKDTree(points2)

    # Find nearest neighbor distances:
    # d1: distances from points2 to their nearest neighbor in points1
    # d2: distances from points1 to their nearest neighbor in points2
    d1, _ = tree1.query(points2)
    d2, _ = tree2.query(points1)

    # Chamfer distance is the sum of the mean squared distances
    chamfer_dist = np.mean(d1**2) + np.mean(d2**2)
    return chamfer_dist


# -----------------------------------------------------------------------------
# 10. Plot Meshes
# -----------------------------------------------------------------------------
def plot_meshes(
    gt_mesh, pred_mesh, outpath="mesh_comparison.png", title="Mesh Comparison"
):
    """Plots ground truth and predicted meshes side-by-side after centering and scaling."""
    from mpl_toolkits.mplot3d import Axes3D # Local import for plotting

    if gt_mesh is None or pred_mesh is None:
        print("[PLOT] Cannot plot, one or both meshes are None.")
        return
    if len(gt_mesh.vertices) == 0 or len(pred_mesh.vertices) == 0:
        print("[PLOT] Cannot plot, one or both meshes have no vertices.")
        return
    if len(gt_mesh.faces) == 0 or len(pred_mesh.faces) == 0:
        print("[PLOT] Cannot plot, one or both meshes have no faces (needed for plot_trisurf).")
        # Optionally, could plot point clouds instead using ax.scatter
        return

    fig = plt.figure(figsize=(12, 6)) # Wider figure for side-by-side

    # --- Plot Ground Truth Mesh ---
    ax1 = fig.add_subplot(121, projection="3d")
    # Center the mesh
    center_gt = gt_mesh.vertices.mean(axis=0)
    vertices_gt_centered = gt_mesh.vertices - center_gt
    # Scale to fit in unit cube (approx)
    scale_gt = np.max(np.linalg.norm(vertices_gt_centered, axis=1))
    vertices_gt_norm = vertices_gt_centered / scale_gt if scale_gt > 1e-6 else vertices_gt_centered

    ax1.plot_trisurf(
        vertices_gt_norm[:, 0],
        vertices_gt_norm[:, 1],
        vertices_gt_norm[:, 2],
        triangles=gt_mesh.faces,
        color="blue",
        alpha=0.7,
        edgecolor='gray', # Add edges for clarity
        linewidth=0.1
    )
    ax1.set_title("Ground Truth Mesh")
    ax1.set_box_aspect([1, 1, 1]) # Equal aspect ratio
    ax1.set_xlabel("X"); ax1.set_ylabel("Y"); ax1.set_zlabel("Z")
    ax1.set_xlim([-1, 1]); ax1.set_ylim([-1, 1]); ax1.set_zlim([-1, 1]) # Consistent bounds

    # --- Plot Predicted Mesh ---
    ax2 = fig.add_subplot(122, projection="3d")
    # Center the mesh
    center_pred = pred_mesh.vertices.mean(axis=0)
    vertices_pred_centered = pred_mesh.vertices - center_pred
    # Scale to fit in unit cube (approx)
    scale_pred = np.max(np.linalg.norm(vertices_pred_centered, axis=1))
    vertices_pred_norm = vertices_pred_centered / scale_pred if scale_pred > 1e-6 else vertices_pred_centered

    ax2.plot_trisurf(
        vertices_pred_norm[:, 0],
        vertices_pred_norm[:, 1],
        vertices_pred_norm[:, 2],
        triangles=pred_mesh.faces,
        color="red",
        alpha=0.7,
        edgecolor='gray',
        linewidth=0.1
    )
    ax2.set_title("Predicted Mesh (NeRF)")
    ax2.set_box_aspect([1, 1, 1]) # Equal aspect ratio
    ax2.set_xlabel("X"); ax2.set_ylabel("Y"); ax2.set_zlabel("Z")
    ax2.set_xlim([-1, 1]); ax2.set_ylim([-1, 1]); ax2.set_zlim([-1, 1]) # Consistent bounds

    plt.suptitle(title)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust layout to fit suptitle
    plt.savefig(outpath)
    # plt.show() # Uncomment to display plot interactively
    plt.close()
    print(f"[PLOT] Saved mesh comparison to {outpath}")


# -----------------------------------------------------------------------------
# 11. Training Loop (with Additional Debug)
# -----------------------------------------------------------------------------
def train_nerf(
    model,
    rays_o_all,
    rays_d_all,
    target_pixels_all,
    mask_all,
    image_shape, # Tuple (H, W)
    num_iterations=8000,
    device=None,
    near=0.5,
    far=1.5,
    sigma_scale=2.0,
    debug_interval=1000,
    out_dir="debug_renders",
    lr=5e-4,
    batch_size=1024,
    N_samples_train=64, # Samples per ray during training
    N_samples_debug=128, # Samples per ray for debug renders
    lambda_photo=1.0, # Weight for RGB photo loss
    lambda_sil=1.0,   # Weight for silhouette loss
    lambda_shape=1e-3, # Weight for spherical prior loss
    lambda_density=0.2, # Weight for foreground density loss
    lambda_smooth=0.1, # Weight for smoothness loss
    desired_radius=0.6 # Target radius for spherical prior
):
    """
    Trains the NeRF model using provided ray data and masks.

    Features:
    - Weighted ray sampling (emphasizing edges and foreground).
    - Combined loss: photo, silhouette, spherical prior, density, smoothness.
    - Periodic debug rendering (RGB, alpha, mask comparison).
    - Learning rate scheduling and AMP (Automatic Mixed Precision).
    - Saves the best model based on validation loss (using total loss here).
    """
    H, W = image_shape
    if device is None:
        device = next(model.parameters()).device

    optimizer = optim.Adam(model.parameters(), lr=lr)
    # Scheduler reduces LR if loss plateaus
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=500, verbose=True # Increased patience
    )
    # AMP for faster training and less memory usage
    scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())

    print("[TRAIN] Loss weights:")
    print(f"  Photo:    {lambda_photo}")
    print(f"  Sil:      {lambda_sil}")
    print(f"  Shape:    {lambda_shape}")
    print(f"  Density:  {lambda_density}")
    print(f"  Smooth:   {lambda_smooth}")

    best_loss = float("inf")
    print("\n[TRAIN] Starting training...")

    # --- Training Loop ---
    t_range = trange(num_iterations, desc="Training", unit="iter")
    for i in t_range:
        model.train() # Ensure model is in training mode
        optimizer.zero_grad()

        # Sample a batch of rays using weighted strategy
        rays_o_batch, rays_d_batch, rgb_batch, mask_batch = sample_rays_weighted(
            rays_o_all,
            rays_d_all,
            target_pixels_all,
            mask_all,
            original_shape=(H, W),
            batch_size=batch_size,
        )

        # Use AMP context manager
        with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
            # Render rays for the sampled batch
            rgb_map, alpha_map = render_rays(
                model,
                rays_o_batch,
                rays_d_batch,
                near=near,
                far=far,
                sigma_scale=sigma_scale,
                N_samples=N_samples_train, # Use training sample count
            )

            # Calculate individual loss components
            # 1. Photometric Loss (RGB difference) - only where mask is > 0? Optional.
            # photo_loss = torch.mean((rgb_map[mask_batch > 0.5] - rgb_batch[mask_batch > 0.5]) ** 2) if torch.sum(mask_batch > 0.5) > 0 else torch.tensor(0.0, device=device)
            photo_loss = torch.mean((rgb_map - rgb_batch)**2) # Simpler: use all pixels

            # 2. Silhouette Loss (Alpha vs Mask)
            sil_loss_val = silhouette_loss(alpha_map, mask_batch)

            # 3. Spherical Shape Prior Loss
            shape_loss_val = spherical_prior_loss(
                model,
                num_samples=1024, # Fewer samples okay within training loop
                bound=1.0, # Volume bounds
                desired_radius=desired_radius,
                sigma_scale=sigma_scale,
                device=device,
            )

            # 4. Foreground Density Loss
            dens_loss_val = foreground_density_loss(
                alpha_map, mask_batch, target_density=1.0 # Encourage high density where mask is 1
            )

            # 5. Smoothness Prior Loss (Density smoothness)
            smooth_loss_val = smoothness_prior_loss(
                model,
                num_samples=1024, # Fewer samples
                bound=1.0,
                offset=0.02, # Slightly larger offset might be okay
                sigma_scale=sigma_scale,
                device=device
            )

            # Combine losses with their respective weights
            total_loss = (
                lambda_photo * photo_loss
                + lambda_sil * sil_loss_val
                + lambda_shape * shape_loss_val
                + lambda_density * dens_loss_val
                + lambda_smooth * smooth_loss_val
            )

        # Backpropagation using AMP scaler
        scaler.scale(total_loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # Update learning rate based on total loss
        scheduler.step(total_loss)

        # --- Logging ---
        if (i + 1) % 200 == 0:
            t_range.set_postfix(
                loss=f"{total_loss.item():.4f}",
                photo=f"{photo_loss.item():.4f}",
                sil=f"{sil_loss_val.item():.4f}",
                shape=f"{shape_loss_val.item():.4f}",
                dens=f"{dens_loss_val.item():.4f}",
                smooth=f"{smooth_loss_val.item():.4f}",
                lr=f"{optimizer.param_groups[0]['lr']:.1e}"
            )
            # Save best model based on total loss
            if total_loss.item() < best_loss:
                best_loss = total_loss.item()
                torch.save(model.state_dict(), "nerf_best_model.pth")
                # print(f"[TRAIN @ Iter {i+1}] ==> New best model saved (Loss: {best_loss:.4f}).")


        # --- Debug Rendering & Checkpointing ---
        if debug_interval > 0 and (i + 1) % debug_interval == 0:
            print(f"\n[DEBUG @ Iter {i + 1}] Rendering debug images...")
            model.eval() # Switch to eval mode for rendering

            # 1) Render front view (assuming first H*W rays are front)
            if rays_o_all.shape[0] >= H * W:
                front_rays_o = rays_o_all[: H * W]
                front_rays_d = rays_d_all[: H * W]
                front_mask = mask_all[: H * W]

                debug_render(
                    model, front_rays_o, front_rays_d, H, W,
                    near=near, far=far, sigma_scale=sigma_scale, N_samples=N_samples_debug,
                    device=device, title_prefix="front", iteration=i + 1, out_dir=out_dir,
                )
                debug_compare_mask_and_alpha(
                    model, front_rays_o, front_rays_d, front_mask, H, W,
                    near=near, far=far, sigma_scale=sigma_scale, N_samples=N_samples_debug,
                    device=device, title_prefix="front", iteration=i + 1, out_dir=out_dir,
                )

            # 2) Render side view (assuming next H*W rays are side)
            if rays_o_all.shape[0] >= 2 * H * W:
                side_rays_o = rays_o_all[H * W : 2 * H * W]
                side_rays_d = rays_d_all[H * W : 2 * H * W]
                side_mask = mask_all[H * W : 2 * H * W]

                debug_render(
                    model, side_rays_o, side_rays_d, H, W,
                    near=near, far=far, sigma_scale=sigma_scale, N_samples=N_samples_debug,
                    device=device, title_prefix="side", iteration=i + 1, out_dir=out_dir,
                )
                debug_compare_mask_and_alpha(
                    model, side_rays_o, side_rays_d, side_mask, H, W,
                    near=near, far=far, sigma_scale=sigma_scale, N_samples=N_samples_debug,
                    device=device, title_prefix="side", iteration=i + 1, out_dir=out_dir,
                )

            # Save checkpoint
            ckpt_path = os.path.join(out_dir, f"nerf_checkpoint_{i + 1}.pth")
            torch.save(model.state_dict(), ckpt_path)
            print(f"[TRAIN] Saved checkpoint {ckpt_path}\n")
            model.train() # Switch back to train mode

    print("\n[TRAIN] Training finished.")
    # Load the best performing model before returning
    if os.path.exists("nerf_best_model.pth"):
        print("[TRAIN] Loading best model state dict.")
        model.load_state_dict(torch.load("nerf_best_model.pth"))
    return model


# -----------------------------------------------------------------------------
# 12. Main Execution Block
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    # --- Configuration ---
    DATA_AVAILABLE = PollenDataset is not None and get_train_test_split is not None
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    IMAGE_SIZE = 128 # Assumed image size if dataset not loaded
    FOCAL_LENGTH = 300.0 # Example focal length
    NEAR_BOUND = 0.4   # Near plane for rendering
    FAR_BOUND = 1.6    # Far plane for rendering
    SIGMA_SCALE_TRAIN = 1.0 # Multiplier for sigma during training
    SIGMA_SCALE_EXTRACT = 2.0 # Multiplier for sigma during mesh extraction (can differ)
    MESH_RESOLUTION = 192 # Resolution for marching cubes grid
    MESH_BOUND = 1.0 # Spatial bounds for mesh extraction [-bound, bound]
    SIGMA_THRESHOLD_EXTRACT = None # Or set a specific float value, e.g. 5.0
    NUM_ITERATIONS = 10000 # Training iterations
    DEBUG_INTERVAL = 2000 # How often to save debug images/checkpoints (0 to disable)
    DEBUG_DIR = "debug_nerf_pollen"

    print(f"[SYS] Using device: {DEVICE}")

    # --- 1. Load Data (if available) ---
    if DATA_AVAILABLE:
        print("[DATA] Loading PollenDataset...")
        image_transform = transforms.Compose([
            transforms.ToTensor(),
            # Add resize if needed: transforms.Resize((IMAGE_SIZE, IMAGE_SIZE))
        ])
        # Assuming get_train_test_split returns dataset, train_ids, test_ids
        # And dataset provides: (left_img, right_img), gt_mesh_data, rotations, voxels
        # where gt_mesh_data could be points, mesh path, etc.
        dataset, train_ids, test_ids = get_train_test_split(
            image_transforms=image_transform,
            mesh_transforms=None, # Assuming mesh is loaded directly if needed
            device=DEVICE,
            # Add base_dir if your dataset requires it
            # base_dir="path/to/your/pollen/data"
        )

        # Select a sample (e.g., the first training sample)
        sample_idx = train_ids[0]
        # Unpack the data - ADJUST THIS BASED ON YOUR PollenDataset.__getitem__
        (left_img, right_img), gt_mesh_data, rotations, voxels = dataset[sample_idx]

        # --- Ground Truth Mesh Handling ---
        # Try to load the ground truth mesh if gt_mesh_data is a path or similar
        gt_mesh = None
        gt_points = None
        if isinstance(gt_mesh_data, str) and os.path.exists(gt_mesh_data):
            try:
                gt_mesh = trimesh.load(gt_mesh_data, process=False)
                gt_points = gt_mesh.sample(5000) # Sample points for Chamfer
                print(f"[DATA] Loaded GT mesh: {gt_mesh_data} ({len(gt_mesh.vertices)} verts, {len(gt_mesh.faces)} faces)")
            except Exception as e:
                print(f"[DATA] Warning: Could not load GT mesh from {gt_mesh_data}: {e}")
        elif isinstance(gt_mesh_data, torch.Tensor): # Assuming it's points if Tensor
             gt_points = gt_mesh_data.cpu().numpy()
             print(f"[DATA] Loaded GT points: {gt_points.shape}")
        elif isinstance(gt_mesh_data, trimesh.Trimesh):
             gt_mesh = gt_mesh_data
             gt_points = gt_mesh.sample(5000)
             print(f"[DATA] Using provided GT trimesh object.")


        print(
            f"[DATA] Loaded sample #{sample_idx} -> "
            f"Images: L={left_img.shape}, R={right_img.shape}; "
            f"GT Points: {gt_points.shape if gt_points is not None else 'None'}; "
            f"Rotations (rad): {rotations.tolist()}; "
            # f"Voxels: {voxels.shape if hasattr(voxels, 'shape') else voxels}" # Optional
        )
        H, W = left_img.shape[1], left_img.shape[2]
        print(f"[DATA] Image dimensions: H={H}, W={W}")

        # --- 2. Preprocess Images ---
        # Ensure images are on the correct device and float type
        left_img = left_img.to(DEVICE).float()
        right_img = right_img.to(DEVICE).float()
        rotations = rotations.to(DEVICE).float()

        # Ensure 3 channels (repeat grayscale if needed)
        if left_img.ndim == 3 and left_img.shape[0] == 1: left_img = left_img.repeat(3, 1, 1)
        if right_img.ndim == 3 and right_img.shape[0] == 1: right_img = right_img.repeat(3, 1, 1)
        if left_img.ndim == 2: left_img = left_img.unsqueeze(0).repeat(3, 1, 1) # Add channel and repeat
        if right_img.ndim == 2: right_img = right_img.unsqueeze(0).repeat(3, 1, 1)

        # Normalize to [0, 1] if they are in [0, 255]
        if left_img.max() > 1.0: left_img /= 255.0
        if right_img.max() > 1.0: right_img /= 255.0

        # Display input images
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1); plt.imshow(left_img.permute(1, 2, 0).cpu().clamp(0,1)); plt.title("Left Input Image"); plt.axis("off")
        plt.subplot(1, 2, 2); plt.imshow(right_img.permute(1, 2, 0).cpu().clamp(0,1)); plt.title("Right Input Image"); plt.axis("off")
        plt.tight_layout(); #plt.show();
        plt.savefig(os.path.join(DEBUG_DIR, "input_images.png")); plt.close()
        print(f"[DATA] Saved input images view to {os.path.join(DEBUG_DIR, 'input_images.png')}")


        # --- 3. Create Masks (Silhouettes) ---
        # Convert to grayscale, maybe blur slightly, then threshold
        left_gray = left_img.mean(dim=0, keepdim=True)
        right_gray = right_img.mean(dim=0, keepdim=True)
        # Optional blurring (helps remove noise)
        blur_kernel_size = 5
        left_gray = torch.nn.functional.avg_pool2d(
             left_gray.unsqueeze(0), blur_kernel_size, stride=1, padding=blur_kernel_size//2
        ).squeeze()
        right_gray = torch.nn.functional.avg_pool2d(
             right_gray.unsqueeze(0), blur_kernel_size, stride=1, padding=blur_kernel_size//2
        ).squeeze()
        # Threshold to create binary mask (adjust threshold 0.2 if needed)
        silhouette_threshold = 0.2
        left_mask = (left_gray > silhouette_threshold).float().reshape(-1) # Flatten
        right_mask = (right_gray > silhouette_threshold).float().reshape(-1) # Flatten

        # Display masks
        plt.figure(figsize=(8, 4))
        plt.subplot(1, 2, 1); plt.imshow(left_mask.reshape(H, W).cpu(), cmap='gray'); plt.title("Left Mask"); plt.axis("off")
        plt.subplot(1, 2, 2); plt.imshow(right_mask.reshape(H, W).cpu(), cmap='gray'); plt.title("Right Mask"); plt.axis("off")
        plt.tight_layout(); #plt.show();
        plt.savefig(os.path.join(DEBUG_DIR, "input_masks.png")); plt.close()
        print(f"[DATA] Saved input masks view to {os.path.join(DEBUG_DIR, 'input_masks.png')}")

        # --- 4. Prepare Ray Data ---
        # Flatten target image pixels (N, 3)
        left_img_flat = left_img.permute(1, 2, 0).reshape(-1, 3)
        right_img_flat = right_img.permute(1, 2, 0).reshape(-1, 3)

        # Generate canonical (unrotated) rays for front and side views
        rays_o_f, rays_d_f, rays_o_s, rays_d_s = get_rays(H, W, focal=FOCAL_LENGTH)
        rays_o_f, rays_d_f = rays_o_f.to(DEVICE), rays_d_f.to(DEVICE)
        rays_o_s, rays_d_s = rays_o_s.to(DEVICE), rays_d_s.to(DEVICE)


        # Rotate rays according to the sample's specific rotation
        # Assuming 'left' view corresponds to 'front' canonical rays
        # Assuming 'right' view corresponds to 'side' canonical rays
        rays_o_left, rays_d_left = rotate_rays(rays_o_f, rays_d_f, rotations)
        rays_o_right, rays_d_right = rotate_rays(rays_o_s, rays_d_s, rotations)

        # Concatenate data from both views for training
        rays_o_all = torch.cat([rays_o_left, rays_o_right], dim=0)
        rays_d_all = torch.cat([rays_d_left, rays_d_right], dim=0)
        target_pixels_all = torch.cat([left_img_flat, right_img_flat], dim=0)
        mask_all = torch.cat([left_mask, right_mask], dim=0)

        print(f"[DATA] Prepared training data shapes:")
        print(f"  rays_o_all: {rays_o_all.shape}")
        print(f"  rays_d_all: {rays_d_all.shape}")
        print(f"  target_pixels_all: {target_pixels_all.shape}")
        print(f"  mask_all: {mask_all.shape}")

    else:
        # --- Fallback: Create dummy data if dataset not loaded ---
        print("[DATA] PollenDataset not found. Creating dummy data...")
        H, W = IMAGE_SIZE, IMAGE_SIZE
        # Dummy images (e.g., white square on black bg)
        left_img = torch.zeros(3, H, W, device=DEVICE)
        right_img = torch.zeros(3, H, W, device=DEVICE)
        left_img[:, H//4:3*H//4, W//4:3*W//4] = 1.0
        right_img[:, H//4:3*H//4, W//4:3*W//4] = 1.0
        # Dummy masks
        left_mask = (left_img.mean(0) > 0.1).float().reshape(-1)
        right_mask = (right_img.mean(0) > 0.1).float().reshape(-1)
        # Dummy rotations (identity)
        rotations = torch.zeros(3, device=DEVICE)
        # Dummy GT points (random sphere)
        phi = np.random.uniform(0, np.pi, 5000)
        costheta = np.random.uniform(-1, 1, 5000)
        theta = np.arccos(costheta)
        x = 0.6 * np.sin(phi) * np.sin(theta)
        y = 0.6 * np.sin(phi) * np.cos(theta)
        z = 0.6 * np.cos(phi)
        gt_points = np.stack([x, y, z], axis=-1)
        gt_mesh = None # No dummy mesh


        left_img_flat = left_img.permute(1, 2, 0).reshape(-1, 3)
        right_img_flat = right_img.permute(1, 2, 0).reshape(-1, 3)
        rays_o_f, rays_d_f, rays_o_s, rays_d_s = get_rays(H, W, focal=FOCAL_LENGTH)
        rays_o_f, rays_d_f = rays_o_f.to(DEVICE), rays_d_f.to(DEVICE)
        rays_o_s, rays_d_s = rays_o_s.to(DEVICE), rays_d_s.to(DEVICE)
        rays_o_left, rays_d_left = rotate_rays(rays_o_f, rays_d_f, rotations)
        rays_o_right, rays_d_right = rotate_rays(rays_o_s, rays_d_s, rotations)
        rays_o_all = torch.cat([rays_o_left, rays_o_right], dim=0)
        rays_d_all = torch.cat([rays_d_left, rays_d_right], dim=0)
        target_pixels_all = torch.cat([left_img_flat, right_img_flat], dim=0)
        mask_all = torch.cat([left_mask, right_mask], dim=0)


    # --- 5. Initialize NeRF Model ---
    model = NeRF(D=6, W=128, L=4).to(DEVICE) # Keep parameters relatively small
    print(f"[MODEL] Initialized NeRF (D=6, W=128, L=4) on {DEVICE}.")

    # --- 6. Train the Model ---
    # Create debug directory if it doesn't exist
    os.makedirs(DEBUG_DIR, exist_ok=True)
    model = train_nerf(
        model,
        rays_o_all,
        rays_d_all,
        target_pixels_all,
        mask_all,
        image_shape=(H, W),
        num_iterations=NUM_ITERATIONS,
        device=DEVICE,
        near=NEAR_BOUND,
        far=FAR_BOUND,
        sigma_scale=SIGMA_SCALE_TRAIN,
        debug_interval=DEBUG_INTERVAL,
        out_dir=DEBUG_DIR,
        # Can adjust other train_nerf parameters here if needed
        # lr=5e-4, batch_size=1024, lambda_sil=1.0, etc.
    )

    # --- 7. Extract 3D Mesh ---
    print("\n[RESULT] Extracting final mesh...")
    pred_mesh = extract_3d_from_nerf(
        model,
        resolution=MESH_RESOLUTION,
        bound=MESH_BOUND,
        sigma_scale=SIGMA_SCALE_EXTRACT, # Potentially different scale for extraction
        device=DEVICE,
        sigma_threshold=SIGMA_THRESHOLD_EXTRACT # Use threshold if defined
    )

    # --- 8. Evaluate (Chamfer Distance) & Visualize ---
    if pred_mesh is not None and gt_points is not None:
        print("[RESULT] Calculating Chamfer Distance...")
        pred_points = pred_mesh.sample(5000) # Sample points from predicted mesh
        # Ensure gt_points is numpy array on CPU
        if isinstance(gt_points, torch.Tensor):
            gt_points_np = gt_points.cpu().numpy()
        else:
            gt_points_np = gt_points

        cd = chamfer_distance(pred_points, gt_points_np)
        print(f"  --> Chamfer Distance (Pred Mesh vs. GT Points): {cd:.6f}")

        # Optional: Plot comparison if GT mesh also exists
        if gt_mesh is not None:
             plot_meshes(gt_mesh, pred_mesh, outpath=os.path.join(DEBUG_DIR,"final_mesh_comparison.png"))
        else:
             # Just save the predicted mesh render if no GT mesh
             try:
                 scene = trimesh.Scene(pred_mesh)
                 scene.camera.elevation = np.radians(30) # Set view angle
                 png = scene.save_image(resolution=(600, 600))
                 with open(os.path.join(DEBUG_DIR,"final_predicted_mesh.png"), "wb") as f:
                     f.write(png)
                 print(f"[PLOT] Saved predicted mesh render to {os.path.join(DEBUG_DIR,'final_predicted_mesh.png')}")
             except Exception as e:
                 print(f"[PLOT] Failed to render predicted mesh: {e}")

    elif pred_mesh is None:
        print("[RESULT] Mesh extraction failed. No evaluation possible.")
    else: # pred_mesh exists but gt_points is None
        print("[RESULT] Predicted mesh extracted, but no GT points for comparison.")
        # Save predicted mesh render
        try:
            scene = trimesh.Scene(pred_mesh)
            scene.camera.elevation = np.radians(30)
            png = scene.save_image(resolution=(600, 600))
            with open(os.path.join(DEBUG_DIR,"final_predicted_mesh.png"), "wb") as f:
                 f.write(png)
            print(f"[PLOT] Saved predicted mesh render to {os.path.join(DEBUG_DIR,'final_predicted_mesh.png')}")
        except Exception as e:
            print(f"[PLOT] Failed to render predicted mesh: {e}")


    print("\n[DONE] Script finished.")

Device: cuda


FileNotFoundError: [WinError 3] The system cannot find the path specified: 'C:/Users/super/Documents/GitHub/sequoia/data\\processed\\images'