In [None]:
!pip install warp-lang

In [None]:
import warp as wp
import numpy as np
import torch

wp.init()

In [24]:
import warp as wp
import numpy as np
import torch

# Initialize Warp
wp.init()

@wp.struct
class Drop:
    position: wp.vec3
    radius: wp.float32
    weight: wp.float32
    energy: wp.float32
    sigma: wp.float32

@wp.struct
class Ray:
    origin: wp.vec3
    dir: wp.vec3

# --- Differentiable Influence Function ---

@wp.func
def influence_func(sample_pos: wp.vec3, drop_pos: wp.vec3, sigma: wp.float32):
    dist_sq = wp.length_sq(sample_pos - drop_pos)
    sigma_sq = sigma * sigma
    return wp.exp(-dist_sq / (2.0 * sigma_sq + 1e-8))

# --- Volumetric Tracing Kernel with HashGrid ---

@wp.kernel
def volumetric_trace(
    # Scene Data
    grid: wp.uint64, # CHANGED: We now use a hash grid ID
    drop_positions: wp.array(dtype=wp.vec3),
    drop_learnable_params: wp.array(dtype=wp.float32, ndim=2),

    # Input/Output
    initial_rays: wp.array(dtype=Ray),
    outputs: wp.array(dtype=wp.float32),

    # Kernel Parameters
    num_steps: wp.int32,
    step_size: wp.float32,
    query_radius: wp.float32
):
    tid = wp.tid()
    ray = initial_rays[tid]
    accumulated_radiance = wp.float32(0.0)

    for i in range(num_steps):
        t = wp.float32(i) * step_size
        sample_pos = ray.origin + ray.dir * t

        # FIXED: Use the correct hash grid query function
        query = wp.hash_grid_query(grid, sample_pos, query_radius)
        candidate_index = wp.int32(0)
        
        local_radiance = wp.float32(0.0)

        # Loop over neighbors found by the query
        while wp.hash_grid_query_next(query, candidate_index):
            drop_weight = drop_learnable_params[candidate_index, 0]
            drop_energy = drop_learnable_params[candidate_index, 1]
            drop_sigma  = drop_learnable_params[candidate_index, 2]

            influence = influence_func(sample_pos, drop_positions[candidate_index], drop_sigma)
            contribution = influence * drop_weight * drop_energy
            local_radiance += contribution
        
        accumulated_radiance += local_radiance * step_size
    
    outputs[tid] = accumulated_radiance
@wp.kernel
def calculate_sparsity_cost(
    # Same inputs as the main kernel to trace the same paths
    grid: wp.uint64,
    initial_rays: wp.array(dtype=Ray),
    
    # Output
    total_neighbor_interactions: wp.array(dtype=wp.int32),

    # Kernel Parameters
    num_steps: wp.int32,
    step_size: wp.float32,
    query_radius: wp.float32
):
    tid = wp.tid()
    ray = initial_rays[tid]

    for i in range(num_steps):
        t = wp.float32(i) * step_size
        sample_pos = ray.origin + ray.dir * t

        query = wp.hash_grid_query(grid, sample_pos, query_radius)
        candidate_index = wp.int32(0)
        
        # Loop over neighbors and just count them
        while wp.hash_grid_query_next(query, candidate_index):
            # Atomically increment a global counter for each interaction
            wp.atomic_add(total_neighbor_interactions, 0, 1)

#main area

# Sim step
num_drops = 400
num_initial_rays = 4096
training_steps = 150
device = wp.get_preferred_device()
torch_device = torch.device('cuda' if wp.get_device(device).is_cuda else 'cpu')

# create drops
rng = np.random.default_rng(42)
positions_np = (rng.random(size=(num_drops, 3), dtype=np.float32) - 0.5) * 8.0
positions_np[:, 2] += 5.0
positions_device = wp.array(positions_np, dtype=wp.vec3, device=device)

print("Building HashGrid...")
grid_cell_size = 2.0 
grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device=device)
# The grid needs to be built with the particle positions
grid.build(points=positions_device, radius=grid_cell_size)
print(f"HashGrid built for {num_drops} drops.")

#drop properties
initial_weights = (rng.random(size=(num_drops, 1), dtype=np.float32) - 0.5) * 0.1
initial_energies = np.ones((num_drops, 1), dtype=np.float32)
initial_sigmas = np.full((num_drops, 1), 0.5, dtype=np.float32)
learnable_data_np = np.hstack([initial_weights, initial_energies, initial_sigmas])
learnable_params_tensor = torch.tensor(learnable_data_np, requires_grad=True, device=torch_device)

optimizer = torch.optim.Adam([learnable_params_tensor], lr=0.01)
loss_fn = torch.nn.MSELoss()

