In [None]:
%pip uninstall -y numpy  # Force clean slate
%pip install numpy==1.24.4  # PyTorch-compatible version
%pip install torch==2.1.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

In [None]:
import torch
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"GPU: {torch.cuda.get_device_name(0)}")

## Expected output:
# PyTorch: 2.1.0+cu118
# CUDA available: True
# GPU: Tesla T4

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Place 7scenes dataset in your drive to access

In [None]:
%cp "/content/drive/MyDrive/stairs.zip" /content/
%cp "/content/drive/MyDrive/stairs.mhd" /content/
%cp "/content/drive/MyDrive/stairs.raw" /content/
# In place of stairs, you can use any other scene from 7scenes dataset

In [None]:
drive.flush_and_unmount()

In [None]:
!unzip -u "/content/stairs.zip" -d "/content/"

# Change the folder name to 7scenes

In [None]:
!rm -rf "/content/stairs.zip"

In [None]:
!unzip -u "/content/7scenes/seq-01.zip" -d '/content/7scenes/'
!unzip -u "/content/7scenes/seq-02.zip" -d '/content/7scenes/'
!unzip -u "/content/7scenes/seq-03.zip" -d '/content/7scenes/'
!unzip -u "/content/7scenes/seq-04.zip" -d '/content/7scenes/'
!unzip -u "/content/7scenes/seq-05.zip" -d '/content/7scenes/'
!unzip -u "/content/7scenes/seq-06.zip" -d '/content/7scenes/'
# .
# ...
# !unzip -u "/content/7scenes/seq-N.zip" -d '/content/7scenes/'

In [None]:
%pip install snntorch==0.9.4 e3nn==0.5.0 numpy==1.24.4

In [None]:
%pip install kornia

In [None]:
%pip install torch-geometric

In [None]:
%pip install torch-scatter -f https://data.pyg.org/whl/torch-2.1.0+cu118.html

In [None]:
# Core dependencies
import torch
import torch.nn as nn
import torch.nn.functional as F
import snntorch as snn
from snntorch import surrogate
import e3nn
from e3nn import o3
from e3nn import nn as e3nn_nn
import torch_geometric as pyg
from torch_geometric.nn import MessagePassing
from torch_geometric.data import Data, Batch
import kornia as kn
import glob
from PIL import Image
import kornia.geometry.conversions as KG
import torch.optim as optim
from torchvision import transforms
import os
import numpy as np
from torch.utils.data import Dataset, DataLoader
import math
from torch_scatter import scatter  # Add this line
from torch.optim.lr_scheduler import LambdaLR

In [None]:
%pip install SimpleITK

In [None]:
%pip install open3d

In [None]:
import SimpleITK as sitk
import numpy as np
import open3d as o3d

def tsdf_to_pointcloud(mhd_path, ply_path):
    try:
        tsdf = sitk.ReadImage(mhd_path)
        tsdf_vol = sitk.GetArrayFromImage(tsdf)
        spacing = tsdf.GetSpacing()
        offset = tsdf.GetOrigin()

        indices = np.where(np.abs(tsdf_vol) < 0.05)
        z_idx, y_idx, x_idx = indices
        points = np.stack([
            x_idx * spacing[0] + offset[0],  # X coordinate
            y_idx * spacing[1] + offset[1],  # Y coordinate
            z_idx * spacing[2] + offset[2]   # Z coordinate
        ], axis=-1)

        pcd = o3d.geometry.PointCloud()
        pcd.points = o3d.utility.Vector3dVector(points)
        o3d.io.write_point_cloud(ply_path, pcd)
        print(f"Saved {len(points)} points to {ply_path}")
    except Exception as e:
        print(f"Error: {str(e)}")

# Change file name according to the scene:
tsdf_to_pointcloud(
    mhd_path="/content/stairs.mhd",
    ply_path="/content/stairs_points.ply"
)

In [None]:
import torch
import open3d as o3d
def compute_scene_diameter(points):
    """Calculate scene diameter from point cloud"""
    # Convert to tensor if needed
    if not isinstance(points, torch.Tensor):
        points = torch.tensor(points, dtype=torch.float32)

    # Compute pairwise distances
    dist_matrix = torch.cdist(points, points)  # [N,N]

    # Find maximum distance
    return torch.max(dist_matrix).item()

pcd_tensor = o3d.t.io.read_point_cloud("stairs_points.ply")
scene_points = pcd_tensor.point["positions"].numpy()/1000.0


In [None]:
class Config:
    batch_size = 16
    num_steps = 32  # Increased temporal resolution
    input_size = (128, 128)
    num_epochs = 100
    lr = 3e-4
    beta = 0.97  # Slower membrane decay
    grad_clip = 1.0
    irreps_in = o3.Irreps("256x0e")
    irreps_hidden = o3.Irreps("128x0e + 64x1o + 32x2e")
    irreps_out = o3.Irreps("1o + 1o")  # Rotation + Translation
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

class LearnableTemporalPool(nn.Module):
    def __init__(self, num_steps):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=1,  # Process each feature independently
            out_channels=1,
            kernel_size=3,
            padding=1
        )
        self.weights = nn.Parameter(torch.ones(num_steps))

    def forward(self, spikes):
        batch_size, num_steps, channels, height, width = spikes.shape

        # Flatten spatial and channel dimensions
        x = spikes.reshape(batch_size * channels * height * width, num_steps)

        # Add channel dimension: [B*C*H*W, T] -> [B*C*H*W, 1, T]
        x = x.unsqueeze(1)

        # Apply convolution: [B*C*H*W, 1, T] -> [B*C*H*W, 1, T]
        x = self.conv(x)

        # Apply learnable weights: [B*C*H*W, 1, T] * [1, 1, T] -> [B*C*H*W, 1, T]
        x = x * self.weights.view(1, 1, -1)

        # Sum over time: [B*C*H*W, 1]
        x = x.sum(dim=2)

        # Reshape to original spatial dimensions: [B, C, H, W]
        return x.reshape(batch_size, channels, height, width)

