In [None]:


import os
import sys
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from tqdm import trange
from scipy.spatial import cKDTree
from skimage.measure import marching_cubes
import trimesh

# If you have the data package
sys.path.append("..")
try:
    from data.pollen_dataset import PollenDataset, get_train_test_split
except ImportError:
    PollenDataset = None
    get_train_test_split = None

torch.backends.cudnn.benchmark = True
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


# -----------------------------------------------------------------------------
# 1. Positional Encoding (Reduced Frequencies)
# -----------------------------------------------------------------------------
def positional_encoding(x, L=4):
    out = [x]
    for i in range(L):
        for fn in (torch.sin, torch.cos):
            out.append(fn((2.0**i) * np.pi * x))
    return torch.cat(out, dim=-1)


# -----------------------------------------------------------------------------
# 2. NeRF Model
# -----------------------------------------------------------------------------
class NeRF(nn.Module):
    def __init__(self, D=6, W=128, L=4):
        super(NeRF, self).__init__()
        self.L = L
        in_ch = 3 * (2 * L + 1)
        layers = [nn.Linear(in_ch, W)] + [nn.Linear(W, W) for _ in range(D - 1)]
        self.layers = nn.ModuleList(layers)
        self.output_layer = nn.Linear(W, 4)
        with torch.no_grad():
            self.output_layer.bias[3] = 0.1

    def forward(self, x):
        x_enc = positional_encoding(x, self.L)
        h = x_enc
        for l in self.layers:
            h = torch.relu(l(h))
        return self.output_layer(h)


# -----------------------------------------------------------------------------
# 3. Render Rays
# -----------------------------------------------------------------------------
def render_rays(
    model, rays_o, rays_d, near=0.5, far=1.5, N_samples=128, sigma_scale=1.0
):
    device = rays_o.device
    z_vals = torch.linspace(near, far, N_samples, device=device)
    pts = rays_o[:, None, :] + rays_d[:, None, :] * z_vals[None, :, None]
    raw = model(pts.reshape(-1, 3)).reshape(pts.shape[0], N_samples, 4)
    rgb = torch.sigmoid(raw[..., :3])
    sigma = torch.relu(raw[..., 3]) * sigma_scale
    deltas = torch.cat([z_vals[1:] - z_vals[:-1], torch.tensor([1e10], device=device)])
    deltas = deltas[None, :].expand(sigma.shape)
    alpha = 1.0 - torch.exp(-sigma * deltas)
    T = torch.cumprod(
        torch.cat(
            [torch.ones((sigma.shape[0], 1), device=device), 1 - alpha + 1e-10], dim=-1
        ),
        dim=-1,
    )[:, :-1]
    weights = alpha * T
    rgb_map = torch.sum(weights[..., None] * rgb, dim=1)
    alpha_map = torch.sum(weights, dim=1)
    return rgb_map, alpha_map


# -----------------------------------------------------------------------------
# 4. Losses
# -----------------------------------------------------------------------------
def silhouette_loss(alpha, mask):
    return torch.mean((alpha - mask) ** 2)


def spherical_prior_loss(
    model, num_samples=2000, bound=1.0, desired_radius=0.6, sigma_scale=2.0, device=None
):
    if device is None:
        device = next(model.parameters()).device
    coords = torch.rand(num_samples, 3, device=device) * (2 * bound) - bound
    sigma = torch.relu(model(coords)[..., 3]) * sigma_scale
    d = torch.norm(coords, dim=1)
    return torch.mean(sigma * (d - desired_radius) ** 2)


def foreground_density_loss(alpha, mask, target_density=1.0):
    D = -torch.log(1 - alpha + 1e-6)
    fg = mask > 0.5
    if fg.sum() > 0:
        return torch.mean(torch.clamp(target_density - D[fg], min=0.0))
    return torch.tensor(0.0, device=alpha.device)


def smoothness_prior_loss(
    model, num_samples=2000, bound=1.0, offset=0.01, sigma_scale=2.0, device=None
):
    if device is None:
        device = next(model.parameters()).device
    coords = torch.rand(num_samples, 3, device=device) * (2 * bound) - bound
    sigma0 = torch.relu(model(coords)[..., 3]) * sigma_scale
    offsets = torch.tensor(
        [
            [offset, 0, 0],
            [-offset, 0, 0],
            [0, offset, 0],
            [0, -offset, 0],
            [0, 0, offset],
            [0, 0, -offset],
        ],
        device=device,
    )
    diffs = []
    for off in offsets:
        sigma1 = torch.relu(model(coords + off)[..., 3]) * sigma_scale
        diffs.append(torch.mean((sigma0 - sigma1) ** 2))
    return sum(diffs) / len(diffs)