# fake data
y_true_np = np.zeros(num_initial_rays, dtype=np.float32)
num_rays_side = int(np.sqrt(num_initial_rays))
for j in range(num_rays_side):
    for i in range(num_rays_side):
        idx = j * num_rays_side + i
        if idx >= num_initial_rays: continue
        u = (i / (num_rays_side - 1) - 0.5) * 2.0
        v = (j / (num_rays_side - 1) - 0.5) * 2.0
        distance_sq = u*u + v*v
        sigma_target = 0.5
        y_true_np[idx] = 5.0 * np.exp(-distance_sq / (2.0 * sigma_target**2.0))
y_true_tensor = torch.tensor(y_true_np, device=torch_device)

# initial rays
initial_rays_np = np.empty(num_initial_rays, dtype=Ray.numpy_dtype())
for j in range(num_rays_side):
    for i in range(num_rays_side):
        idx = j * num_rays_side + i
        if idx >= num_initial_rays: continue
        ox = (i / (num_rays_side - 1) - 0.5) * 5.0
        oy = (j / (num_rays_side - 1) - 0.5) * 5.0
        initial_rays_np[idx]['origin'] = (ox, oy, -1.0)
        initial_rays_np[idx]['dir'] = (0.0, 0.0, 1.0)
initial_rays_wp = wp.array(initial_rays_np, dtype=Ray, device=device)


print(f"been training for like {training_steps} steps ---")

# sparse params
lambda_l1 = 1e-4 

for step in range(training_steps):
    optimizer.zero_grad()
    
    learnable_params_wp = wp.from_torch(learnable_params_tensor)
    outputs_tensor = torch.zeros(num_initial_rays, dtype=torch.float32, device=torch_device)
    outputs_wp = wp.from_torch(outputs_tensor)

    # forward pass
    tape = wp.Tape()
    with tape:
        wp.launch(
            kernel=volumetric_trace,
            dim=num_initial_rays,
            inputs=[
                grid.id,
                positions_device,
                learnable_params_wp,
                initial_rays_wp,
                outputs_wp,
                64,    # num_steps
                0.2,   # step_size
                1.0    # query_radius
            ],
            device=device
        )

    # main render loss
    rendering_loss = loss_fn(outputs_tensor, y_true_tensor)
    drop_weights = learnable_params_tensor[:, 0]
    l1_loss = lambda_l1 * torch.abs(drop_weights).sum()
    
    # aux loss
    total_loss = rendering_loss + l1_loss

    # back pass
    # 
    seed_grad_tensor = (2.0 / num_initial_rays) * (outputs_tensor - y_true_tensor)
    tape.backward(grads={outputs_wp: wp.from_torch(seed_grad_tensor)})
    
    # Copy the Warp gradients to the PyTorch tensor
    learnable_params_tensor.grad = wp.to_torch(learnable_params_wp.grad)
    
    
    l1_loss.backward()

    # --- 5. Optimizer Step ---
    optimizer.step()
    
    if step % 10 == 0 or step == training_steps - 1:
        print(f"Step {step:03d}, Rendering Loss: {rendering_loss.item():.6f}, L1 Loss: {l1_loss.item():.6f}")

print("\n--- Training complete ---")

Building HashGrid...
HashGrid built for 400 drops.

--- Starting Training for 150 steps ---
Step 000, Rendering Loss: 4.749850, L1 Loss: 0.001019
Step 010, Rendering Loss: 1.325285, L1 Loss: 0.002278
Step 020, Rendering Loss: 0.351572, L1 Loss: 0.002482
Step 030, Rendering Loss: 0.073063, L1 Loss: 0.002218
Step 040, Rendering Loss: 0.052650, L1 Loss: 0.002323
Step 050, Rendering Loss: 0.023236, L1 Loss: 0.002354
Step 060, Rendering Loss: 0.010697, L1 Loss: 0.002223
Step 070, Rendering Loss: 0.006464, L1 Loss: 0.002115
Step 080, Rendering Loss: 0.003927, L1 Loss: 0.001986
Step 090, Rendering Loss: 0.002607, L1 Loss: 0.001880
Step 100, Rendering Loss: 0.001990, L1 Loss: 0.001803
Step 110, Rendering Loss: 0.001588, L1 Loss: 0.001745
Step 120, Rendering Loss: 0.001333, L1 Loss: 0.001690
Step 130, Rendering Loss: 0.001145, L1 Loss: 0.001641
Step 140, Rendering Loss: 0.000996, L1 Loss: 0.001604
Step 149, Rendering Loss: 0.000889, L1 Loss: 0.001573

--- Training complete ---


In [43]:
import warp as wp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

wp.init()

@wp.struct
class Ray:
    origin: wp.vec3
    dir: wp.vec3


@wp.func
def influence_func(sample_pos: wp.vec3, drop_pos: wp.vec3, sigma: wp.float32):
    """Calculates the influence of a drop at a sample point."""
    dist_sq = wp.length_sq(sample_pos - drop_pos)
    sigma_sq = sigma * sigma
    return wp.exp(-dist_sq / (2.0 * sigma_sq + 1e-9))