class SNNEncoder(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, output_dim=32, beta=0.5, num_steps=25, spike_grad=None):
        super().__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, kernel_size=5, stride=2, padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim*2, kernel_size=5, stride=2, padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(hidden_dim*2, output_dim, kernel_size=5, stride=2, padding=2, padding_mode='replicate')

        self.temporal_pool = LearnableTemporalPool(num_steps)

        # Spiking neuron layers
        spike_grad = surrogate.fast_sigmoid() if spike_grad is None else spike_grad
        self.lif1 = snn.Leaky(
            beta=0.85,
            threshold=0.5,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        self.lif2 = snn.Leaky(
            beta=0.78,
            threshold=0.6,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        self.lif3 = snn.Leaky(
            beta=0.72,
            threshold=0.7,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        # Hidden states
        self.mem1 = self.mem2 = self.mem3 = None

    def forward(self, x, num_steps=25):
        # Reset membrane potentials
        batch_size = x.size(0)
        x = x.unsqueeze(1)  # [B, 1, C, H, W]
        x = x.repeat(1, num_steps, 1, 1, 1)  # [B, T, C, H, W]
        x += torch.randn_like(x) * 0.1  # Add temporal noise
        x = x.view(batch_size * num_steps, *x.shape[2:])  # [B*T, C, H, W]
        if torch.isnan(x).any():
            print("NaN in SNN input after reshaping!")
            raise RuntimeError

        # Initialize membrane potentials
        mem1 = torch.zeros(
            batch_size,
            num_steps,
            self.conv1.out_channels,
            x.size(2)//2,
            x.size(3)//2
        ).to(x.device)
        mem2 = torch.zeros(
            batch_size,
            num_steps,
            self.conv2.out_channels,
            x.size(2)//4,
            x.size(3)//4
        ).to(x.device)
        mem3 = torch.zeros(
            batch_size,
            num_steps,
            self.conv3.out_channels,
            x.size(2)//8,
            x.size(3)//8
        ).to(x.device)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk3_rec = []
        for t in range(num_steps):
            x_t = x.view(batch_size, num_steps, *x.shape[1:])[:, t]

            # Layer 1 with snnTorch state handling
            cur1 = self.conv1(x_t)
            spk1, mem1 = self.lif1(cur1, mem1)  # Automatic state update

            # Layer 2
            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Layer 3
            cur3 = self.conv3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)

        return torch.stack(spk3_rec, dim=1)
        
class SpikeToGeometric(nn.Module):
    def __init__(self, input_dim=32, output_scalar=32, output_vector=32):
        super().__init__()
        self.conv_scalar = nn.Conv2d(input_dim, output_scalar, kernel_size=1)
        self.conv_vector = nn.Conv2d(input_dim, output_vector*3, kernel_size=1)
        self.output_vector = output_vector

        # Proper initialization for Conv layers
        nn.init.kaiming_normal_(self.conv_scalar.weight,
                              mode='fan_in',
                              nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv_vector.weight,
                              mode='fan_in',
                              nonlinearity='leaky_relu')
        nn.init.zeros_(self.conv_scalar.bias)
        nn.init.zeros_(self.conv_vector.bias)

    def forward(self, spike_features):
        """Convert spike rate features to geometric features with spatial preservation"""
        # Input shape: [B, 32, 16, 16]
        # Process scalar features (per-location invariant features)
        batch_size = spike_features.size(0)
        if torch.isnan(spike_features).any():
            print("NaN in spike_features input to SpikeToGeometric")
            raise RuntimeError

        scalar_features = self.conv_scalar(spike_features)  # [B,64,16,16]
        if torch.isnan(scalar_features).any():
            print("NaN in scalar_features after conv_scalar")
            raise RuntimeError

        # Process vector features (per-location equivariant features)
        vector_features = self.conv_vector(spike_features)  # [B,192,16,16]
        vector_features = vector_features.view(
            batch_size, self.output_vector, 3, 16, 16
        )  # [B,64,3,16,16]
        if torch.isnan(vector_features).any():
            print("NaN in vector_features after conv_vector")
            raise RuntimeError

        return scalar_features, vector_features

class GraphConstructor(nn.Module):
    def __init__(self, feature_dim=256, k_neighbors=8):
        super().__init__()
        self.k = k_neighbors
        self.feature_projection = nn.Linear(feature_dim, feature_dim)#.to(Config.device)

    def forward(self, features, positions=None):
        batch_size, num_nodes, _ = features.size()
        device = features.device

        # Project features
        node_features = self.feature_projection(features)

        # Build graphs
        graphs = []
        for b in range(batch_size):
            if positions is not None:
                # Use 3D spatial positions for KNN
                pos = positions[b]
                edge_index = self._knn_graph(pos, self.k)
            else:
                # Fallback to feature similarity graph
                feat = node_features[b]
                edge_index = self._feature_graph(feat, self.k)

            graph = Data(
                x=node_features[b],
                edge_index=edge_index,
                pos=positions[b] if positions is not None else None
            )
            graphs.append(graph)

        return Batch.from_data_list(graphs)

    def _feature_graph(self, features, k):
        """Construct graph based on feature similarity (critical for non-spatial data)"""
        device = features.device
        num_nodes = features.size(0)
        effective_k = min(k, num_nodes - 1)

        # Compute cosine similarity
        features_norm = F.normalize(features, p=2, dim=1)
        similarity = torch.mm(features_norm, features_norm.t())

        # Get top k+1 similar neighbors
        _, indices = torch.topk(similarity, k=effective_k+1, largest=True)
        indices = indices[:, 1:]  # Remove self-loops

        # Create edge index with proper device handling
        rows = torch.arange(num_nodes, device=device)[:, None].expand(-1, effective_k)
        edge_index = torch.stack([
            rows.reshape(-1),
            indices.reshape(-1).to(device)
        ], dim=0)

        return edge_index

    def _knn_graph(self, points, k):
        """3D spatial KNN graph construction with locality preservation"""
        device = points.device
        num_nodes = points.size(0)
        effective_k = min(k, num_nodes - 1)

        # Preserve spatial locality by using grid ordering
        if num_nodes == 256:  # 16x16 grid
            # Create grid indices
            idx = torch.arange(num_nodes, device=device).view(16, 16)

            # Horizontal connections
            horizontal_edges = []
            for i in range(16):
                for j in range(15):
                    n1 = idx[i, j]
                    n2 = idx[i, j+1]
                    horizontal_edges.append([n1, n2])
                    horizontal_edges.append([n2, n1])

            # Vertical connections
            vertical_edges = []
            for i in range(15):
                for j in range(16):
                    n1 = idx[i, j]
                    n2 = idx[i+1, j]
                    vertical_edges.append([n1, n2])
                    vertical_edges.append([n2, n1])

            # Diagonal connections
            diagonal_edges = []
            for i in range(15):
                for j in range(15):
                    n1 = idx[i, j]
                    n2 = idx[i+1, j+1]
                    diagonal_edges.append([n1, n2])
                    diagonal_edges.append([n2, n1])

            # Combine all connections
            edge_index = torch.tensor(horizontal_edges + vertical_edges + diagonal_edges,
                                     device=device).t().contiguous()
            return edge_index

        # Fallback to geometric KNN for non-grid structures
        dist = torch.cdist(points, points)
        _, indices = torch.topk(dist, k=effective_k+1, largest=False)
        indices = indices[:, 1:]  # Remove self-loops

        rows = torch.arange(num_nodes, device=device)[:, None].expand(-1, effective_k)
        return torch.stack([rows.reshape(-1), indices.reshape(-1)], dim=0)

class SE3EquivariantGNN(nn.Module):
    def __init__(self, hidden_dim=32, num_layers=3):
        super().__init__()
        self.irreps_in = o3.Irreps("64x0e + 64x1o")
        self.irreps_hidden = o3.Irreps("64x0e + 64x1o")
        self.irreps_edge = o3.Irreps("1x0e")  # Scalar edge features
        self.irreps_out = o3.Irreps("1x1o + 2x1o")

        self.layers = nn.ModuleList()

        # Initial layer with edge integration
        self.layers.append(o3.TensorProduct(
            self.irreps_in,
            self.irreps_edge,
            self.irreps_hidden,
            instructions=[(i,0,i,'uvu',True, 1.0) for i in range(len(self.irreps_in))],
            shared_weights=True,
            internal_weights=True
        ))

        # Hidden layers
        for _ in range(num_layers-2):
            self.layers.append(o3.TensorProduct(
                self.irreps_hidden,
                self.irreps_edge,
                self.irreps_hidden,
                instructions=[(i,0,i,'uvu', True, 1.0) for i in range(len(self.irreps_hidden))],
                shared_weights=True,
                path_normalization='element',
                internal_weights=True
            ))
            self.layers.append(self.Float32NormActivation(
                self.irreps_hidden,
                scalar_nonlinearity=torch.tanh,
                epsilon=1e-5
            ))
            
        # Output layer
        self.layers.append(o3.Linear(self.irreps_hidden, self.irreps_out))

    class Float32NormActivation(e3nn_nn.NormActivation):
        def forward(self, features):
            original_dtype = features.dtype
            # Convert to float32 for stable computation
            features = features.float()
            features = super().forward(features)
            # Convert back to original precision
            return features.to(original_dtype)

    def forward(self, x, edge_index, pos):
        # Compute edge features (radial basis)
        if torch.isnan(x).any():
            print("NaN in GNN input")
            raise RuntimeError

        senders, receivers = edge_index
        rel_pos = pos[senders] - pos[receivers]
        distances = torch.norm(rel_pos, dim=1, keepdim=True)
        if torch.isnan(rel_pos).any() or torch.isnan(distances).any():
            print("NaN in edge features")
            raise RuntimeError

        edge_attr = torch.exp(-distances**2)  # [num_edges, 1]
        # Message passing loop
        for layer_idx, layer in enumerate(self.layers):
            if isinstance(layer, o3.TensorProduct):
                # Aggregate messages using edge features
                messages = layer(x[senders], edge_attr)
                if torch.isnan(messages).any():
                    print(f"NaN in messages at layer {layer_idx}")
                    raise RuntimeError

                x = scatter(messages, receivers, dim=0, dim_size=x.size(0))
                if torch.isnan(x).any():
                    print(f"NaN after scattering at layer {layer_idx}")
                    raise RuntimeError
            else:
                x = layer(x)
                # print(x)
                if torch.isnan(x).any():
                    print(x)
                    print(f"NaN after layer {layer_idx} ({type(layer).__name__})")
                    raise RuntimeError

        # Split output
        translation = x[:, :3]  # First 3 components (translation)
        rotation_6d = x[:, 3:9]  # Raw 6D representation
        return translation, rotation_6d

    def inference_orthogonalize(self, rotation_6d):
        """Apply orthogonalization ONLY during inference"""
        a1, a2 = rotation_6d[:, :3], rotation_6d[:, 3:6]
        return self.sixd_to_rotation_matrix(a1, a2)

    def sixd_to_rotation_matrix(self, a1, a2):
        """Stable 6D → rotation conversion"""
        # Add epsilon guards
        a1 = a1 / (torch.norm(a1, dim=-1, keepdim=True) + 1e-6)
        a2 = a2 - (a1 * a2).sum(dim=-1, keepdim=True) * a1
        a2 = a2 / (torch.norm(a2, dim=-1, keepdim=True) + 1e-6)
        b3 = torch.cross(a1, a2)
        R = torch.stack([a1, a2, b3], dim=-1)

        return R

class NeuromorphicLieGNN(nn.Module):
    def __init__(self, input_channels=3, hidden_dim=64, num_steps=25, beta=0.5, num_gnn_layers=3):
        super().__init__()

        # SNN encoder
        self.snn_encoder = SNNEncoder(
            input_dim=input_channels,
            hidden_dim=hidden_dim,
            output_dim=32,
            beta=beta
        )

        # Spike to geometric conversion
        self.spike_to_geo = SpikeToGeometric(
            input_dim=32,  # Depends on input size and encoder architecture
            output_scalar=64,
            output_vector=64
        )

        # Graph construction
        self.graph_constructor = GraphConstructor(feature_dim=256, k_neighbors=8)

        # SE(3) equivariant GNN
        self.se3_gnn = SE3EquivariantGNN(hidden_dim=32, num_layers=num_gnn_layers)

        # Parameters
        self.num_steps = num_steps

    def forward(self, images, positions=None):
        """
        Args:
            images: Batch of input images (B, C, H, W)
            positions: Optional initial position estimates (B, N, 3)
        """
        batch_size = images.size(0)

        # 1. Process through SNN encoder to get spike trains
        spike_features = self.snn_encoder(images, self.num_steps)
        # print("Encoder spikes:", torch.mean(spike_features).item())
        # print("Spike features shape:", spike_features.shape)  # [B,T,C,H,W]  # Add here
        if torch.isnan(spike_features).any():
            raise ValueError("NaN in spike features!")

        # 2. Convert to rate-based representation
        rate_features = self.snn_encoder.temporal_pool(spike_features)
        # Ensure consistent dtype
        rate_features = rate_features.to(torch.float32)
        # print("Rate features shape:", rate_features.shape)  # [B,C,H,W]
        if torch.isnan(rate_features).any():
            raise ValueError("NaN in rate features!")

        # 3. Convert to geometric features
        scalar_features, vector_features = self.spike_to_geo(rate_features)
        if torch.isnan(scalar_features).any() or torch.isnan(vector_features).any():
            print("NaN in geometric features")
            raise RuntimeError

        # print("Scalar features:", scalar_features.abs().mean().item())
        # print("Vector features:", vector_features.abs().mean().item())
        # print("Vector features pre-reshape:", vector_features.shape)
        vector_reshaped = vector_features.view(
            batch_size,
            self.spike_to_geo.output_vector * 3,  # 64*3=192
            16,
            16
        )
        # print("Vector features post-reshape:", vector_reshaped.shape)  # [B,192,16,16]
        # 4. Construct graph
        # Combine scalar and vector features for graph construction
        combined_features = torch.cat([
            scalar_features,  # [B,64,16,16]
            vector_reshaped
        ], dim=1)  # [B,256,16,16]
        # print("Combined features shape:", combined_features.shape)  # [B,256,16,16]

        # Reshape to spatial nodes [B,256,16,16] -> [B,256,256]
        node_features = combined_features.permute(0, 2, 3, 1)  # [B, 16, 16, 256]

        # Get normalized grid positions
        grid = self._get_grid_positions().to(images.device)
        grid = (grid - grid.min()) / (grid.max() - grid.min()) * 2 - 1  # Normalize to [-1,1]

        # Expand positions to match batch size
        positions = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)  # [B, 16, 16, 3]

        # Flatten spatial dimensions while preserving locality
        node_features = node_features.reshape(batch_size, 256, -1)  # [B, 256, 256]
        positions = positions.reshape(batch_size, 256, 3)  # [B, 256, 3]

        # Build graph with spatial positions
        graph = self.graph_constructor(node_features, positions)

        # 5. Process through SE(3) equivariant GNN
        position, rotation_6d = self.se3_gnn(graph.x, graph.edge_index, graph.pos)


        rotation_matrix = self.se3_gnn.inference_orthogonalize(rotation_6d)

        position = scatter(position, graph.batch, dim=0, reduce='mean')
        rotation_6d = scatter(rotation_6d, graph.batch, dim=0, reduce='mean')
        rotation_matrix = scatter(rotation_matrix, graph.batch, dim=0, reduce='mean')

        return {
            'position': position,
            'rotation_6d': rotation_6d,  # For training loss
            'rotation_matrix': rotation_matrix,  # For inference/validation
            'pose': self._combine_pose(position, rotation_matrix)
        }

    def _combine_pose(self, position, rotation):
        """Combine position and rotation into 4x4 transformation matrix"""
        batch_size = position.size(0)

        # Convert quaternion to rotation matrix
        R = rotation
        # Create transformation matrix
        T = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).to(position.device)
        T[:, :3, :3] = R
        T[:, :3, 3] = position

        return T

    def _get_grid_positions(self):
        """Generate 3D pseudo-coordinates for spatial graph construction"""
        y, x = torch.meshgrid(
            torch.arange(16),
            torch.arange(16),
            indexing='ij'
        )
        z = torch.zeros_like(x)  # Add depth dimension
        return torch.stack([x.float(), y.float(), z], dim=-1)  # [16,16,3]