# New strong priors:
def radial_profile_loss(
    model,
    num_samples=5000,
    bound=1.0,
    desired_radius=0.6,
    sigma_scale=2.0,
    width=0.05,
    device=None,
):
    if device is None:
        device = next(model.parameters()).device
    coords = (torch.rand(num_samples, 3, device=device) * 2 - 1) * bound
    sigma = torch.relu(model(coords)[..., 3]) * sigma_scale
    d = torch.norm(coords, dim=1)
    target = torch.exp(-0.5 * ((d - desired_radius) / width) ** 2)
    return torch.mean((sigma - target) ** 2)


def symmetry_loss(model, num_samples=5000, bound=1.0, sigma_scale=2.0, device=None):
    if device is None:
        device = next(model.parameters()).device
    coords = (torch.rand(num_samples, 3, device=device) * 2 - 1) * bound
    sigma0 = torch.relu(model(coords)[..., 3]) * sigma_scale
    losses = []
    for axis in range(3):
        refl = coords.clone()
        refl[:, axis] *= -1
        sigma1 = torch.relu(model(refl)[..., 3]) * sigma_scale
        losses.append(torch.mean((sigma0 - sigma1) ** 2))
    return sum(losses) / len(losses)


# -----------------------------------------------------------------------------
# 5. Rays & Rotation
# -----------------------------------------------------------------------------
import torch
from trimesh.transformations import euler_matrix


def get_rays(H, W, focal=300.0):
    i, j = torch.meshgrid(
        torch.linspace(0, W - 1, W),
        torch.linspace(0, H - 1, H),
        indexing="xy",
    )
    dirs = torch.stack(
        [(i - W / 2) / focal, -(j - H / 2) / focal, -torch.ones_like(i)],
        dim=-1,
    )
    dirs = dirs / torch.norm(dirs, dim=-1, keepdim=True)
    orig = torch.zeros_like(dirs)
    return orig.reshape(-1, 3), dirs.reshape(-1, 3)

def rotate_rays(o, d, angles_deg):
    """
    angles_deg: tensor([rx, ry, rz]) in degrees, 'sxyz' convention
    """
    # → in radians für euler_matrix
    ang = angles_deg * np.pi / 180.0
    R4 = euler_matrix(float(ang[0]), float(ang[1]), float(ang[2]), "sxyz")
    R = torch.from_numpy(R4[:3, :3]).float().to(o.device)
    return (R @ o.T).T, (R @ d.T).T


# -----------------------------------------------------------------------------
# 6. Weighted Sampling
# -----------------------------------------------------------------------------
def sample_rays_weighted(rays_o, rays_d, rgb, mask, original_shape, batch_size=1024):
    H, W = original_shape
    ppv = H * W
    weights = []
    # two views
    for v in range(2):
        m = mask[v * ppv : (v + 1) * ppv].reshape(H, W)
        k = (
            torch.tensor(
                [[-1, -1, -1], [-1, 8, -1], [-1, -1, -1]], device=m.device
            ).float()
            / 8
        )
        edges = torch.abs(
            torch.nn.functional.conv2d(
                m.unsqueeze(0).unsqueeze(0), k.unsqueeze(0).unsqueeze(0), padding=1
            )
        ).reshape(H, W)
        w = (
            edges.reshape(-1)
            + 0.1
            + (mask[v * ppv : (v + 1) * ppv] > 0.5).float() * 2.0
        )
        weights.append(w)
    p = torch.cat(weights)
    p /= p.sum()
    idx = torch.multinomial(p, batch_size, replacement=True)
    return rays_o[idx], rays_d[idx], rgb[idx], mask[idx]


# -----------------------------------------------------------------------------
# 7. Debug & 8. Marching Cubes, 9. Chamfer same as before
# ... (omitted for brevity; copy your existing debug_render, debug_compare, extract_3d_from_nerf, chamfer_distance)