@wp.kernel
def volumetric_trace(
    # Scene Data
    grid: wp.uint64,
    drop_positions: wp.array(dtype=wp.vec3),
    drop_learnable_params: wp.array(dtype=wp.float32, ndim=2),

    # Input/Output
    initial_rays: wp.array(dtype=Ray),
    outputs: wp.array(dtype=wp.float32),

    # Kernel Parameters
    num_steps: wp.int32,
    step_size: wp.float32,
    query_radius: wp.float32
):
    """Traces rays through the volume and accumulates radiance."""
    tid = wp.tid()
    ray = initial_rays[tid]
    accumulated_radiance = wp.float32(0.0)

    for i in range(num_steps):
        t = wp.float32(i) * step_size
        sample_pos = ray.origin + ray.dir * t

        query = wp.hash_grid_query(grid, sample_pos, query_radius)
        candidate_index = wp.int32(0)
        
        local_radiance = wp.float32(0.0)

        while wp.hash_grid_query_next(query, candidate_index):
            drop_weight = wp.clamp(drop_learnable_params[candidate_index, 0], -1.0, 1.0)
            drop_energy = wp.clamp(drop_learnable_params[candidate_index, 1], 0.0, 5.0)
            drop_sigma  = wp.clamp(drop_learnable_params[candidate_index, 2], 0.01, 2.0)

            influence = influence_func(sample_pos, drop_positions[candidate_index], drop_sigma)
            contribution = influence * drop_weight * drop_energy
            local_radiance += contribution
        
        accumulated_radiance += local_radiance * step_size
    
    outputs[tid] = accumulated_radiance