class PoseLoss(nn.Module):
    def __init__(self, scene_points, alpha=0.9):
        super().__init__()
        # Scene points setup
        self.register_buffer('scene_points', self._preprocess_scene_points(scene_points))

        # EMA parameters for adaptive weighting
        self.alpha = alpha
        self.register_buffer('pos_ema', torch.tensor(float('inf')))
        self.register_buffer('rot_ema', torch.tensor(float('inf')))
        self.register_buffer('add_ema', torch.tensor(float('inf')))

    def _farthest_point_sampling(self, points, n_samples):
        """Memory-efficient FPS without full distance matrix"""
        device = points.device
        n_points = points.shape[0]

        # Initialize with random point
        start_idx = torch.randint(0, n_points, (1,)).item()
        selected_indices = [start_idx]

        # Initialize min distances
        min_dists = torch.full((n_points,), float('inf'), device=device)

        # Compute initial distances
        dists = torch.norm(points - points[start_idx], dim=1)
        min_dists = torch.min(min_dists, dists)

        for _ in range(1, min(n_samples, n_points)):
            # Find farthest point
            next_idx = torch.argmax(min_dists).item()
            selected_indices.append(next_idx)

            # Compute distances to new point
            new_dists = torch.norm(points - points[next_idx], dim=1)

            # Update min distances
            min_dists = torch.min(min_dists, new_dists)
            min_dists[next_idx] = 0  # Avoid reselecting

        return points[selected_indices]

    def _preprocess_scene_points(self, points):
        """Fast random subsampling for debugging"""
        if isinstance(points, np.ndarray):
            points = torch.from_numpy(points)
        points = points.float()

        # Simple random sampling instead of FPS
        if len(points) > 500:
            indices = torch.randperm(len(points))[:500]
            points = points[indices]

        return points

    def _matrix_to_6d(self, rotation_matrix):
        """Convert rotation matrix to 6D representation"""
        return rotation_matrix[:, :, :2].reshape(rotation_matrix.size(0), 6)

    def compute_add_s(self, pred_trans, pred_rot, gt_trans, gt_rot):
        scene_points = self.scene_points.to(pred_trans.device)
        batch_size = pred_trans.size(0)

        # Expand points to [B, N, 3]
        scene_points_batch = scene_points.unsqueeze(0).repeat(batch_size, 1, 1)

        # Stable transformation (avoid einsum)
        pred_points = torch.bmm(pred_rot, scene_points_batch.transpose(1,2)).transpose(1,2) + pred_trans.unsqueeze(1)
        gt_points = torch.bmm(gt_rot, scene_points_batch.transpose(1,2)).transpose(1,2) + gt_trans.unsqueeze(1)

        # Clamped distances
        dists = torch.cdist(pred_points, gt_points)
        min_dists = dists.min(dim=-1)[0].clamp(max=1.0)  # Prevent outliers
        return min_dists.mean(dim=1).mean()

    def forward(self, pred, target, epoch):
        # Position loss
        pos_loss = F.smooth_l1_loss(pred['position'], target['position'])

        # Rotation loss - direct 6D comparison
        R_pred = self.sixd_to_rotation_matrix(pred['rotation_6d'])

        # Geodesic rotation loss
        rot_loss = self.geodesic_loss(R_pred, target['rotation_matrix'])

        # ADD-S loss with orthogonalized rotation
        add_s = self.compute_add_s(
            pred['position'], R_pred,
            target['position'], target['rotation_matrix']
        )

        # Update EMAs
        self.pos_ema = self.alpha * self.pos_ema + (1 - self.alpha) * pos_loss.detach()
        self.rot_ema = self.alpha * self.rot_ema + (1 - self.alpha) * rot_loss.detach()
        self.add_ema = self.alpha * self.add_ema + (1 - self.alpha) * add_s.detach()

        # Compute stable weights with clipping
        w_pos = torch.clamp(1.0 / (self.pos_ema + 1e-3), 0.1, 10.0)
        w_rot = torch.clamp(1.0 / (self.rot_ema + 1e-3), 0.1, 10.0)
        w_add = torch.clamp(1.0 / (self.add_ema + 1e-3), 0.1, 10.0)

        # Normalize
        total = w_pos + w_rot + w_add + 1e-7
        w_pos /= total
        w_rot /= total
        w_add /= total

        total_loss = w_pos * pos_loss + w_rot * rot_loss + w_add * add_s

        return {
            'total': total_loss,
            'position': pos_loss,
            'rotation': rot_loss,
            'add_s': add_s,
            'weights': {
                'pos': w_pos.item(),
                'rot': w_rot.item(),
                'add': w_add.item()
            }
        }

    def geodesic_loss(self, R_pred, R_target):
        """Stable geodesic distance on SO(3)"""
        R_diff = torch.bmm(R_pred, R_target.transpose(1, 2))
        trace = R_diff[:, 0, 0] + R_diff[:, 1, 1] + R_diff[:, 2, 2]
        angle = torch.acos(torch.clamp((trace - 1)/2, -1+1e-6, 1-1e-6))
        return angle.mean()

    def sixd_to_rotation_matrix(self, rotation_6d):
        """SVD-based stable orthogonalization"""
        a1, a2 = rotation_6d[:, :3], rotation_6d[:, 3:6]

        # Form unnormalized matrix
        b3 = torch.cross(a1, a2)
        R = torch.stack([a1, a2, b3], dim=-1)

        # SVD orthogonalization
        U, S, Vh = torch.linalg.svd(R)
        return U @ Vh.transpose(-1, -2)