# -----------------------------------------------------------------------------
# 10. Training
# -----------------------------------------------------------------------------
def train_nerf(
    model,
    rays_o_all,
    rays_d_all,
    target_pixels_all,
    mask_all,
    image_shape,
    num_iterations=8000,
    
    device=None,
    near=0.5,
    far=1.5,
    sigma_scale=2.0,
    debug_interval=1000,
    out_dir="debug_renders",
):
    H, W = image_shape
    if device is None:
        device = next(model.parameters()).device
    opt = optim.Adam(model.parameters(), lr=5e-4)
    sch = optim.lr_scheduler.ReduceLROnPlateau(
        opt, mode="min", factor=0.5, patience=300
    )
    scaler = torch.cuda.amp.GradScaler()
    
    front_o = canonical_rays["front_o"].to(device)
    front_d = canonical_rays["front_d"].to(device)
    side_o  = canonical_rays["side_o"].to(device)
    side_d  = canonical_rays["side_d"].to(device)

    # strong prior lambdas
    lambda_sil = 4.0
    lambda_shape = 3.0
    lambda_density = 0.5
    lambda_smooth = 0.5
    lambda_radial = 3.0
    lambda_sym = 5.0

    best = 1e9
    for i in trange(num_iterations, desc="Training"):
        opt.zero_grad()
        ro, rd, rgbB, mB = sample_rays_weighted(
            rays_o_all, rays_d_all, target_pixels_all, mask_all, (H, W), 1024
        )
        with torch.cuda.amp.autocast():
            rays_o_front = front_o + t_front.unsqueeze(0)
            rays_o_side  = side_o  + t_side .unsqueeze(0)
            rays_o_all   = torch.cat([rays_o_front, rays_o_side], dim=0)
            rays_d_all   = torch.cat([front_d, side_d], dim=0)

            # 2) daraus dann Rays sample’n
            rays_o_batch, rays_d_batch, rgb_batch, mask_batch = sample_rays_weighted(
                rays_o_all, rays_d_all, target_pixels_all, mask_all,
                original_shape=(H,W), batch_size=1024
            )
            rgb_map, alpha_map = render_rays(model, ro, rd, near, far, 64, sigma_scale)
            Lp = torch.mean((rgb_map - rgbB) ** 2)
            Ls = silhouette_loss(alpha_map, mB)
            Lh = spherical_prior_loss(model, device=device)
            Ld = foreground_density_loss(alpha_map, mB)
            Lsm = smoothness_prior_loss(model, device=device)
            Lr = radial_profile_loss(model, device=device)
            Lsy = symmetry_loss(model, device=device)
            loss = (
                Lp
                + lambda_sil * Ls
                + lambda_shape * Lh
                + lambda_density * Ld
                + lambda_smooth * Lsm
                + lambda_radial * Lr
                + lambda_sym * Lsy
            )
        scaler.scale(loss).backward()
        scaler.step(opt)
        scaler.update()
        sch.step(loss)
        if (i + 1) % 200 == 0 and loss < best:
            best = loss
            torch.save(model.state_dict(), "nerf_best.pth")
        if (i + 1) % debug_interval == 0:
            # debug_render calls...
            pass
    return model


# -----------------------------------------------------------------------------
# 11. Main
# -----------------------------------------------------------------------------
if __name__ == "__main__":
    dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print("Device:", dev)
    tf = transforms.ToTensor()
    dataset, train_ids, _ = get_train_test_split(image_transforms=tf, device=dev)
    (l_img, r_img), pts, rot, vox = dataset[train_ids[3]]
    # plot the two images
    plt.subplot(1, 2, 1)
    plt.imshow(l_img.permute(1, 2, 0).cpu(), interpolation="nearest")
    plt.title("Left Image")
    plt.axis("off")
    plt.subplot(1, 2, 2)
    plt.imshow(r_img.permute(1, 2, 0).cpu(), interpolation="nearest")
    plt.title("Right Image")
    plt.axis("off")
    plt.show()
    # prepare two views as before
    H, W = l_img.shape[1], l_img.shape[2]
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # 1) Kanonische Front-Rays
    rays_o_f, rays_d_f = get_rays(H, W, focal=300.0)

    # 2) Kanonische Side-Rays: 90° Yaw um Y‑Achse
    yaw90 = torch.tensor([0.0, 90.0, 0.0], device=device)
    rays_o_s, rays_d_s = rotate_rays(rays_o_f, rays_d_f, yaw90)

    # 3) Auf die Sample-Rotation anwenden (deine rot aus dem Dataset)
    rot = torch.tensor(rot, device=device)  # z.B. tensor([rx, ry, rz])
    rays_o_front, rays_d_front = rotate_rays(rays_o_f, rays_d_f, rot)
    rays_o_side,  rays_d_side  = rotate_rays(rays_o_s, rays_d_s, rot)

    # 4) Beide Views zusammenpacken
    rays_o_all = torch.cat([rays_o_front, rays_o_side], dim=0).to(device)
    rays_d_all = torch.cat([rays_d_front, rays_d_side], dim=0).to(device)
    # target_pixels_all, mask_all built similarly
    model = NeRF().to(dev)
    model = train_nerf(
        model, rays_o_all, rays_d_all, target_pixels_all, mask_all, (H, W), device=dev
    )
    # extract mesh & chamfer

    print("Done.")