class VolumetricSplattingLayer(nn.Module):
    """
    A plug-and-play neural network layer for volumetric rendering.
    Manages its own state of "drops" and handles rendering, densification, and pruning.
    """
    def __init__(self, initial_positions, initial_learnable_params, device,
                 densify_start_step=100, densify_interval=100,
                 prune_threshold_weight=0.005, densify_grad_threshold=0.001,
                 split_threshold_sigma=0.5, ema_beta=0.9):
        super().__init__()
        self.device = 'cuda'
                self.positions = nn.Parameter(torch.tensor(initial_positions, device=self.device))
        self.learnable_params = nn.Parameter(torch.tensor(initial_learnable_params, device=self.device))

        self.densify_start_step = densify_start_step
        self.densify_interval = densify_interval
        self.prune_threshold_weight = prune_threshold_weight
        self.densify_grad_threshold = densify_grad_threshold
        self.split_threshold_sigma = split_threshold_sigma
        
        self.ema_beta = ema_beta
        self.position_grad_ema = torch.zeros_like(self.positions)

    def forward(self, initial_rays_wp: wp.array):
        """
        Performs the forward pass: renders the volume for the given rays.
        """
        num_initial_rays = initial_rays_wp.shape[0]
        
        # Rebuild HashGrid at every step since positions are learnable
        # Store these warp arrays as attributes to be accessed in the backward pass
        self.current_positions_wp = wp.from_torch(self.positions, dtype=wp.vec3)
        grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device=self.device)
        grid.build(points=self.current_positions_wp, radius=2.0)
        
        self.learnable_params_wp = wp.from_torch(self.learnable_params)
        outputs_tensor = torch.zeros(num_initial_rays, dtype=torch.float32, device=self.device)
        self.outputs_wp = wp.from_torch(outputs_tensor)

        
        self.tape = wp.Tape()
        with self.tape:
            wp.launch(
                kernel=volumetric_trace,
                dim=num_initial_rays,
                inputs=[
                    grid.id,
                    self.current_positions_wp,
                    self.learnable_params_wp,
                    initial_rays_wp,
                    self.outputs_wp,
                    64, 0.2, 1.5 # num_steps, step_size, query_radius
                ],
                device=self.device
            )
        
        return outputs_tensor

    @torch.no_grad()
    def update_state(self, step: int, optimizer: torch.optim.Optimizer):
        """
        Updates the layer's state by pruning and densifying drops.
        This should be called from the main training loop.
        Returns True if the optimizer needs to be reset.
        """
        if step <= self.densify_start_step or step % self.densify_interval != 0:
            return False

        if self.positions.grad is None:
            return False

        # Update gradient EMA for stable densification ---
        self.position_grad_ema = self.ema_beta * self.position_grad_ema + \
                                 (1.0 - self.ema_beta) * self.positions.grad
        
        pos_grads_magnitude = torch.norm(self.position_grad_ema, dim=1)
        
        print(f"--- Step {step:03d}: Densification ---")
        print(f"Max position grad magnitude (EMA): {pos_grads_magnitude.max().item():.6f}")

        # prune
        prune_mask = torch.abs(self.learnable_params[:, 0]) > self.prune_threshold_weight
        n_pruned = self.positions.shape[0] - prune_mask.sum().item()
        
        if n_pruned > 0:
            self.positions.data = self.positions.data[prune_mask]
            self.learnable_params.data = self.learnable_params.data[prune_mask]
            pos_grads_magnitude = pos_grads_magnitude[prune_mask]
            self.position_grad_ema = self.position_grad_ema[prune_mask]

        # splitting
        split_mask = (self.learnable_params[:, 2] > self.split_threshold_sigma) & (pos_grads_magnitude > self.densify_grad_threshold)
        n_split = 0
        if split_mask.any():
            n_split = split_mask.sum().item()
            split_positions = self.positions.data[split_mask]
            split_learnables = self.learnable_params.data[split_mask]

            new_positions = split_positions.repeat(2, 1)
            new_learnables = split_learnables.repeat(2, 1)
            new_learnables[:, 2] *= 0.7
            new_learnables[:, 0] *= 0.7 
            
            stdev = torch.sqrt(new_learnables[:n_split, 2])
            mean = torch.zeros_like(stdev)
            offset_dir = self.positions.grad[prune_mask if n_pruned > 0 else torch.ones_like(prune_mask, dtype=torch.bool)][split_mask]
            offset = torch.normal(mean, stdev).unsqueeze(1) * (offset_dir / (torch.norm(offset_dir, dim=1, keepdim=True) + 1e-9))

            new_positions[:n_split] += offset
            new_positions[n_split:] -= offset
            
            keep_mask = ~split_mask
            self.positions.data = torch.cat((self.positions.data[keep_mask], new_positions), dim=0)
            self.learnable_params.data = torch.cat((self.learnable_params.data[keep_mask], new_learnables), dim=0)

            split_grads_magnitude = pos_grads_magnitude[split_mask].repeat(2)
            pos_grads_magnitude = torch.cat((pos_grads_magnitude[keep_mask], split_grads_magnitude), dim=0)

            split_grad_ema = self.position_grad_ema[split_mask].repeat(2, 1)
            self.position_grad_ema = torch.cat((self.position_grad_ema[keep_mask], split_grad_ema), dim=0)

        # cloning
        clone_mask = (self.learnable_params[:, 2] <= self.split_threshold_sigma) & (pos_grads_magnitude > self.densify_grad_threshold)
        n_cloned = 0
        if clone_mask.any():
            n_cloned = clone_mask.sum().item()
            clone_positions = self.positions.data[clone_mask]
            clone_learnables = self.learnable_params.data[clone_mask]

            self.learnable_params.data[clone_mask, 0] /= 2.0
            
            self.positions.data = torch.cat((self.positions.data, clone_positions), dim=0)
            self.learnable_params.data = torch.cat((self.learnable_params.data, clone_learnables), dim=0)
            self.position_grad_ema = torch.cat((self.position_grad_ema, self.position_grad_ema[clone_mask]), dim=0)


        if n_pruned > 0 or n_split > 0 or n_cloned > 0:
            print(f"Pruned: {n_pruned}, Split: {n_split}, Cloned: {n_cloned}.")
            return True # Optimizer needs reset
            
        return False



# Simulation setup
num_drops_initial = 400
num_initial_rays = 4096
training_steps = 300
device = wp.get_preferred_device()
torch_device = 'cuda' if wp.get_device(device).is_cuda else 'cpu'
lambda_l1 = 1e-4

# init drop data
rng = np.random.default_rng(42)
positions_np = (rng.random(size=(num_drops_initial, 3), dtype=np.float32) - 0.5) * 8.0
positions_np[:, 2] += 5.0
initial_weights = (rng.random(size=(num_drops_initial, 1), dtype=np.float32) - 0.5) * 0.1
initial_energies = np.ones((num_drops_initial, 1), dtype=np.float32)
initial_sigmas = np.full((num_drops_initial, 1), 0.5, dtype=np.float32)
learnable_data_np = np.hstack([initial_weights, initial_energies, initial_sigmas])

#fake data and rays
y_true_np = np.zeros(num_initial_rays, dtype=np.float32)
num_rays_side = int(np.sqrt(num_initial_rays))
for j in range(num_rays_side):
    for i in range(num_rays_side):
        idx = j * num_rays_side + i
        if idx >= num_initial_rays: continue
        u, v = (i / (num_rays_side - 1) - 0.5) * 2.0, (j / (num_rays_side - 1) - 0.5) * 2.0
        y_true_np[idx] = 5.0 * np.exp(-(u*u + v*v) / (2.0 * 0.5**2.0))
y_true_tensor = torch.tensor(y_true_np, device=torch_device)