class SevenScenesPoseDataset(Dataset):
    def __init__(self, roots=['/content/7scenes', '/content/7scenes2', '/content/7scenes3'],
                 transform=None, split='train', val_test_split_ratio=0.7):
        self.roots = roots
        self.split = split
        self.transform = transforms.Compose([
            transforms.Resize(Config.input_size),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # NEW
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # RGB means
                std=[0.229, 0.224, 0.225]     # RGB stds
            )
        ])
        self.val_test_split_ratio = val_test_split_ratio

        # Initialize data storage
        self.poses_t = []
        self.poses_r = []
        self.samples = []

        def format_sequence_name(seq_str):
            """Convert raw sequence strings to standardized seq-XX format"""
            # Remove existing 'seq-' prefix if present
            clean_str = seq_str.lower().replace('sequence', '').strip()

            # Handle numeric formats
            if clean_str.isdigit():
                seq_num = int(clean_str)
                return f"seq-{seq_num:02d}"  # Always 2-digit format

            # Handle already formatted seq-01 cases
            if len(clean_str) == 2 and clean_str.isdigit():
                return f"seq-{clean_str}"

            # Fallback for non-numeric sequences
            return f"seq-{clean_str.zfill(2)}"

        # Collect sequences from all roots
        all_train_sequences = []
        all_test_sequences = []

        for root in self.roots:
            train_split_file = os.path.join(root, "TrainSplit.txt")
            test_split_file = os.path.join(root, "TestSplit.txt")

            # Load train sequences
            if os.path.exists(train_split_file):
                with open(train_split_file) as f:
                    for line in f:
                            raw_seq = line.strip()
                            if raw_seq:
                                # Format sequence name before adding
                                formatted_seq = format_sequence_name(raw_seq)
                                all_train_sequences.append((root, formatted_seq))

            # Load test sequences
            if os.path.exists(test_split_file):
                with open(test_split_file) as f:
                    for line in f:
                            raw_seq = line.strip()
                            if raw_seq:
                                # Format sequence name before adding
                                formatted_seq = format_sequence_name(raw_seq)
                                all_test_sequences.append((root, formatted_seq))

        # Split validation and test from test sequences
        if split in ['val', 'test']:
            # Split test sequences into val/test subsets
            split_idx = int(len(all_test_sequences) * self.val_test_split_ratio)
            if split == 'val':
                selected_sequences = all_test_sequences[:split_idx]
            else:
                selected_sequences = all_test_sequences[split_idx:]
        else:
            selected_sequences = all_train_sequences

        for root, seq in selected_sequences:
            seq_num = seq[-1]
            seq_name = 'seq-0'+seq_num
            seq = seq_name
        print(selected_sequences)
        # Process selected sequences
        for root, seq in selected_sequences:
            seq_path = os.path.join(root, seq)
            if not os.path.exists(seq_path):
                continue

            # Extract numeric sequence ID
            seq_id = seq.split('-')[-1]

            # Collect valid samples
            color_files = sorted([
                f for f in os.listdir(seq_path)
                if f.endswith('color.png')
            ])

            for cf in color_files:
                base = cf.replace('color.png', '')
                pose_file = base + 'pose.txt'
                pose_path = os.path.join(seq_path, pose_file)
                img_path = os.path.join(seq_path, cf)

                if os.path.exists(pose_path):
                    # Load and process pose
                    pose_mat = np.loadtxt(pose_path)
                    translation = pose_mat[:3, 3]
                    rotation = pose_mat[:3, :3]

                    self.poses_t.append(torch.tensor(translation))
                    self.poses_r.append(torch.tensor(rotation))
                    self.samples.append(img_path)

        # Convert to tensors
        if self.poses_t:
            self.poses_t = torch.stack(self.poses_t)
            self.poses_r = torch.stack(self.poses_r)
        else:
            self.poses_t = torch.empty(0)
            self.poses_r = torch.empty(0)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load image
        img_path = self.samples[idx]
        image = Image.open(img_path)

        # Process rotation matrix
        rotation = self.poses_r[idx].float()
        translation = self.poses_t[idx].float()

        return {
            'image': self.transform(image).float(),
            'rotation_matrix': rotation,
            'translation': translation.clone().detach(),
            'path': img_path 
        }

def train_one_epoch(model, dataloader, optimizer, criterion, device, epoch, accumulation_steps=4):
    model.train()
    total_loss = 0.0
    scaler = torch.cuda.amp.GradScaler()

    optimizer.zero_grad()
    for batch_idx, batch in enumerate(dataloader):
        model.snn_encoder.lif1.reset_mem()
        model.snn_encoder.lif2.reset_mem()
        model.snn_encoder.lif3.reset_mem()

        # Move data to device
        images = batch['image'].to(device, dtype=torch.float32)
        target_position = batch['translation'].to(device, dtype=torch.float32)
        target_rotation = batch['rotation_matrix'].to(device, dtype=torch.float32)

        # Skip any NaN batches
        if torch.isnan(images).any() or torch.isnan(target_position).any() or torch.isnan(target_rotation).any():
            continue

        # 1. Forward + loss under autocast
        #with torch.cuda.amp.autocast():  # ← AMP enabled here
        outputs = model(images)
        pred = {
            'position': outputs['position'],
            'rotation_6d': outputs['rotation_6d']
        }
        loss_dict = criterion(pred, {'position': target_position, 'rotation_matrix': target_rotation}, epoch)
        loss = loss_dict['total'] / accumulation_steps

        # 2. Backward scaled loss
        retain_graph = (batch_idx + 1) % accumulation_steps != 0
        scaler.scale(loss).backward(retain_graph=retain_graph)

        # 3. Step, unscale, clip, update only at accumulation boundary
        if (batch_idx + 1) % accumulation_steps == 0:
            # Unscale gradients before clipping
            scaler.unscale_(optimizer)  # moves gradients back to FP32 for clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # clip

            # Step optimizer and update scaler
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        total_loss += loss_dict['total'].item()

    # Handle leftover gradients if dataset size not divisible by accumulation_steps
    if len(dataloader) % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

    return total_loss / len(dataloader)

def compute_add_s_accuracy(pred_pts, gt_pts, thresholds=(0.1, 0.2, 0.5)):
    """Compute ADD-S accuracy at multiple thresholds"""
    # Compute pairwise distances
    dists = torch.cdist(pred_pts, gt_pts)

    # Get minimum distance for each point
    min_dists = dists.min(dim=-1)[0]

    # Calculate accuracy for each threshold
    accuracies = {}
    for thr in thresholds:
        within_threshold = (min_dists < thr).float()
        accuracies[f'acc@{thr:.1f}'] = within_threshold.mean()

    return accuracies

def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    loss_sums = {'total': 0.0, 'position': 0.0, 'rotation': 0.0, 'add_s': 0.0}
    acc_sums = {'acc@0.1': 0.0, 'acc@0.2': 0.0, 'acc@0.5': 0.0}
    num_batches = len(dataloader)

    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            target_pos = batch['translation'].to(device)
            target_rot = batch['rotation_matrix'].to(device)

            outputs = model(images)
            loss = criterion(outputs, {'position': target_pos, 'rotation_matrix': target_rot}, epoch)

            # Accumulate losses
            for k in loss_sums:
                loss_sums[k] += loss[k].item()

            # Compute ADD-S accuracy
            scene_pts = criterion.scene_points.to(device)
            batch_size = images.size(0)
            pts = scene_pts.unsqueeze(0).expand(batch_size, -1, -1)

            # Use orthogonalized rotation for accuracy computation
            R_pred = criterion.sixd_to_rotation_matrix(outputs['rotation_6d'])
            pred_pts = torch.bmm(R_pred, pts.transpose(1,2)).transpose(1,2) + outputs['position'].unsqueeze(1)
            gt_pts = torch.bmm(target_rot, pts.transpose(1,2)).transpose(1,2) + target_pos.unsqueeze(1)

            batch_acc = compute_add_s_accuracy(pred_pts, gt_pts)

            # Accumulate accuracies
            for k in acc_sums:
                acc_sums[k] += batch_acc[k].item()

    # Average across batches
    for k in loss_sums:
        loss_sums[k] /= num_batches
    for k in acc_sums:
        acc_sums[k] /= num_batches

    # Combine results
    return {**loss_sums, **acc_sums}

def main():

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = NeuromorphicLieGNN().to(device)
    def create_param_groups(model):
        rot_params = [p for n,p in model.named_parameters() if 'rot' in n]
        trans_params = [p for n,p in model.named_parameters() if 'position' in n or 'trans' in n]
        other_params = [p for n,p in model.named_parameters() if not ('rot' in n or 'position' in n)]

        return [
            {'params': rot_params, 'lr': 1e-4},
            {'params': trans_params, 'lr': 3e-4},
            {'params': other_params, 'lr': 2e-4}
        ]

    param_groups = create_param_groups(model)
    optimizer = optim.AdamW(param_groups, weight_decay=5e-6)

    # optimizer = optim.AdamW(model.parameters(), lr=Config.lr, weight_decay=5e-6)
    scene_points = pcd_tensor.point["positions"].numpy()/1000.0
    criterion = PoseLoss(scene_points=scene_points).to(device)

    train_dataset = SevenScenesPoseDataset(split='train')
    val_dataset = SevenScenesPoseDataset(split='val')
    test_dataset = SevenScenesPoseDataset(split='test')
    dummy = torch.randn(2,3,128,128).to(Config.device)
    out = model(dummy)
    # print(f"Position: {out['position'].min():.2f} to {out['position'].max():.2f}")
    # print(f"Rotation: {out['rotation'].min():.2f} to {out['rotation'].max():.2f}")

    sample = train_dataset[0]

    print("Image stats after normalization:")
    print("Min:", sample['image'].min().item())
    print("Max:", sample['image'].max().item())
    print("Mean:", sample['image'].mean().item())

    # print("\nDataset Sanity Check:")
    # print(f"Quaternion norm: {torch.norm(sample['rotation']):.4f} (should be ~1.0)")
    # print("Scalar weights mean/std:", model.spike_to_geo.conv_scalar.weight.mean().item())
    # print("Vector weights range:", model.spike_to_geo.conv_vector.weight.min().item())
    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=16)
    test_loader = DataLoader(test_dataset, batch_size=16)
    # Training parameters (from search result [3])
    num_epochs = 200
    best_val_loss = float('inf')
    batch = next(iter(test_loader))
    gt_trans = batch['translation'].to(device)
    gt_rot = batch['rotation_matrix'].to(device)
    with torch.no_grad():
        gt_adds = criterion.compute_add_s(gt_trans, gt_rot, gt_trans, gt_rot)
        print("ADD-S (GT vs GT):", gt_adds.mean().item())
    warmup_epochs = 10
    iterations_per_epoch = len(train_loader)
    
    # END OF ADD BLOCK 
    for epoch in range(num_epochs):
        # Train phase
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, epoch)

        # Validation phase
        val_loss = validate(model, val_loader, criterion, device, epoch)

        #scheduler.step()  # Call without arguments for LambdaLR

        # Save best model
        if val_loss['total'] < best_val_loss:
            best_val_loss = val_loss['total']

            # Saving
            torch.save({
                'model_state': model.state_dict(),
                'optimizer_state': optimizer.state_dict() #,
                #'scheduler_state': scheduler.state_dict() 
            }, f'best_s_model_epoch_{epoch}.pth')

        # Print epoch summary
        print(f"Epoch {epoch+1}/{num_epochs} | "
          f"Train Loss: {train_loss:.4f} | "
          f"Val Loss: {val_loss['total']:.4f} | "
          f"Pos: {val_loss['position']:.4f} | "
          f"Rot: {val_loss['rotation']:.4f} | "
          f"ADD-S: {val_loss['add_s']:.4f} | "
          f"Acc@0.1: {val_loss['acc@0.1']:.3f} | "
          f"Acc@0.2: {val_loss['acc@0.2']:.3f} | "
          f"Acc@0.5: {val_loss['acc@0.5']:.3f}")

if __name__ == "__main__":
    main()

TESTING CODE

In [None]:
class Config:
    batch_size = 16
    num_steps = 32  # Increased temporal resolution
    input_size = (128, 128)
    num_epochs = 100
    lr = 3e-4
    beta = 0.97  # Slower membrane decay
    grad_clip = 1.0
    irreps_in = o3.Irreps("256x0e")
    irreps_hidden = o3.Irreps("128x0e + 64x1o + 32x2e")
    irreps_out = o3.Irreps("1o + 1o")  # Rotation + Translation
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

class LearnableTemporalPool(nn.Module):
    def __init__(self, num_steps):
        super().__init__()
        self.conv = nn.Conv1d(
            in_channels=1,  # Process each feature independently
            out_channels=1,
            kernel_size=3,
            padding=1
        )
        self.weights = nn.Parameter(torch.ones(num_steps))

    def forward(self, spikes):
        batch_size, num_steps, channels, height, width = spikes.shape

        # Flatten spatial and channel dimensions
        x = spikes.reshape(batch_size * channels * height * width, num_steps)

        # Add channel dimension: [B*C*H*W, T] -> [B*C*H*W, 1, T]
        x = x.unsqueeze(1)

        # Apply convolution: [B*C*H*W, 1, T] -> [B*C*H*W, 1, T]
        x = self.conv(x)

        # Apply learnable weights: [B*C*H*W, 1, T] * [1, 1, T] -> [B*C*H*W, 1, T]
        x = x * self.weights.view(1, 1, -1)

        # Sum over time: [B*C*H*W, 1]
        x = x.sum(dim=2)

        # Reshape to original spatial dimensions: [B, C, H, W]
        return x.reshape(batch_size, channels, height, width)