initial_rays_np = np.empty(num_initial_rays, dtype=Ray.numpy_dtype())
for j in range(num_rays_side):
    for i in range(num_rays_side):
        idx = j * num_rays_side + i
        if idx >= num_initial_rays: continue
        ox, oy = (i / (num_rays_side - 1) - 0.5) * 5.0, (j / (num_rays_side - 1) - 0.5) * 5.0
        initial_rays_np[idx]['origin'], initial_rays_np[idx]['dir'] = (ox, oy, -1.0), (0.0, 0.0, 1.0)
initial_rays_wp = wp.array(initial_rays_np, dtype=Ray, device=device)

model = VolumetricSplattingLayer(positions_np, learnable_data_np, device).to(torch_device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

print(f"\n training for like {training_steps} steps ---")

for step in range(training_steps):
    optimizer.zero_grad()
    #forward
    outputs_tensor = model(initial_rays_wp)
    
    # loss calc
    rendering_loss = F.mse_loss(outputs_tensor, y_true_tensor)
    l1_loss = lambda_l1 * torch.abs(model.learnable_params[:, 0]).sum()
    total_loss = rendering_loss + l1_loss
    
    # back pass
    # get rend grads from warp
    seed_grad_tensor = (2.0 / outputs_tensor.numel()) * (outputs_tensor - y_true_tensor)
    seed_grad_wp = wp.from_torch(seed_grad_tensor.contiguous())
    # Use the output warp array stored as a model attribute as the key
    model.tape.backward(grads={model.outputs_wp: seed_grad_wp})
    
    # Retrieve gradients using the input warp arrays (also stored as attributes)
    model.learnable_params.grad = wp.to_torch(model.tape.gradients[model.learnable_params_wp])
    model.positions.grad = wp.to_torch(model.tape.gradients[model.current_positions_wp])
    
    # Manually add L1 grads to avoid autograd issues with in-place updates
    with torch.no_grad():
        # Gradient of L1 loss is lambda * sign(parameter)
        l1_grad_weights = lambda_l1 * torch.sign(model.learnable_params.data[:, 0])
        
        # Create a zero tensor with the same shape as the full gradient
        l1_grad_full = torch.zeros_like(model.learnable_params.grad)
        
        # Place the weight gradients in the first column
        l1_grad_full[:, 0] = l1_grad_weights
        
        # Accumulate the gradients
        model.learnable_params.grad += l1_grad_full

    # optimizer step
    optimizer.step()

    # prune and densiffy 
    if model.update_state(step, optimizer):
        # If drops were added/removed, we must reset the optimizer's state
        optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    if step % 10 == 0 or step == training_steps - 1:
        print(f"Step {step:03d}, Total Loss: {total_loss.item():.6f} (Render: {rendering_loss.item():.6f}, L1: {l1_loss.item():.6f}), Num Drops: {model.positions.shape[0]}")

print("\n--- Training complete ---")




--- Starting Training for 300 steps ---
Step 000, Total Loss: 4.751473 (Render: 4.750454, L1: 0.001019), Num Drops: 400
Step 010, Total Loss: 1.061185 (Render: 1.058930, L1: 0.002255), Num Drops: 400
Step 020, Total Loss: 0.249210 (Render: 0.246908, L1: 0.002302), Num Drops: 400
Step 030, Total Loss: 0.112066 (Render: 0.110066, L1: 0.002001), Num Drops: 400
Step 040, Total Loss: 0.023191 (Render: 0.020962, L1: 0.002229), Num Drops: 400
Step 050, Total Loss: 0.013817 (Render: 0.011561, L1: 0.002256), Num Drops: 400
Step 060, Total Loss: 0.007593 (Render: 0.005451, L1: 0.002142), Num Drops: 400
Step 070, Total Loss: 0.005473 (Render: 0.003422, L1: 0.002050), Num Drops: 400
Step 080, Total Loss: 0.004150 (Render: 0.002178, L1: 0.001972), Num Drops: 400
Step 090, Total Loss: 0.003455 (Render: 0.001541, L1: 0.001914), Num Drops: 400
Step 100, Total Loss: 0.002929 (Render: 0.001078, L1: 0.001851), Num Drops: 400
Step 110, Total Loss: 0.002604 (Render: 0.000798, L1: 0.001806), Num Drops: 400

In [44]:
import warp as wp
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# --- (Your existing Warp structs and kernels remain unchanged) ---
# Initialize Warp
wp.init()

@wp.struct
class Ray:
    origin: wp.vec3
    dir: wp.vec3

@wp.func
def influence_func(sample_pos: wp.vec3, drop_pos: wp.vec3, sigma: wp.float32):
    """Calculates the influence of a drop at a sample point."""
    dist_sq = wp.length_sq(sample_pos - drop_pos)
    sigma_sq = sigma * sigma
    return wp.exp(-dist_sq / (2.0 * sigma_sq + 1e-9))

@wp.kernel
def volumetric_trace(
    # Scene Data
    grid: wp.uint64,
    drop_positions: wp.array(dtype=wp.vec3),
    drop_learnable_params: wp.array(dtype=wp.float32, ndim=2),

    # Input/Output
    initial_rays: wp.array(dtype=Ray),
    outputs: wp.array(dtype=wp.float32),

    # Kernel Parameters
    num_steps: wp.int32,
    step_size: wp.float32,
    query_radius: wp.float32
):
    """Traces rays through the volume and accumulates radiance."""
    tid = wp.tid()
    ray = initial_rays[tid]
    accumulated_radiance = wp.float32(0.0)

    for i in range(num_steps):
        t = wp.float32(i) * step_size
        sample_pos = ray.origin + ray.dir * t

        query = wp.hash_grid_query(grid, sample_pos, query_radius)
        candidate_index = wp.int32(0)
        
        local_radiance = wp.float32(0.0)

        while wp.hash_grid_query_next(query, candidate_index):
            drop_weight = wp.clamp(drop_learnable_params[candidate_index, 0], -1.0, 1.0)
            drop_energy = wp.clamp(drop_learnable_params[candidate_index, 1], 0.0, 5.0)
            drop_sigma  = wp.clamp(drop_learnable_params[candidate_index, 2], 0.01, 2.0)

            influence = influence_func(sample_pos, drop_positions[candidate_index], drop_sigma)
            contribution = influence * drop_weight * drop_energy
            local_radiance += contribution
        
        accumulated_radiance += local_radiance * step_size
    
    outputs[tid] = accumulated_radiance

# --- (Your VolumetricSplattingLayer remains mostly unchanged) ---
class VolumetricSplattingLayer(nn.Module):
    """
    A plug-and-play neural network layer for volumetric rendering.
    Manages its own state of "drops" and handles rendering, densification, and pruning.
    """
    def __init__(self, initial_positions, initial_learnable_params, device,
                 densify_start_step=500, densify_interval=100,
                 prune_threshold_weight=0.005, densify_grad_threshold=0.0005,
                 split_threshold_sigma=0.5, ema_beta=0.9):
        super().__init__()
        self.device = 'cuda'
        
        # --- Register drops as learnable parameters ---
        self.positions = nn.Parameter(torch.tensor(initial_positions, device=self.device))
        self.learnable_params = nn.Parameter(torch.tensor(initial_learnable_params, device=self.device))

        # --- Densification and Pruning state ---
        self.densify_start_step = densify_start_step
        self.densify_interval = densify_interval
        self.prune_threshold_weight = prune_threshold_weight
        self.densify_grad_threshold = densify_grad_threshold
        self.split_threshold_sigma = split_threshold_sigma
        
        # --- EMA for stable gradient-based densification ---
        self.ema_beta = ema_beta
        self.position_grad_ema = torch.zeros_like(self.positions)

    def forward(self, initial_rays_wp: wp.array):
        num_initial_rays = initial_rays_wp.shape[0]
        
        # Store these warp arrays as attributes to be accessed in the backward pass
        self.current_positions_wp = wp.from_torch(self.positions, dtype=wp.vec3)
        grid = wp.HashGrid(dim_x=128, dim_y=128, dim_z=128, device=self.device)
        grid.build(points=self.current_positions_wp, radius=2.0)
        
        self.learnable_params_wp = wp.from_torch(self.learnable_params)
        outputs_tensor = torch.zeros(num_initial_rays, dtype=torch.float32, device=self.device, requires_grad=True)
        self.outputs_wp = wp.from_torch(outputs_tensor)

        self.tape = wp.Tape()
        with self.tape:
            wp.launch(
                kernel=volumetric_trace,
                dim=num_initial_rays,
                inputs=[
                    grid.id,
                    self.current_positions_wp,
                    self.learnable_params_wp,
                    initial_rays_wp,
                    self.outputs_wp,
                    32, 0.25, 1.5 # num_steps, step_size, query_radius
                ],
                device=self.device
            )
        
        return outputs_tensor

    @torch.no_grad()
    def update_state(self, step: int):
        if step <= self.densify_start_step or step % self.densify_interval != 0:
            return False

        if self.positions.grad is None:
            return False

        self.position_grad_ema = self.ema_beta * self.position_grad_ema + \
                                 (1.0 - self.ema_beta) * self.positions.grad
        
        pos_grads_magnitude = torch.norm(self.position_grad_ema, dim=1)
        
        # --- 1. Pruning ---
        prune_mask = torch.abs(self.learnable_params[:, 0]) > self.prune_threshold_weight
        n_pruned = self.positions.shape[0] - prune_mask.sum().item()
        
        if n_pruned > 0:
            self.positions.data = self.positions.data[prune_mask]
            self.learnable_params.data = self.learnable_params.data[prune_mask]
            pos_grads_magnitude = pos_grads_magnitude[prune_mask]
            self.position_grad_ema = self.position_grad_ema[prune_mask]

        # --- 2. Densification (Splitting) ---
        split_mask = (self.learnable_params[:, 2] > self.split_threshold_sigma) & (pos_grads_magnitude > self.densify_grad_threshold)
        n_split = 0
        if split_mask.any():
            n_split = split_mask.sum().item()
            split_positions = self.positions.data[split_mask]
            split_learnables = self.learnable_params.data[split_mask]

            new_positions = split_positions.repeat(2, 1)
            new_learnables = split_learnables.repeat(2, 1)
            new_learnables[:, 2] *= 0.7
            new_learnables[:, 0] *= 0.7 
            
            stdev = torch.sqrt(new_learnables[:n_split, 2])
            mean = torch.zeros_like(stdev)
            offset_dir = self.positions.grad[prune_mask if n_pruned > 0 else torch.ones_like(prune_mask, dtype=torch.bool)][split_mask]
            offset = torch.normal(mean, stdev).unsqueeze(1) * (offset_dir / (torch.norm(offset_dir, dim=1, keepdim=True) + 1e-9))

            new_positions[:n_split] += offset
            new_positions[n_split:] -= offset
            
            keep_mask = ~split_mask
            self.positions.data = torch.cat((self.positions.data[keep_mask], new_positions), dim=0)
            self.learnable_params.data = torch.cat((self.learnable_params.data[keep_mask], new_learnables), dim=0)

            split_grad_ema = self.position_grad_ema[split_mask].repeat(2, 1)
            self.position_grad_ema = torch.cat((self.position_grad_ema[keep_mask], split_grad_ema), dim=0)

        # --- 3. Densification (Cloning) ---
        clone_mask = (self.learnable_params[:, 2] <= self.split_threshold_sigma) & (pos_grads_magnitude > self.densify_grad_threshold)
        n_cloned = 0
        if clone_mask.any():
            n_cloned = clone_mask.sum().item()
            clone_positions = self.positions.data[clone_mask]
            clone_learnables = self.learnable_params.data[clone_mask]

            self.learnable_params.data[clone_mask, 0] /= 2.0
            
            self.positions.data = torch.cat((self.positions.data, clone_positions), dim=0)
            self.learnable_params.data = torch.cat((self.learnable_params.data, clone_learnables), dim=0)
            self.position_grad_ema = torch.cat((self.position_grad_ema, self.position_grad_ema[clone_mask]), dim=0)

        if n_pruned > 0 or n_split > 0 or n_cloned > 0:
            print(f"Pruned: {n_pruned}, Split: {n_split}, Cloned: {n_cloned}.")
            return True # Optimizer needs reset
            
        return False
# --- NEW: Main Classifier Model ---

class MNISTVolumetricClassifier(nn.Module):
    def __init__(self, initial_positions, initial_learnable_params, device, img_size=28):
        super().__init__()
        self.img_size = img_size
        self.num_rays = img_size * img_size
        self.wp_device = device
        
        # Core volumetric rendering layer
        self.volumetric_layer = VolumetricSplattingLayer(
            initial_positions, initial_learnable_params, device
        )
        
        # Readout layer to classify the rendered feature map
        self.readout_layer = nn.Linear(self.num_rays, 10)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]
        
        # 1. Generate a grid of rays for each image in the batch
        rays_wp = self.generate_rays_for_batch(batch_size)
        
        # 2. Render the feature map using the volumetric layer
        # The output will have shape (batch_size * num_rays)
        rendered_features_flat = self.volumetric_layer(rays_wp)
        
        # 3. Reshape to (batch_size, num_rays) for the readout layer
        rendered_features = rendered_features_flat.view(batch_size, self.num_rays)
        
        # 4. Classify the feature map
        logits = self.readout_layer(rendered_features)
        
        return logits, rendered_features

    def generate_rays_for_batch(self, batch_size: int):
        """Creates a grid of parallel rays for a batch of images."""
        total_rays = batch_size * self.num_rays
        initial_rays_np = np.empty(total_rays, dtype=Ray.numpy_dtype())
        
        scale = 5.0 # How spread out the ray origins are in the XY plane
        
        for b in range(batch_size):
            for j in range(self.img_size):
                for i in range(self.img_size):
                    idx = b * self.num_rays + j * self.img_size + i
                    
                    # Map pixel coords (i, j) to world space (ox, oy)
                    ox = (i / (self.img_size - 1) - 0.5) * scale
                    oy = (j / (self.img_size - 1) - 0.5) * scale
                    
                    initial_rays_np[idx]['origin'] = (ox, oy, -2.0) # Start rays further back
                    initial_rays_np[idx]['dir'] = (0.0, 0.0, 1.0) # All point forward
                    
        return wp.array(initial_rays_np, dtype=Ray, device=self.wp_device)

# --- Main script ---

# Simulation setup
num_drops_initial = 8000 # More drops might be needed for this complex task
batch_size = 16
training_steps = 5000
device = wp.get_preferred_device()
torch_device = 'cuda' if wp.get_device(device).is_cuda else 'cpu'
lambda_l1 = 1e-5 # L1 regularization on drop weights

# --- Create Initial Drop Data ---
rng = np.random.default_rng(42)
positions_np = (rng.random(size=(num_drops_initial, 3), dtype=np.float32) - 0.5) * 8.0
positions_np[:, 2] += 4.0 # Center the cloud along the Z axis
initial_weights = (rng.random(size=(num_drops_initial, 1), dtype=np.float32) - 0.5) * 0.1
initial_energies = np.ones((num_drops_initial, 1), dtype=np.float32)
initial_sigmas = np.full((num_drops_initial, 1), 0.5, dtype=np.float32)
learnable_data_np = np.hstack([initial_weights, initial_energies, initial_sigmas])

# --- Load MNIST Data ---
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = datasets.MNIST('../data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
data_iter = iter(train_loader)

# --- Define the Model and Optimizer ---
model = MNISTVolumetricClassifier(positions_np, learnable_data_np, device).to(torch_device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.99))
loss_fn = nn.CrossEntropyLoss()