class SNNEncoder(nn.Module):
    def __init__(self, input_dim=3, hidden_dim=64, output_dim=32, beta=0.5, num_steps=25, spike_grad=None):
        super().__init__()
        self.conv1 = nn.Conv2d(input_dim, hidden_dim, kernel_size=5, stride=2, padding=2, padding_mode='replicate')
        self.conv2 = nn.Conv2d(hidden_dim, hidden_dim*2, kernel_size=5, stride=2, padding=2, padding_mode='replicate')
        self.conv3 = nn.Conv2d(hidden_dim*2, output_dim, kernel_size=5, stride=2, padding=2, padding_mode='replicate')

        self.temporal_pool = LearnableTemporalPool(num_steps)

        # Spiking neuron layers
        spike_grad = surrogate.fast_sigmoid() if spike_grad is None else spike_grad
        self.lif1 = snn.Leaky(
            beta=0.85,
            threshold=0.5,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        self.lif2 = snn.Leaky(
            beta=0.78,
            threshold=0.6,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        self.lif3 = snn.Leaky(
            beta=0.72,
            threshold=0.7,  # Lower initial threshold
            learn_threshold=True,  # Make threshold trainable
            spike_grad=surrogate.fast_sigmoid(slope=10.0)
        )
        # Hidden states
        self.mem1 = self.mem2 = self.mem3 = None

    def forward(self, x, num_steps=25):
        # Reset membrane potentials
        batch_size = x.size(0)
        x = x.unsqueeze(1)  # [B, 1, C, H, W]
        x = x.repeat(1, num_steps, 1, 1, 1)  # [B, T, C, H, W]
        x += torch.randn_like(x) * 0.1  # Add temporal noise
        x = x.view(batch_size * num_steps, *x.shape[2:])  # [B*T, C, H, W]
        if torch.isnan(x).any():
            print("NaN in SNN input after reshaping!")
            raise RuntimeError

        # Initialize membrane potentials
        mem1 = torch.zeros(
            batch_size,
            num_steps,
            self.conv1.out_channels,
            x.size(2)//2,
            x.size(3)//2
        ).to(x.device)
        mem2 = torch.zeros(
            batch_size,
            num_steps,
            self.conv2.out_channels,
            x.size(2)//4,
            x.size(3)//4
        ).to(x.device)
        mem3 = torch.zeros(
            batch_size,
            num_steps,
            self.conv3.out_channels,
            x.size(2)//8,
            x.size(3)//8
        ).to(x.device)
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        spk3_rec = []
        for t in range(num_steps):
            x_t = x.view(batch_size, num_steps, *x.shape[1:])[:, t]

            # Layer 1 with snnTorch state handling
            cur1 = self.conv1(x_t)
            spk1, mem1 = self.lif1(cur1, mem1)  # Automatic state update

            # Layer 2
            cur2 = self.conv2(spk1)
            spk2, mem2 = self.lif2(cur2, mem2)

            # Layer 3
            cur3 = self.conv3(spk2)
            spk3, mem3 = self.lif3(cur3, mem3)

            spk3_rec.append(spk3)

        return torch.stack(spk3_rec, dim=1)

class SpikeToGeometric(nn.Module):
    def __init__(self, input_dim=32, output_scalar=32, output_vector=32):
        super().__init__()
        self.conv_scalar = nn.Conv2d(input_dim, output_scalar, kernel_size=1)
        self.conv_vector = nn.Conv2d(input_dim, output_vector*3, kernel_size=1)
        self.output_vector = output_vector

        # Proper initialization for Conv layers
        nn.init.kaiming_normal_(self.conv_scalar.weight,
                              mode='fan_in',
                              nonlinearity='leaky_relu')
        nn.init.kaiming_normal_(self.conv_vector.weight,
                              mode='fan_in',
                              nonlinearity='leaky_relu')
        nn.init.zeros_(self.conv_scalar.bias)
        nn.init.zeros_(self.conv_vector.bias)

    def forward(self, spike_features):
        """Convert spike rate features to geometric features with spatial preservation"""
        # Input shape: [B, 32, 16, 16]
        # Process scalar features (per-location invariant features)
        batch_size = spike_features.size(0)
        if torch.isnan(spike_features).any():
            print("NaN in spike_features input to SpikeToGeometric")
            raise RuntimeError

        scalar_features = self.conv_scalar(spike_features)  # [B,64,16,16]
        if torch.isnan(scalar_features).any():
            print("NaN in scalar_features after conv_scalar")
            raise RuntimeError

        # Process vector features (per-location equivariant features)
        vector_features = self.conv_vector(spike_features)  # [B,192,16,16]
        vector_features = vector_features.view(
            batch_size, self.output_vector, 3, 16, 16
        )  # [B,64,3,16,16]
        if torch.isnan(vector_features).any():
            print("NaN in vector_features after conv_vector")
            raise RuntimeError

        return scalar_features, vector_features

class GraphConstructor(nn.Module):
    def __init__(self, feature_dim=256, k_neighbors=8):
        super().__init__()
        self.k = k_neighbors
        self.feature_projection = nn.Linear(feature_dim, feature_dim)

    def forward(self, features, positions=None):
        batch_size, num_nodes, _ = features.size()
        device = features.device

        # Project features
        node_features = self.feature_projection(features)

        # Build graphs
        graphs = []
        for b in range(batch_size):
            if positions is not None:
                # Use 3D spatial positions for KNN
                pos = positions[b]
                edge_index = self._knn_graph(pos, self.k)
            else:
                # Fallback to feature similarity graph
                feat = node_features[b]
                edge_index = self._feature_graph(feat, self.k)

            graph = Data(
                x=node_features[b],
                edge_index=edge_index,
                pos=positions[b] if positions is not None else None
            )
            graphs.append(graph)

        return Batch.from_data_list(graphs)

    def _feature_graph(self, features, k):
        """Construct graph based on feature similarity (critical for non-spatial data)"""
        device = features.device
        num_nodes = features.size(0)
        effective_k = min(k, num_nodes - 1)

        # Compute cosine similarity
        features_norm = F.normalize(features, p=2, dim=1)
        similarity = torch.mm(features_norm, features_norm.t())

        # Get top k+1 similar neighbors
        _, indices = torch.topk(similarity, k=effective_k+1, largest=True)
        indices = indices[:, 1:]  # Remove self-loops

        # Create edge index with proper device handling
        rows = torch.arange(num_nodes, device=device)[:, None].expand(-1, effective_k)
        edge_index = torch.stack([
            rows.reshape(-1),
            indices.reshape(-1).to(device)
        ], dim=0)

        return edge_index

    def _knn_graph(self, points, k):
        """3D spatial KNN graph construction with locality preservation"""
        device = points.device
        num_nodes = points.size(0)
        effective_k = min(k, num_nodes - 1)

        # Preserve spatial locality by using grid ordering
        if num_nodes == 256:  # 16x16 grid
            # Create grid indices
            idx = torch.arange(num_nodes, device=device).view(16, 16)

            # Horizontal connections
            horizontal_edges = []
            for i in range(16):
                for j in range(15):
                    n1 = idx[i, j]
                    n2 = idx[i, j+1]
                    horizontal_edges.append([n1, n2])
                    horizontal_edges.append([n2, n1])

            # Vertical connections
            vertical_edges = []
            for i in range(15):
                for j in range(16):
                    n1 = idx[i, j]
                    n2 = idx[i+1, j]
                    vertical_edges.append([n1, n2])
                    vertical_edges.append([n2, n1])

            # Diagonal connections
            diagonal_edges = []
            for i in range(15):
                for j in range(15):
                    n1 = idx[i, j]
                    n2 = idx[i+1, j+1]
                    diagonal_edges.append([n1, n2])
                    diagonal_edges.append([n2, n1])

            # Combine all connections
            edge_index = torch.tensor(horizontal_edges + vertical_edges + diagonal_edges,
                                     device=device).t().contiguous()
            return edge_index

        # Fallback to geometric KNN for non-grid structures
        dist = torch.cdist(points, points)
        _, indices = torch.topk(dist, k=effective_k+1, largest=False)
        indices = indices[:, 1:]  # Remove self-loops

        rows = torch.arange(num_nodes, device=device)[:, None].expand(-1, effective_k)
        return torch.stack([rows.reshape(-1), indices.reshape(-1)], dim=0)

class SE3EquivariantGNN(nn.Module):
    def __init__(self, hidden_dim=32, num_layers=3):
        super().__init__()
        self.irreps_in = o3.Irreps("64x0e + 64x1o")
        self.irreps_hidden = o3.Irreps("64x0e + 64x1o")
        self.irreps_edge = o3.Irreps("1x0e")  # Scalar edge features
        self.irreps_out = o3.Irreps("1x1o + 2x1o") 

        self.layers = nn.ModuleList()

        # Initial layer with edge integration
        self.layers.append(o3.TensorProduct(
            self.irreps_in,
            self.irreps_edge,
            self.irreps_hidden,
            instructions=[(i,0,i,'uvu',True, 1.0) for i in range(len(self.irreps_in))],
            shared_weights=True,
            internal_weights=True
        ))

        # Hidden layers
        for _ in range(num_layers-2):
            self.layers.append(o3.TensorProduct(
                self.irreps_hidden,
                self.irreps_edge,
                self.irreps_hidden,
                instructions=[(i,0,i,'uvu', True, 1.0) for i in range(len(self.irreps_hidden))],
                shared_weights=True,
                path_normalization='element',
                internal_weights=True
            ))
            self.layers.append(self.Float32NormActivation(
                self.irreps_hidden,
                scalar_nonlinearity=torch.tanh,
                epsilon=1e-5
            ))
            
        # Output layer
        self.layers.append(o3.Linear(self.irreps_hidden, self.irreps_out))

    class Float32NormActivation(e3nn_nn.NormActivation):
        def forward(self, features):
            original_dtype = features.dtype
            # Convert to float32 for stable computation
            features = features.float()
            features = super().forward(features)
            # Convert back to original precision
            return features.to(original_dtype)

    def forward(self, x, edge_index, pos):
        # Compute edge features (radial basis)
        if torch.isnan(x).any():
            print("NaN in GNN input")
            raise RuntimeError

        senders, receivers = edge_index
        rel_pos = pos[senders] - pos[receivers]
        distances = torch.norm(rel_pos, dim=1, keepdim=True)
        if torch.isnan(rel_pos).any() or torch.isnan(distances).any():
            print("NaN in edge features")
            raise RuntimeError

        edge_attr = torch.exp(-distances**2)  # [num_edges, 1]
        # Message passing loop
        for layer_idx, layer in enumerate(self.layers):
            if isinstance(layer, o3.TensorProduct):
                # Aggregate messages using edge features
                messages = layer(x[senders], edge_attr)
                if torch.isnan(messages).any():
                    print(f"NaN in messages at layer {layer_idx}")
                    raise RuntimeError

                x = scatter(messages, receivers, dim=0, dim_size=x.size(0))
                if torch.isnan(x).any():
                    print(f"NaN after scattering at layer {layer_idx}")
                    raise RuntimeError
            else:
                x = layer(x)
                # print(x)
                if torch.isnan(x).any():
                    print(x)
                    print(f"NaN after layer {layer_idx} ({type(layer).__name__})")
                    raise RuntimeError

        # Split output
        translation = x[:, :3]  # First 3 components (translation)
        translation = x[:, :3]
        rotation_6d = x[:, 3:9]  # Raw 6D representation
        return translation, rotation_6d

    def inference_orthogonalize(self, rotation_6d):
        """Apply orthogonalization ONLY during inference"""
        a1, a2 = rotation_6d[:, :3], rotation_6d[:, 3:6]
        return self.sixd_to_rotation_matrix(a1, a2)

    def sixd_to_rotation_matrix(self, a1, a2):
        """Stable 6D to rotation conversion"""
        # Add epsilon guards
        a1 = a1 / (torch.norm(a1, dim=-1, keepdim=True) + 1e-6)
        a2 = a2 - (a1 * a2).sum(dim=-1, keepdim=True) * a1
        a2 = a2 / (torch.norm(a2, dim=-1, keepdim=True) + 1e-6)
        b3 = torch.cross(a1, a2)
        R = torch.stack([a1, a2, b3], dim=-1)

        return R

class NeuromorphicLieGNN(nn.Module):
    def __init__(self, input_channels=3, hidden_dim=64, num_steps=25, beta=0.5, num_gnn_layers=3):
        super().__init__()

        # SNN encoder
        self.snn_encoder = SNNEncoder(
            input_dim=input_channels,
            hidden_dim=hidden_dim,
            output_dim=32,
            beta=beta
        )

        # Spike to geometric conversion
        self.spike_to_geo = SpikeToGeometric(
            input_dim=32,  # Depends on input size and encoder architecture
            output_scalar=64,
            output_vector=64
        )

        # Graph construction
        self.graph_constructor = GraphConstructor(feature_dim=256, k_neighbors=8)

        # SE(3) equivariant GNN
        self.se3_gnn = SE3EquivariantGNN(hidden_dim=32, num_layers=num_gnn_layers)

        # Parameters
        self.num_steps = num_steps

    def forward(self, images, positions=None):
        """
        Args:
            images: Batch of input images (B, C, H, W)
            positions: Optional initial position estimates (B, N, 3)
        """
        batch_size = images.size(0)

        # 1. Process through SNN encoder to get spike trains
        spike_features = self.snn_encoder(images, self.num_steps)
        # print("Encoder spikes:", torch.mean(spike_features).item())
        # print("Spike features shape:", spike_features.shape)  # [B,T,C,H,W]  # Add here
        if torch.isnan(spike_features).any():
            raise ValueError("NaN in spike features!")

        # 2. Convert to rate-based representation
        rate_features = self.snn_encoder.temporal_pool(spike_features)
        # Ensure consistent dtype
        rate_features = rate_features.to(torch.float32)
        # print("Rate features shape:", rate_features.shape)  # [B,C,H,W]
        if torch.isnan(rate_features).any():
            raise ValueError("NaN in rate features!")

        # 3. Convert to geometric features
        scalar_features, vector_features = self.spike_to_geo(rate_features)
        if torch.isnan(scalar_features).any() or torch.isnan(vector_features).any():
            print("NaN in geometric features")
            raise RuntimeError

        # print("Scalar features:", scalar_features.abs().mean().item())
        # print("Vector features:", vector_features.abs().mean().item())
        # print("Vector features pre-reshape:", vector_features.shape)
        vector_reshaped = vector_features.view(
            batch_size,
            self.spike_to_geo.output_vector * 3,  # 64*3=192
            16,
            16
        )
        # print("Vector features post-reshape:", vector_reshaped.shape)  # [B,192,16,16]
        # 4. Construct graph
        # Combine scalar and vector features for graph construction
        combined_features = torch.cat([
            scalar_features,  # [B,64,16,16]
            vector_reshaped
        ], dim=1)  # [B,256,16,16]
        # print("Combined features shape:", combined_features.shape)  # [B,256,16,16]

        # Reshape to spatial nodes [B,256,16,16] -> [B,256,256]
        node_features = combined_features.permute(0, 2, 3, 1)  # [B, 16, 16, 256]

        # Get normalized grid positions
        grid = self._get_grid_positions().to(images.device)
        grid = (grid - grid.min()) / (grid.max() - grid.min()) * 2 - 1  # Normalize to [-1,1]

        # Expand positions to match batch size
        positions = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1)  # [B, 16, 16, 3]

        # Flatten spatial dimensions while preserving locality
        node_features = node_features.reshape(batch_size, 256, -1)  # [B, 256, 256]
        positions = positions.reshape(batch_size, 256, 3)  # [B, 256, 3]

        # Build graph with spatial positions
        graph = self.graph_constructor(node_features, positions)

        # 5. Process through SE(3) equivariant GNN
        position, rotation_6d = self.se3_gnn(graph.x, graph.edge_index, graph.pos)


        rotation_matrix = self.se3_gnn.inference_orthogonalize(rotation_6d)

        position = scatter(position, graph.batch, dim=0, reduce='mean')
        rotation_6d = scatter(rotation_6d, graph.batch, dim=0, reduce='mean')
        rotation_matrix = scatter(rotation_matrix, graph.batch, dim=0, reduce='mean')

        return {
            'position': position,
            'rotation_6d': rotation_6d,  # For training loss
            'rotation_matrix': rotation_matrix,  # For inference/validation
            'pose': self._combine_pose(position, rotation_matrix)
        }

    def _combine_pose(self, position, rotation):
        """Combine position and rotation into 4x4 transformation matrix"""
        batch_size = position.size(0)

        # Convert quaternion to rotation matrix
        R = rotation
        # Create transformation matrix
        T = torch.eye(4).unsqueeze(0).repeat(batch_size, 1, 1).to(position.device)
        T[:, :3, :3] = R
        T[:, :3, 3] = position

        return T

    def _get_grid_positions(self):
        """Generate 3D pseudo-coordinates for spatial graph construction"""
        y, x = torch.meshgrid(
            torch.arange(16),
            torch.arange(16),
            indexing='ij'
        )
        z = torch.zeros_like(x)  # Add depth dimension
        return torch.stack([x.float(), y.float(), z], dim=-1)  # [16,16,3]


class PoseLoss(nn.Module):
    def __init__(self, scene_points, alpha=0.9):
        super().__init__()
        # Scene points setup
        self.register_buffer('scene_points', self._preprocess_scene_points(scene_points))

        # EMA parameters for adaptive weighting
        self.alpha = alpha
        self.register_buffer('pos_ema', torch.tensor(float('inf')))
        self.register_buffer('rot_ema', torch.tensor(float('inf')))
        self.register_buffer('add_ema', torch.tensor(float('inf')))

    def _farthest_point_sampling(self, points, n_samples):
        """Memory-efficient FPS without full distance matrix"""
        device = points.device
        n_points = points.shape[0]

        # Initialize with random point
        start_idx = torch.randint(0, n_points, (1,)).item()
        selected_indices = [start_idx]

        # Initialize min distances
        min_dists = torch.full((n_points,), float('inf'), device=device)

        # Compute initial distances
        dists = torch.norm(points - points[start_idx], dim=1)
        min_dists = torch.min(min_dists, dists)

        for _ in range(1, min(n_samples, n_points)):
            # Find farthest point
            next_idx = torch.argmax(min_dists).item()
            selected_indices.append(next_idx)

            # Compute distances to new point
            new_dists = torch.norm(points - points[next_idx], dim=1)

            # Update min distances
            min_dists = torch.min(min_dists, new_dists)
            min_dists[next_idx] = 0  # Avoid reselecting

        return points[selected_indices]

    def _preprocess_scene_points(self, points):
        """Fast random subsampling for debugging"""
        if isinstance(points, np.ndarray):
            points = torch.from_numpy(points)
        points = points.float()

        # Simple random sampling instead of FPS
        if len(points) > 500:
            indices = torch.randperm(len(points))[:500]
            points = points[indices]

        return points

    def _matrix_to_6d(self, rotation_matrix):
        """Convert rotation matrix to 6D representation"""
        return rotation_matrix[:, :, :2].reshape(rotation_matrix.size(0), 6)

    def compute_add_s(self, pred_trans, pred_rot, gt_trans, gt_rot):
        scene_points = self.scene_points.to(pred_trans.device)
        batch_size = pred_trans.size(0)

        # Expand points to [B, N, 3]
        scene_points_batch = scene_points.unsqueeze(0).repeat(batch_size, 1, 1)

        # Stable transformation (avoid einsum)
        pred_points = torch.bmm(pred_rot, scene_points_batch.transpose(1,2)).transpose(1,2) + pred_trans.unsqueeze(1)
        gt_points = torch.bmm(gt_rot, scene_points_batch.transpose(1,2)).transpose(1,2) + gt_trans.unsqueeze(1)

        # Clamped distances
        dists = torch.cdist(pred_points, gt_points)
        min_dists = dists.min(dim=-1)[0].clamp(max=1.0)  # Prevent outliers
        return min_dists.mean(dim=1).mean()

    def forward(self, pred, target, epoch=None):
        # Position loss
        pos_loss = F.smooth_l1_loss(pred['position'], target['position'])

        # Rotation loss - direct 6D comparison
        R_pred = self.sixd_to_rotation_matrix(pred['rotation_6d'])

        # Geodesic rotation loss
        rot_loss = self.geodesic_loss(R_pred, target['rotation_matrix'])

        # ADD-S loss with orthogonalized rotation
        add_s = self.compute_add_s(
            pred['position'], R_pred,
            target['position'], target['rotation_matrix']
        )

        # Update EMAs
        self.pos_ema = self.alpha * self.pos_ema + (1 - self.alpha) * pos_loss.detach()
        self.rot_ema = self.alpha * self.rot_ema + (1 - self.alpha) * rot_loss.detach()
        self.add_ema = self.alpha * self.add_ema + (1 - self.alpha) * add_s.detach()

        # Compute stable weights with clipping
        w_pos = torch.clamp(1.0 / (self.pos_ema + 1e-3), 0.1, 10.0)
        w_rot = torch.clamp(1.0 / (self.rot_ema + 1e-3), 0.1, 10.0)
        w_add = torch.clamp(1.0 / (self.add_ema + 1e-3), 0.1, 10.0)

        # Normalize
        total = w_pos + w_rot + w_add + 1e-7
        w_pos /= total
        w_rot /= total
        w_add /= total

        total_loss = w_pos * pos_loss + w_rot * rot_loss + w_add * add_s

        return {
            'total': total_loss,
            'position': pos_loss,
            'rotation': rot_loss,
            'add_s': add_s,
            'weights': {
                'pos': w_pos.item(),
                'rot': w_rot.item(),
                'add': w_add.item()
            }
        }

    def geodesic_loss(self, R_pred, R_target):
        """Stable geodesic distance on SO(3)"""
        R_diff = torch.bmm(R_pred, R_target.transpose(1, 2))
        trace = R_diff[:, 0, 0] + R_diff[:, 1, 1] + R_diff[:, 2, 2]
        angle = torch.acos(torch.clamp((trace - 1)/2, -1+1e-6, 1-1e-6))
        return angle.mean()

    def sixd_to_rotation_matrix(self, rotation_6d):
        """SVD-based stable orthogonalization"""
        a1, a2 = rotation_6d[:, :3], rotation_6d[:, 3:6]

        # Form unnormalized matrix
        b3 = torch.cross(a1, a2)
        R = torch.stack([a1, a2, b3], dim=-1)

        # SVD orthogonalization
        U, S, Vh = torch.linalg.svd(R)
        return U @ Vh.transpose(-1, -2)

class SevenScenesPoseDataset(Dataset):
    def __init__(self, roots=['/content/7scenes', '/content/7scenes2', '/content/7scenes3'],
                 transform=None, split='train', val_test_split_ratio=0.7):
        self.roots = roots
        self.split = split
        self.transform = transforms.Compose([
            transforms.Resize(Config.input_size),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # NEW
            transforms.RandomRotation(5),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],  # RGB means
                std=[0.229, 0.224, 0.225]     # RGB stds
            )
        ])
        self.val_test_split_ratio = val_test_split_ratio

        # Initialize data storage
        self.poses_t = []
        self.poses_r = []
        self.samples = []

        def format_sequence_name(seq_str):
            """Convert raw sequence strings to standardized seq-XX format"""
            # Remove existing 'seq-' prefix if present
            clean_str = seq_str.lower().replace('sequence', '').strip()

            # Handle numeric formats
            if clean_str.isdigit():
                seq_num = int(clean_str)
                return f"seq-{seq_num:02d}"  # Always 2-digit format

            # Handle already formatted seq-01 cases
            if len(clean_str) == 2 and clean_str.isdigit():
                return f"seq-{clean_str}"

            # Fallback for non-numeric sequences
            return f"seq-{clean_str.zfill(2)}"

        # Collect sequences from all roots
        all_train_sequences = []
        all_test_sequences = []

        for root in self.roots:
            train_split_file = os.path.join(root, "TrainSplit.txt")
            test_split_file = os.path.join(root, "TestSplit.txt")

            # Load train sequences
            if os.path.exists(train_split_file):
                with open(train_split_file) as f:
                    for line in f:
                            raw_seq = line.strip()
                            if raw_seq:
                                # Format sequence name before adding
                                formatted_seq = format_sequence_name(raw_seq)
                                all_train_sequences.append((root, formatted_seq))

            # Load test sequences
            if os.path.exists(test_split_file):
                with open(test_split_file) as f:
                    for line in f:
                            raw_seq = line.strip()
                            if raw_seq:
                                # Format sequence name before adding
                                formatted_seq = format_sequence_name(raw_seq)
                                all_test_sequences.append((root, formatted_seq))

        # Split validation and test from test sequences
        if split in ['val', 'test']:
            # Split test sequences into val/test subsets
            split_idx = int(len(all_test_sequences) * self.val_test_split_ratio)
            if split == 'val':
                selected_sequences = all_test_sequences[:split_idx]
            else:
                selected_sequences = all_test_sequences[split_idx:]
        else:
            selected_sequences = all_train_sequences

        for root, seq in selected_sequences:
            seq_num = seq[-1]
            seq_name = 'seq-0'+seq_num
            seq = seq_name
        print(selected_sequences)
        # Process selected sequences
        for root, seq in selected_sequences:
            seq_path = os.path.join(root, seq)
            if not os.path.exists(seq_path):
                continue

            # Extract numeric sequence ID
            seq_id = seq.split('-')[-1]

            # Collect valid samples
            color_files = sorted([
                f for f in os.listdir(seq_path)
                if f.endswith('color.png')
            ])

            for cf in color_files:
                base = cf.replace('color.png', '')
                pose_file = base + 'pose.txt'
                pose_path = os.path.join(seq_path, pose_file)
                img_path = os.path.join(seq_path, cf)

                if os.path.exists(pose_path):
                    # Load and process pose
                    pose_mat = np.loadtxt(pose_path)
                    translation = pose_mat[:3, 3]
                    rotation = pose_mat[:3, :3]

                    self.poses_t.append(torch.tensor(translation))
                    self.poses_r.append(torch.tensor(rotation))
                    self.samples.append(img_path)

        # Convert to tensors
        if self.poses_t:
            self.poses_t = torch.stack(self.poses_t)
            self.poses_r = torch.stack(self.poses_r)
        else:
            self.poses_t = torch.empty(0)
            self.poses_r = torch.empty(0)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        # Load image
        img_path = self.samples[idx]
        image = Image.open(img_path)

        # Process rotation matrix
        rotation = self.poses_r[idx].float()
        translation = self.poses_t[idx].float()

        return {
            'image': self.transform(image).float(),
            'rotation_matrix': rotation,
            'translation': translation.clone().detach(),
            'path': img_path  # For debugging
        }