# --- Training Loop ---
print(f"\n--- Starting Training for {training_steps} steps ---")

for step in range(training_steps):
    # Fetch a new batch of data
    try:
        images, labels = next(data_iter)
    except StopIteration:
        data_iter = iter(train_loader)
        images, labels = next(data_iter)
        
    images, labels = images.to(torch_device), labels.to(torch_device)

    optimizer.zero_grad()
    
    # --- Forward Pass ---
    logits, rendered_features = model(images)
    
    # --- Loss Calculation ---
    classification_loss = loss_fn(logits, labels)
    l1_loss = lambda_l1 * torch.abs(model.volumetric_layer.learnable_params[:, 0]).sum()
    total_loss = classification_loss + l1_loss
    
    # --- Backward Pass ---
    # We need to manually handle the gradient flow from PyTorch back to Warp.
    
    # Step 1: Get the gradient of the loss w.r.t. the output of the volumetric layer.
    # PyTorch's autograd does this for us when we call backward() on the loss.
    # We need to retain the graph to allow backprop for the readout layer's params.
    total_loss.backward(retain_graph=True) 
    
    # The gradient we need is now in rendered_features.grad
    seed_grad_tensor = rendered_features.grad
    
    # Step 2: Feed this gradient back into the Warp tape
    seed_grad_wp = wp.from_torch(seed_grad_tensor.flatten().contiguous())
    
    # Use the tape and arrays stored in the volumetric layer during the forward pass
    vol_layer = model.volumetric_layer
    vol_layer.tape.backward(grads={vol_layer.outputs_wp: seed_grad_wp})
    
    # Step 3: Assign the calculated gradients from Warp to the nn.Parameters
    vol_layer.learnable_params.grad = wp.to_torch(vol_layer.tape.gradients[vol_layer.learnable_params_wp])
    vol_layer.positions.grad = wp.to_torch(vol_layer.tape.gradients[vol_layer.current_positions_wp])
    
    # Step 4: Manually add the L1 regularization gradient to the drop weights
    with torch.no_grad():
        l1_grad_weights = lambda_l1 * torch.sign(vol_layer.learnable_params.data[:, 0])
        l1_grad_full = torch.zeros_like(vol_layer.learnable_params.grad)
        l1_grad_full[:, 0] = l1_grad_weights
        vol_layer.learnable_params.grad += l1_grad_full

    # --- Optimizer Step ---
    # The optimizer will now update the readout_layer params (from loss.backward())
    # and the volumetric_layer params (from our manual gradient assignment).
    optimizer.step()

    # --- Update Layer State (Pruning/Densification) ---
    if vol_layer.update_state(step):
        # If drops were added/removed, reset the optimizer's state
        optimizer = torch.optim.Adam(model.parameters(), lr=0.005, betas=(0.9, 0.99))

    if step % 50 == 0 or step == training_steps - 1:
        # Calculate accuracy for logging
        with torch.no_grad():
            preds = torch.argmax(logits, dim=1)
            accuracy = (preds == labels).float().mean()
        print(f"Step {step:04d}, Loss: {total_loss.item():.4f} (Class: {classification_loss.item():.4f}, L1: {l1_loss.item():.6f}), Acc: {accuracy:.2%}, Drops: {vol_layer.positions.shape[0]}")

print("\n--- Training complete ---")

100%|██████████| 9.91M/9.91M [00:00<00:00, 13.8MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 340kB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 3.15MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 8.71MB/s]



--- Starting Training for 5000 steps ---


  seed_grad_tensor = rendered_features.grad


AttributeError: 'NoneType' object has no attribute 'flatten'