def train_one_epoch(model, dataloader, optimizer, criterion, device, epoch, accumulation_steps=4):
    model.train()
    total_loss = 0.0
    scaler = torch.cuda.amp.GradScaler()

    optimizer.zero_grad()
    for batch_idx, batch in enumerate(dataloader):
        model.snn_encoder.lif1.reset_mem()
        model.snn_encoder.lif2.reset_mem()
        model.snn_encoder.lif3.reset_mem()

        # Move data to device
        images = batch['image'].to(device, dtype=torch.float32)
        target_position = batch['translation'].to(device, dtype=torch.float32)
        target_rotation = batch['rotation_matrix'].to(device, dtype=torch.float32)

        # Skip any NaN batches
        if torch.isnan(images).any() or torch.isnan(target_position).any() or torch.isnan(target_rotation).any():
            continue

        # 1. Forward + loss under autocast
        #with torch.cuda.amp.autocast():  # AMP enabled here
        outputs = model(images)
        pred = {
            'position': outputs['position'],
            'rotation_6d': outputs['rotation_6d']
        }
        loss_dict = criterion(pred, {'position': target_position, 'rotation_matrix': target_rotation}, epoch)
        loss = loss_dict['total'] / accumulation_steps

        # 2. Backward scaled loss
        retain_graph = (batch_idx + 1) % accumulation_steps != 0
        scaler.scale(loss).backward(retain_graph=retain_graph)

        # 3. Step, unscale, clip, update only at accumulation boundary
        if (batch_idx + 1) % accumulation_steps == 0:
            # Unscale gradients before clipping
            scaler.unscale_(optimizer)  # moves gradients back to FP32 for clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)  # clip

            # Step optimizer and update scaler
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad(set_to_none=True)

        total_loss += loss_dict['total'].item()

    # Handle leftover gradients if dataset size not divisible by accumulation_steps
    if len(dataloader) % accumulation_steps != 0:
        scaler.unscale_(optimizer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        scaler.step(optimizer)
        scaler.update()
        optimizer.zero_grad(set_to_none=True)

    return total_loss / len(dataloader)

def compute_add_s_accuracy(pred_pts, gt_pts, thresholds=(0.1, 0.2, 0.5)):
    """Compute ADD-S accuracy at multiple thresholds"""
    # Compute pairwise distances
    dists = torch.cdist(pred_pts, gt_pts)

    # Get minimum distance for each point
    min_dists = dists.min(dim=-1)[0]

    # Calculate accuracy for each threshold
    accuracies = {}
    for thr in thresholds:
        within_threshold = (min_dists < thr).float()
        accuracies[f'acc@{thr:.1f}'] = within_threshold.mean()

    return accuracies

def validate(model, dataloader, criterion, device, epoch):
    model.eval()
    loss_sums = {'total': 0.0, 'position': 0.0, 'rotation': 0.0, 'add_s': 0.0}
    acc_sums = {'acc@0.1': 0.0, 'acc@0.2': 0.0, 'acc@0.5': 0.0}
    num_batches = len(dataloader)

    with torch.no_grad():
        for batch in dataloader:
            images = batch['image'].to(device)
            target_pos = batch['translation'].to(device)
            target_rot = batch['rotation_matrix'].to(device)

            outputs = model(images)
            loss = criterion(outputs, {'position': target_pos, 'rotation_matrix': target_rot}, epoch)

            # Accumulate losses
            for k in loss_sums:
                loss_sums[k] += loss[k].item()

            # Compute ADD-S accuracy
            scene_pts = criterion.scene_points.to(device)
            batch_size = images.size(0)
            pts = scene_pts.unsqueeze(0).expand(batch_size, -1, -1)

            # Use orthogonalized rotation for accuracy computation
            R_pred = criterion.sixd_to_rotation_matrix(outputs['rotation_6d'])
            pred_pts = torch.bmm(R_pred, pts.transpose(1,2)).transpose(1,2) + outputs['position'].unsqueeze(1)
            gt_pts = torch.bmm(target_rot, pts.transpose(1,2)).transpose(1,2) + target_pos.unsqueeze(1)

            batch_acc = compute_add_s_accuracy(pred_pts, gt_pts)

            # Accumulate accuracies
            for k in acc_sums:
                acc_sums[k] += batch_acc[k].item()

    # Average across batches
    for k in loss_sums:
        loss_sums[k] /= num_batches
    for k in acc_sums:
        acc_sums[k] /= num_batches

    # Combine results
    return {**loss_sums, **acc_sums}

In [None]:
def test(model, test_loader, criterion, device):
    model.eval()
    loss_sums = {'total': 0.0, 'position': 0.0, 'rotation': 0.0, 'add_s': 0.0}
    acc_sums = {'acc@0.1': 0.0, 'acc@0.2': 0.0, 'acc@0.5': 0.0}
    num_batches = len(test_loader)

    with torch.no_grad():
        for batch in test_loader:
            images = batch['image'].to(device)
            target_pos = batch['translation'].to(device)
            target_rot = batch['rotation_matrix'].to(device)

            outputs = model(images)
            loss = criterion(outputs, {'position': target_pos, 'rotation_matrix': target_rot})

            # Accumulate losses
            for k in loss_sums:
                loss_sums[k] += loss[k].item()

            # Compute ADD-S accuracy
            scene_pts = criterion.scene_points.to(device)
            batch_size = images.size(0)
            pts = scene_pts.unsqueeze(0).expand(batch_size, -1, -1)

            # Use orthogonalized rotation for accuracy computation
            R_pred = criterion.sixd_to_rotation_matrix(outputs['rotation_6d'])
            pred_pts = torch.bmm(R_pred, pts.transpose(1,2)).transpose(1,2) + outputs['position'].unsqueeze(1)
            gt_pts = torch.bmm(target_rot, pts.transpose(1,2)).transpose(1,2) + target_pos.unsqueeze(1)

            batch_acc = compute_add_s_accuracy(pred_pts, gt_pts)

            # Accumulate accuracies
            for k in acc_sums:
                acc_sums[k] += batch_acc[k].item()

    # Average across batches
    for k in loss_sums:
        loss_sums[k] /= num_batches
    for k in acc_sums:
        acc_sums[k] /= num_batches

    # Combine results
    return {**loss_sums, **acc_sums}


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = NeuromorphicLieGNN().to(device)
scene_points = pcd_tensor.point["positions"].numpy()/1000.0
criterion = PoseLoss(scene_points=scene_points).to(device)
checkpoint = torch.load('best_s_model_epoch_62.pth', map_location=device) # Replace file name with your trained model checkpoint
model.load_state_dict(checkpoint['model_state'])
test_dataset = SevenScenesPoseDataset(split='test')
test_loader = DataLoader(test_dataset, batch_size=16)
test_loss = test(model, test_loader, criterion, device)
print(test_loss)