# Triplane Encoder/Decoder
This is Midterm Baseline Implementation. The architecture defined by Shue et al., we created two toy prototypes to test training of
the triplanar representation.

In [None]:
!pip install --upgrade PyMCubes
!pip install trimesh

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import trimesh
import os
from tqdm import tqdm
import matplotlib.pyplot as plt
import mcubes
from IPython.display import display

# Setup

## 1 DATASET LOADING AND PREPROCESSIN

In [None]:
#===========================================
# PART 1: DATASET LOADING AND PREPROCESSING
#===========================================

def load_off_file(off_path):
    """Load an OFF file and normalize to [-1, 1]"""
    try:
        # Read OFF file manually
        with open(off_path, 'r') as f:
            lines = f.readlines()
        
        # Parse header
        line_idx = 0
        while line_idx < len(lines) and (lines[line_idx].startswith('#') or lines[line_idx].strip() == ''):
            line_idx += 1
        
        # Check if first line is "OFF"
        if lines[line_idx].strip() == "OFF":
            line_idx += 1
        elif lines[line_idx].strip().startswith("OFF"):
            # Sometimes "OFF" and counts are on the same line
            counts = lines[line_idx].strip()[3:].split()
            if len(counts) == 3:
                num_vertices, num_faces, _ = map(int, counts)
                line_idx += 1
            else:
                line_idx += 1
                counts = lines[line_idx].strip().split()
                num_vertices, num_faces, _ = map(int, counts)
                line_idx += 1
        else:
            raise ValueError(f"Invalid OFF file format: {off_path}")
        
        # Get counts if not yet parsed
        if 'num_vertices' not in locals():
            counts = lines[line_idx].strip().split()
            num_vertices, num_faces, _ = map(int, counts)
            line_idx += 1
        
        # Read vertices
        vertices = []
        for i in range(num_vertices):
            while line_idx < len(lines) and (lines[line_idx].startswith('#') or lines[line_idx].strip() == ''):
                line_idx += 1
            vertex = list(map(float, lines[line_idx].strip().split()))
            vertices.append(vertex)
            line_idx += 1
        
        # Read faces
        faces = []
        for i in range(num_faces):
            while line_idx < len(lines) and (lines[line_idx].startswith('#') or lines[line_idx].strip() == ''):
                line_idx += 1
            face_data = list(map(int, lines[line_idx].strip().split()))
            # First number is the number of vertices in the face
            face = face_data[1:face_data[0]+1]
            faces.append(face)
            line_idx += 1
        
        # Create trimesh object
        mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
        
    except Exception as e:
        # Try using trimesh's loader instead
        print(f"Manual parsing failed: {e}, trying trimesh loader")
        try:
            mesh = trimesh.load(off_path, file_type='off')
        except Exception as e2:
            print(f"Trimesh loading failed: {e2}, creating fallback sphere")
            # Fallback to a simple sphere
            mesh = trimesh.creation.icosphere(radius=0.5, subdivisions=2)
    
    # Normalize mesh
    try:
        # Safety check for empty mesh
        if len(mesh.vertices) == 0:
            print("Empty mesh, creating fallback sphere")
            mesh = trimesh.creation.icosphere(radius=0.5, subdivisions=2)
        
        # Center mesh
        center = mesh.bounding_box.centroid
        mesh.vertices -= center
        
        # Scale mesh safely
        extents = mesh.bounding_box.extents
        max_extent = np.max(extents)
        
        if max_extent < 1e-6:  # Very small mesh
            print("Mesh too small, creating fallback sphere")
            mesh = trimesh.creation.icosphere(radius=0.5, subdivisions=2)
        else:
            scale = 2.0 / max_extent  # Scale to [-1, 1]
            mesh.vertices *= scale
    
    except Exception as e:
        print(f"Normalization error: {e}, creating fallback sphere")
        mesh = trimesh.creation.icosphere(radius=0.5, subdivisions=2)
    
    return mesh

def load_modelnet40_models(base_path, categories=None, max_models=3, split='train'):
    """Load models from ModelNet40 dataset"""
    meshes = []
    mesh_paths = []
    
    # Handle categories
    if categories is None:
        categories = [d for d in os.listdir(base_path) if os.path.isdir(os.path.join(base_path, d))]
    
    if isinstance(categories, str):
        categories = [categories]
    
    for category in categories:
        category_path = os.path.join(base_path, category, split)
        if not os.path.exists(category_path):
            print(f"Warning: Path {category_path} does not exist")
            continue
        
        # Get all OFF files
        off_files = [f for f in os.listdir(category_path) if f.endswith('.off')]
        off_files = off_files[:max_models]  # Limit number of models
        
        for off_file in off_files:
            off_path = os.path.join(category_path, off_file)
            try:
                mesh = load_off_file(off_path)
                meshes.append(mesh)
                mesh_paths.append(off_path)
                print(f"Loaded: {off_path}")
            except Exception as e:
                print(f"Error loading {off_path}: {e}")
    
    return meshes, mesh_paths

def sample_points_from_mesh(mesh, n_points=5000):
    """Sample points and compute occupancy values"""
    try:
        # Sample surface points
        surface_points, _ = trimesh.sample.sample_surface(mesh, n_points // 2)
        
        # Add noise to create near-surface points
        near_surface_points = surface_points + np.random.normal(0, 0.01, size=surface_points.shape)
        
        # Generate random volume points
        volume_points = np.random.uniform(-1, 1, size=(n_points // 2, 3))
        
        # Combine all points
        all_points = np.vstack([surface_points, near_surface_points, volume_points])
        
        # Compute occupancy (inside/outside)
        occupancy = np.zeros(len(all_points), dtype=np.float32)
        
        # Process in batches to avoid memory issues
        batch_size = 1000
        for i in tqdm(range(0, len(all_points), batch_size), desc="Computing occupancy"):
            batch = all_points[i:i+batch_size]
            try:
                # Try using signed distance
                if hasattr(mesh, 'proximity'):
                    occupancy[i:i+batch_size] = (mesh.proximity.signed_distance(batch) <= 0).astype(np.float32)
                else:
                    # Fall back to contains test
                    for j, point in enumerate(batch):
                        occupancy[i+j] = float(mesh.contains([point])[0])
            except:
                # If both methods fail, use distance to nearest point on mesh
                for j, point in enumerate(batch):
                    closest_point = mesh.nearest.vertex(point)[1]
                    occupancy[i+j] = float(np.linalg.norm(point - closest_point) < 0.05)
        
        return all_points, occupancy.reshape(-1, 1)
    
    except Exception as e:
        print(f"Error sampling points: {e}, creating fallback points")
        # Create fallback points (sphere)
        points = np.random.uniform(-1, 1, size=(n_points, 3))
        occupancy = (np.linalg.norm(points, axis=1) < 0.5).astype(np.float32).reshape(-1, 1)
        return points, occupancy

# Encode & Decode triplane

## 2: TRIPLANE REPRESENTATION

In [None]:
#===========================================
# PART 2: TRIPLANE REPRESENTATION
#===========================================

class FourierFeatureTransform(nn.Module):
    """Fourier feature mapping"""
    def __init__(self, input_channels, mapping_size, scale=1.0):
        super().__init__()
        self.input_channels = input_channels
        self.mapping_size = mapping_size
        self.B = nn.Parameter(torch.randn((input_channels, mapping_size)) * scale, requires_grad=False)
    
    def forward(self, x):
        # x: [batch_size, n_points, input_channels]
        batch_size, n_points, channels = x.shape
        # Project and reshape
        x = (x.reshape(batch_size * n_points, channels) @ self.B).reshape(batch_size, n_points, -1)
        # Apply sine and cosine
        x = 2 * np.pi * x
        return torch.cat([torch.sin(x), torch.cos(x)], dim=-1)

class MiniTriplane(nn.Module):
    """Triplane representation with MLP decoder"""
    def __init__(self, feature_dim=32, resolution=128, output_dim=1):
        super().__init__()
        
        # Create the three feature planes
        self.embeddings = nn.ParameterList([
            nn.Parameter(torch.randn(1, feature_dim, resolution, resolution) * 0.001)
            for _ in range(3)
        ])
        
        # Decoder network
        self.net = nn.Sequential(
            FourierFeatureTransform(feature_dim, 64, scale=1),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, output_dim),
        )
    
    def sample_plane(self, coords2d, plane):
        """Sample features from a 2D plane safely"""
        # Ensure coords are in valid range
        coords2d = torch.clamp(coords2d, min=-1.0, max=1.0)
        
        # Handle NaN values
        if torch.isnan(coords2d).any():
            coords2d = torch.nan_to_num(coords2d, nan=0.0)
        
        # Reshape for grid_sample
        grid = coords2d.reshape(coords2d.shape[0], 1, -1, 2)
        
        # Ensure plane batch dimension matches grid
        if plane.shape[0] != grid.shape[0]:
            plane = plane.repeat(grid.shape[0], 1, 1, 1)
        
        # Sample features
        sampled = F.grid_sample(
            plane, 
            grid, 
            mode='bilinear', 
            padding_mode='zeros', 
            align_corners=True
        )
        
        # Reshape to [batch_size, n_points, channels]
        batch_size, channels, _, n_points = sampled.shape
        sampled = sampled.reshape(batch_size, channels, n_points).permute(0, 2, 1)
        
        return sampled
    
    def forward(self, coords):
        """Forward pass with safe error handling"""
        # coords: [batch_size, n_points, 3]
        batch_size, n_points, dims = coords.shape
        
        # Ensure coordinates are in range [-1, 1]
        coords = torch.clamp(coords, min=-1.0, max=1.0)
        
        try:
            # Sample from each plane
            xy_features = self.sample_plane(coords[..., 0:2], self.embeddings[0])  # XY plane
            yz_features = self.sample_plane(coords[..., 1:3], self.embeddings[1])  # YZ plane
            xz_features = self.sample_plane(coords[..., ::2], self.embeddings[2])   # XZ plane
            
            # Sum features from all planes
            features = xy_features + yz_features + xz_features
            
            # Process through MLP
            return self.net(features)
            
        except Exception as e:
            print(f"Error in forward pass: {e}")
            # Return zeros as fallback
            return torch.zeros(batch_size, n_points, 1, device=coords.device)

## 3: TRAINING THE TRIPLANE ENCODER

In [None]:
#===========================================
# PART 3: TRAINING THE TRIPLANE ENCODER
#===========================================

def train_encoder(mesh, output_path=None, epochs=200, feature_dim=32, resolution=128, device="cuda"):
    """Train triplane encoder from mesh"""
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Sample points and occupancy
    print("Sampling points from mesh...")
    points, occupancy = sample_points_from_mesh(mesh)
    print(f"Sampled {len(points)} points")
    
    # Create encoder (triplane + MLP)
    encoder = MiniTriplane(feature_dim=feature_dim, resolution=resolution).to(device)
    print("Created encoder")
    
    # Optimizer and loss
    optimizer = optim.Adam(encoder.parameters(), lr=1e-4)
    criterion = nn.BCEWithLogitsLoss()
    
    # Training parameters
    batch_size = min(10000, len(points))
    n_batches = len(points) // batch_size
    
    # Convert data to tensors
    points_tensor = torch.FloatTensor(points).to(device)
    occupancy_tensor = torch.FloatTensor(occupancy).to(device)
    
    # Regularization weights
    lambda_tv = 0.1   # Total variation
    lambda_l2 = 0.01  # L2 regularization
    
    # Training loop
    print("Starting training...")
    losses = []
    
    for epoch in range(epochs):
        epoch_loss = 0
        # Shuffle data
        perm = torch.randperm(len(points_tensor))
        
        for b in range(n_batches):
            try:
                # Get batch
                start_idx = b * batch_size
                end_idx = min((b + 1) * batch_size, len(points_tensor))
                idx = perm[start_idx:end_idx]
                
                batch_points = points_tensor[idx].unsqueeze(0)  # [1, batch_size, 3]
                batch_occupancy = occupancy_tensor[idx]
                
                # Forward pass
                pred_occupancy = encoder(batch_points).squeeze(0)
                
                # Fix NaN values
                if torch.isnan(pred_occupancy).any():
                    pred_occupancy = torch.nan_to_num(pred_occupancy, nan=0.0)
                
                # Reconstruction loss
                recon_loss = criterion(pred_occupancy, batch_occupancy)
                
                # Total variation regularization
                tv_loss = 0
                for plane in encoder.embeddings:
                    tv_loss += torch.mean(torch.abs(plane[:, :, 1:, :] - plane[:, :, :-1, :]))
                    tv_loss += torch.mean(torch.abs(plane[:, :, :, 1:] - plane[:, :, :, :-1]))
                
                # L2 regularization
                l2_loss = sum(torch.sum(plane ** 2) for plane in encoder.embeddings)
                
                # Total loss
                loss = recon_loss + lambda_tv * tv_loss + lambda_l2 * l2_loss
                
                # Backpropagation
                optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(encoder.parameters(), max_norm=1.0)
                optimizer.step()
                
                epoch_loss += loss.item()
                
            except Exception as e:
                print(f"Error in batch {b}: {e}")
                continue
        
        # Average loss for epoch
        avg_loss = epoch_loss / n_batches
        losses.append(avg_loss)
        
        # Print progress
        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
    
    # Save triplane features
    if output_path:
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        features = np.stack([embed.data.cpu().numpy() for embed in encoder.embeddings])
        np.save(output_path, features)
        print(f"Saved triplane features to: {output_path}")
        
        # Visualize features
        visualize_features(features, f"{os.path.splitext(output_path)[0]}_features.png")
    
    print("Training complete!")
    return encoder

def visualize_features(features, save_path=None):
    """Visualize triplane features"""
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    plane_names = ['XY Plane', 'YZ Plane', 'XZ Plane']
    
    for i in range(3):
        # Average across feature channels
        avg_feature = features[i, 0].mean(axis=0)
        
        # Plot
        im = axes[i].imshow(avg_feature, cmap='viridis')
        axes[i].set_title(f"{plane_names[i]}")
        plt.colorbar(im, ax=axes[i])
    
    plt.tight_layout()
    
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        plt.savefig(save_path)
        print(f"Saved feature visualization to {save_path}")
    else:
        plt.show()

## 4: MESH GENERATION

In [None]:
#===========================================
# PART 4: MESH GENERATION
#===========================================

def create_mesh(model, res=128, threshold=0.0, device="cuda"):
    """Extract mesh from trained model"""
    model.eval()
    
    try:
        # Create coordinate grid
        x = torch.linspace(-1, 1, res)
        y = torch.linspace(-1, 1, res)
        z = torch.linspace(-1, 1, res)
        
        X, Y, Z = torch.meshgrid(x, y, z, indexing='ij')
        coords = torch.stack([X, Y, Z], dim=-1).reshape(-1, 3)
        
        # Evaluate in batches
        batch_size = 10000
        predictions = []
        
        with torch.no_grad():
            for i in range(0, coords.shape[0], batch_size):
                batch_coords = coords[i:i+batch_size].to(device).unsqueeze(0)
                batch_output = model(batch_coords).cpu().squeeze()
                predictions.append(batch_output)
        
        # Combine predictions
        all_predictions = torch.cat(predictions).reshape(res, res, res).numpy()
        
        # Extract surface using marching cubes
        vertices, triangles = mcubes.marching_cubes(all_predictions, threshold)
        
        # If no surface found, try different thresholds
        if len(vertices) == 0 or len(triangles) == 0:
            for alt_threshold in [-0.2, 0.2, -0.5, 0.5]:
                vertices, triangles = mcubes.marching_cubes(all_predictions, alt_threshold)
                if len(vertices) > 0 and len(triangles) > 0:
                    break
        
        # Map vertices back to [-1, 1] space
        vertices = vertices / (res - 1) * 2 - 1
        
        return vertices, triangles
    
    except Exception as e:
        print(f"Error creating mesh: {e}")
        # Create a simple sphere as fallback
        phi = np.linspace(0, 2 * np.pi, 30)
        theta = np.linspace(0, np.pi, 20)
        
        x = 0.5 * np.outer(np.cos(phi), np.sin(theta)).flatten()
        y = 0.5 * np.outer(np.sin(phi), np.sin(theta)).flatten()
        z = 0.5 * np.outer(np.ones_like(phi), np.cos(theta)).flatten()
        
        vertices = np.vstack([x, y, z]).T
        
        # Create triangles
        triangles = []
        for i in range(29):
            for j in range(19):
                p1 = i * 20 + j
                p2 = (i + 1) % 30 * 20 + j
                p3 = i * 20 + (j + 1) % 20
                p4 = (i + 1) % 30 * 20 + (j + 1) % 20
                triangles.append([p1, p2, p3])
                triangles.append([p2, p4, p3])
        
        triangles = np.array(triangles)
        
        return vertices, triangles

def save_obj(vertices, triangles, filename):
    """Save mesh as OBJ file"""
    os.makedirs(os.path.dirname(filename), exist_ok=True)
    
    with open(filename, 'w') as f:
        for v in vertices:
            f.write(f'v {v[0]} {v[1]} {v[2]}\n')
        for t in triangles:
            # OBJ uses 1-indexed vertices
            f.write(f'f {t[0]+1} {t[1]+1} {t[2]+1}\n')
    
    print(f"Saved mesh to {filename}")

# Diffusion

## 5: DIFFUSION MODEL

In [None]:
#===========================================
# PART 5: DIFFUSION MODEL
#===========================================

# Diffusion model components
class ResBlock(nn.Module):
    """Residual block with time embedding"""
    def __init__(self, in_channels, out_channels, time_emb_dim):
        super().__init__()
        
        self.norm1 = nn.GroupNorm(8, in_channels)
        self.act1 = nn.SiLU()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        
        # Time embedding projection
        self.time_mlp = nn.Linear(time_emb_dim, out_channels)
        
        self.norm2 = nn.GroupNorm(8, out_channels)
        self.act2 = nn.SiLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        
        # Shortcut connection
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, 1)
        else:
            self.shortcut = nn.Identity()
    
    def forward(self, x, temb):
        h = self.norm1(x)
        h = self.act1(h)
        h = self.conv1(h)
        
        # Add time embedding
        h = h + self.time_mlp(temb)[:, :, None, None]
        
        h = self.norm2(h)
        h = self.act2(h)
        h = self.conv2(h)
        
        return h + self.shortcut(x)

class SelfAttention(nn.Module):
    """Self-attention module"""
    def __init__(self, channels):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.qkv = nn.Conv2d(channels, channels * 3, 1)
        self.proj = nn.Conv2d(channels, channels, 1)
    
    def forward(self, x):
        B, C, H, W = x.shape
        
        h = self.norm(x)
        qkv = self.qkv(h)
        q, k, v = qkv.chunk(3, dim=1)
        
        # Reshape for attention
        q = q.reshape(B, C, H * W).permute(0, 2, 1)  # [B, HW, C]
        k = k.reshape(B, C, H * W)  # [B, C, HW]
        v = v.reshape(B, C, H * W).permute(0, 2, 1)  # [B, HW, C]
        
        # Attention calculation
        scale = 1.0 / (C ** 0.5)
        attn = torch.bmm(q, k) * scale  # [B, HW, HW]
        attn = F.softmax(attn, dim=-1)
        
        out = torch.bmm(attn, v)  # [B, HW, C]
        out = out.permute(0, 2, 1).reshape(B, C, H, W)
        
        return self.proj(out) + x

class Downsample(nn.Module):
    """Downsampling with strided convolution"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, stride=2, padding=1)
    
    def forward(self, x):
        return self.conv(x)

class Upsample(nn.Module):
    """Upsampling with interpolation + convolution"""
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
    
    def forward(self, x):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        return self.conv(x)

def timestep_embedding(timesteps, dim, max_period=10000):
    """Sinusoidal position embedding for timesteps"""
    half = dim // 2
    freqs = torch.exp(-torch.log(torch.tensor(max_period)) * torch.arange(start=0, end=half, dtype=torch.float32) / half)
    args = timesteps[:, None].float() * freqs[None]
    embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    
    if dim % 2:
        embedding = torch.cat([embedding, embedding[:, :1]], dim=-1)
    
    return embedding

class TriplaneUNet(nn.Module):
    """UNet architecture for diffusion model"""
    def __init__(self, in_channels, base_channels=64, channel_mults=(1, 2, 4, 8), time_emb_dim=256):
        super().__init__()
        
        # Time embedding
        self.time_embed = nn.Sequential(
            nn.Linear(base_channels, time_emb_dim),
            nn.SiLU(),
            nn.Linear(time_emb_dim, time_emb_dim)
        )
        
        # Input processing
        self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
        
        # Downsampling path
        self.down_blocks = nn.ModuleList()
        current_channels = base_channels
        down_channels = [current_channels]
        
        for level, mult in enumerate(channel_mults):
            out_channels = base_channels * mult
            
            # Two ResBlocks per level
            for _ in range(2):
                self.down_blocks.append(ResBlock(current_channels, out_channels, time_emb_dim))
                current_channels = out_channels
                down_channels.append(current_channels)
            
            # Downsampling except at last level
            if level < len(channel_mults) - 1:
                self.down_blocks.append(Downsample(current_channels))
                down_channels.append(current_channels)
        
        # Middle blocks
        self.mid_block1 = ResBlock(current_channels, current_channels, time_emb_dim)
        self.mid_attn = SelfAttention(current_channels)
        self.mid_block2 = ResBlock(current_channels, current_channels, time_emb_dim)
        
        # Upsampling path
        self.up_blocks = nn.ModuleList()
        
        for level, mult in reversed(list(enumerate(channel_mults))):
            out_channels = base_channels * mult
            
            # Three ResBlocks per level with skip connections
            for _ in range(3):
                skip_channels = down_channels.pop()
                self.up_blocks.append(ResBlock(current_channels + skip_channels, out_channels, time_emb_dim))
                current_channels = out_channels
            
            # Upsampling except at first level
            if level > 0:
                self.up_blocks.append(Upsample(current_channels))
        
        # Output processing
        self.norm_out = nn.GroupNorm(8, current_channels)
        self.act_out = nn.SiLU()
        self.conv_out = nn.Conv2d(current_channels, in_channels, 3, padding=1)
    
    def forward(self, x, timesteps):
        # Time embedding
        temb = self.time_embed(timestep_embedding(timesteps, self.time_embed[0].in_features))
        
        # Initial convolution
        h = self.input_conv(x)
        
        # Store skip connections
        skips = [h]
        
        # Downsampling
        for module in self.down_blocks:
            if isinstance(module, ResBlock):
                h = module(h, temb)
            else:  # Downsample
                h = module(h)
            skips.append(h)
        
        # Middle blocks
        h = self.mid_block1(h, temb)
        h = self.mid_attn(h)
        h = self.mid_block2(h, temb)
        
        # Upsampling with skip connections
        for module in self.up_blocks:
            if isinstance(module, ResBlock):
                skip = skips.pop()
                h = torch.cat([h, skip], dim=1)
                h = module(h, temb)
            else:  # Upsample
                h = module(h)
        
        # Output
        h = self.norm_out(h)
        h = self.act_out(h)
        h = self.conv_out(h)
        
        return h

In [None]:
class GaussianDiffusion:
    """Diffusion model for triplane generation"""
    def __init__(self, model, timesteps=1000, beta_start=0.0001, beta_end=0.02, device="cuda"):
        self.model = model
        self.timesteps = timesteps
        self.device = device
        
        # Define noise schedule
        self.betas = torch.linspace(beta_start, beta_end, timesteps, device=device)
        self.alphas = 1.0 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # Pre-compute coefficients for diffusion and denoising
        self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
        
        # Posterior variance calculation
        self.posterior_variance = self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        self.posterior_log_variance_clipped = torch.log(
            torch.cat([self.posterior_variance[1:2], self.posterior_variance[1:]])
        )
        self.posterior_mean_coef1 = (
            self.betas * torch.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
        )
        self.posterior_mean_coef2 = (
            (1.0 - self.alphas_cumprod_prev) * torch.sqrt(self.alphas) / (1.0 - self.alphas_cumprod)
        )
    
    def q_sample(self, x_start, t, noise=None):
        """Forward diffusion process: q(x_t | x_0)"""
        if noise is None:
            noise = torch.randn_like(x_start)
        
        # Extract coefficients based on timestep
        sqrt_alphas_cumprod_t = extract(self.sqrt_alphas_cumprod, t, x_start.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape)
        
        # Formula: x_t = sqrt(alpha_cumprod_t) * x_0 + sqrt(1-alpha_cumprod_t) * noise
        return sqrt_alphas_cumprod_t * x_start + sqrt_one_minus_alphas_cumprod_t * noise
    
    def p_losses(self, x_start, t, noise=None):
        """Calculate training loss"""
        if noise is None:
            noise = torch.randn_like(x_start)
        
        # Add noise to input
        x_noisy = self.q_sample(x_start, t, noise=noise)
        
        # Predict noise using the model
        predicted_noise = self.model(x_noisy, t)
        
        # Loss is MSE between actual and predicted noise
        loss = F.mse_loss(predicted_noise, noise)
        
        return loss
    
    def p_mean_variance(self, x, t):
        """Compute mean and variance for the reverse process"""
        # Predict noise
        pred_noise = self.model(x, t)
        
        # Calculate mean
        sqrt_recip_alphas_t = extract(self.sqrt_recip_alphas, t, x.shape)
        sqrt_one_minus_alphas_cumprod_t = extract(self.sqrt_one_minus_alphas_cumprod, t, x.shape)
        
        model_mean = sqrt_recip_alphas_t * (x - sqrt_one_minus_alphas_cumprod_t * pred_noise)
        
        # Calculate variance
        posterior_variance_t = extract(self.posterior_variance, t, x.shape)
        posterior_log_variance_t = extract(self.posterior_log_variance_clipped, t, x.shape)
        
        return model_mean, posterior_variance_t, posterior_log_variance_t
    
    @torch.no_grad()
    def p_sample(self, x, t):
        """Sample from p(x_{t-1} | x_t)"""
        model_mean, _, model_log_variance = self.p_mean_variance(x, t)
        
        # No noise at timestep 0
        if t[0] == 0:
            return model_mean
        
        # Add noise scaled by the variance
        noise = torch.randn_like(x)
        return model_mean + torch.exp(0.5 * model_log_variance) * noise
    
    @torch.no_grad()
    def p_sample_loop(self, shape):
        """Generate samples by iterative denoising"""
        device = self.device
        b = shape[0]
        
        # Start from pure noise
        img = torch.randn(shape, device=device)
        
        # Iteratively denoise
        for i in tqdm(reversed(range(0, self.timesteps)), desc="Sampling", total=self.timesteps):
            img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
        
        return img
    
    @torch.no_grad()
    def sample(self, batch_size=1, resolution=128, feature_channels=32):
        """Generate triplane samples"""
        # Define shape for stacked features
        channels = feature_channels * 3  # Three planes stacked in channel dimension
        shape = (batch_size, channels, resolution, resolution)
        
        # Generate samples
        samples = self.p_sample_loop(shape)
        
        # Reshape to triplane format [B, 3, 1, C, H, W]
        result = []
        for i in range(batch_size):
            # Split channels into three parts (one for each plane)
            planes = torch.split(samples[i], feature_channels, dim=0)
            # Add dimensions to match expected format
            planes = [p.unsqueeze(0) for p in planes]  # [1, C, H, W]
            # Stack planes
            stacked = torch.stack(planes, dim=0)  # [3, 1, C, H, W]
            result.append(stacked)
        
        # Stack batch dimension
        result = torch.stack(result, dim=0)  # [B, 3, 1, C, H, W]
        
        return result
    
    def train_step(self, x_batch, optimizer):
        """Perform one training step"""
        device = self.device
        
        # Select random timesteps
        t = torch.randint(0, self.timesteps, (x_batch.shape[0],), device=device, dtype=torch.long)
        
        # Calculate loss
        loss = self.p_losses(x_batch, t)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        return loss.item()

def extract(a, t, x_shape):
    """Extract appropriate timestep values and reshape for broadcasting"""
    batch_size = t.shape[0]
    out = a.gather(-1, t).reshape(batch_size, *((1,) * (len(x_shape) - 1)))
    return out.expand(x_shape)

class TriplaneDataset(torch.utils.data.Dataset):
    """Dataset for triplane features"""
    def __init__(self, features_list):
        self.features = []
        
        for feature in features_list:
            # Convert to torch tensor if needed
            if isinstance(feature, np.ndarray):
                tensor = torch.from_numpy(feature).float()
            else:
                tensor = feature
            
            # Process based on shape
            if tensor.ndim == 5 and tensor.shape[0] == 3 and tensor.shape[1] == 1:
                # Shape: [3, 1, C, H, W]
                # Stack the three planes along channel dimension
                C, H, W = tensor.shape[2], tensor.shape[3], tensor.shape[4]
                planes = [tensor[i, 0] for i in range(3)]  # 3 x [C, H, W]
                stacked = torch.cat(planes, dim=0)  # [3*C, H, W]
                self.features.append(stacked)
            elif tensor.ndim == 4 and tensor.shape[0] == 3:
                # Shape: [3, C, H, W]
                stacked = torch.cat([tensor[i] for i in range(3)], dim=0)  # [3*C, H, W]
                self.features.append(stacked)
            else:
                print(f"Warning: Unsupported feature shape {tensor.shape}")
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx]

## 6: TRAINING FUNCTIONS

In [None]:
#===========================================
# PART 6: TRAINING FUNCTIONS
#===========================================

def create_triplane_dataset(meshes, mesh_paths, output_dir="triplane_features"):
    """Create triplane features from meshes"""
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "reconstructions"), exist_ok=True)
    
    features_list = []
    
    for i, (mesh, path) in enumerate(zip(meshes, mesh_paths)):
        name = os.path.basename(path).split('.')[0]
        output_path = os.path.join(output_dir, f"{name}.npy")
        
        try:
            print(f"Processing mesh {i+1}/{len(meshes)}: {name}")
            
            # Train encoder for this mesh
            encoder = train_encoder(
                mesh=mesh,
                output_path=output_path,
                epochs=200,
                feature_dim=32,
                resolution=128,
                device="cuda" if torch.cuda.is_available() else "cpu"
            )
            
            # Load features and add to list
            features = np.load(output_path)
            features_list.append(features)
            
            # Create test reconstruction mesh
            try:
                vertices, triangles = create_mesh(
                    encoder, 
                    res=128, 
                    device="cuda" if torch.cuda.is_available() else "cpu"
                )
                save_obj(
                    vertices, 
                    triangles, 
                    os.path.join(output_dir, "reconstructions", f"{name}.obj")
                )
            except Exception as e:
                print(f"Error creating reconstruction mesh: {e}")
                
        except Exception as e:
            print(f"Error processing mesh {name}: {e}")
    
    print(f"Created {len(features_list)} triplane features")
    return features_list

def train_diffusion_model(dataset, epochs=100, batch_size=4, lr=1e-4, device="cuda"):
    """Train the diffusion model on triplane features"""
    # Ensure we have data
    if len(dataset) == 0:
        print("Error: Empty dataset")
        return None, None
    
    # Create dataloader
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        batch_size=min(batch_size, len(dataset)), 
        shuffle=True
    )
    
    # Get input features shape from dataset
    sample = dataset[0]
    in_channels = sample.shape[0]  # [3*C, H, W]
    
    # Create model
    model = TriplaneUNet(in_channels=in_channels).to(device)
    
    # Create diffusion process
    diffusion = GaussianDiffusion(model, device=device)
    
    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=lr)
    
    # Training loop
    for epoch in range(epochs):
        epoch_loss = 0
        batch_count = 0
        
        # Process each batch
        for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
            try:
                # Move to device
                batch = batch.to(device)
                
                # Handle NaN values
                if torch.isnan(batch).any():
                    batch = torch.nan_to_num(batch, nan=0.0)
                
                # Train step
                loss = diffusion.train_step(batch, optimizer)
                epoch_loss += loss
                batch_count += 1
                
            except Exception as e:
                print(f"Error in batch {batch_idx}: {e}")
                continue
        
        # Calculate average loss
        avg_loss = epoch_loss / max(1, batch_count)
        print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
        
        # Save checkpoint
        if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
            checkpoint_dir = "checkpoints"
            os.makedirs(checkpoint_dir, exist_ok=True)
            
            checkpoint_path = os.path.join(checkpoint_dir, f"diffusion_epoch{epoch+1}.pt")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss,
            }, checkpoint_path)
            
            print(f"Saved checkpoint to {checkpoint_path}")
            
            # Generate a test sample
            if epoch >= epochs // 2:  # Only generate samples in the latter half of training
                try:
                    # Shape information from dataset
                    resolution = int(np.sqrt(sample.shape[1]))  # Assuming square resolution
                    feature_channels = in_channels // 3
                    
                    # Generate sample
                    test_sample = diffusion.sample(
                        batch_size=1, 
                        resolution=resolution, 
                        feature_channels=feature_channels
                    )
                    
                    # Save sample
                    sample_dir = "samples"
                    os.makedirs(sample_dir, exist_ok=True)
                    
                    sample_np = test_sample.cpu().numpy()
                    np.save(os.path.join(sample_dir, f"sample_epoch{epoch+1}.npy"), sample_np)
                    
                    # Visualize feature planes
                    visualize_features(
                        sample_np[0], 
                        save_path=os.path.join(sample_dir, f"sample_epoch{epoch+1}.png")
                    )
                    
                    # Create mesh from sample
                    try:
                        # Create decoder
                        decoder = MiniTriplane(
                            feature_dim=feature_channels, 
                            resolution=resolution
                        ).to(device)
                        
                        # Load generated features
                        for j in range(3):
                            decoder.embeddings[j].data = test_sample[0, j, 0].to(device)
                        
                        # Generate mesh
                        vertices, triangles = create_mesh(decoder, res=128, device=device)
                        
                        # Save OBJ
                        save_obj(
                            vertices, 
                            triangles, 
                            os.path.join(sample_dir, f"sample_epoch{epoch+1}.obj")
                        )
                    except Exception as e:
                        print(f"Error creating mesh from sample: {e}")
                
                except Exception as e:
                    print(f"Error generating sample: {e}")
    
    return model, diffusion

# Experiment

## 7: FULL PIPELINE

In [None]:
#===========================================
# PART 7: FULL PIPELINE
#===========================================

def run_modelnet40_pipeline(modelnet40_path, output_dir="results", 
                            categories=["airplane"], max_models=3, 
                            epochs=50, device="cuda"):
    """Complete pipeline from ModelNet40 to generated 3D models"""
    os.makedirs(output_dir, exist_ok=True)
    
    print("=== Step 1: Loading models from ModelNet40 ===")
    meshes, mesh_paths = load_modelnet40_models(
        modelnet40_path,
        categories=categories,
        max_models=max_models,
        split='train'
    )
    
    # Check if we have models
    if len(meshes) == 0:
        print("No models loaded! Adding fallback shapes")
        
        # Add simple shapes as fallback
        box = trimesh.creation.box(extents=[1.0, 1.0, 1.0])
        sphere = trimesh.creation.icosphere(radius=0.5)
        cylinder = trimesh.creation.cylinder(radius=0.5, height=1.0)
        
        meshes = [box, sphere, cylinder]
        mesh_paths = ["box.obj", "sphere.obj", "cylinder.obj"]
    
    print(f"Loaded {len(meshes)} models")
    
    print("\n=== Step 2: Creating triplane features ===")
    triplane_dir = os.path.join(output_dir, "triplane_features")
    features_list = create_triplane_dataset(meshes, mesh_paths, triplane_dir)
    
    # Check if we have features
    if len(features_list) == 0:
        print("Failed to create triplane features!")
        return None, None
    
    print("\n=== Step 3: Creating dataset and training diffusion model ===")
    dataset = TriplaneDataset(features_list)
    
    print(f"Dataset size: {len(dataset)}")
    print(f"Sample shape: {dataset[0].shape}")
    
    # Train diffusion model
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    model, diffusion = train_diffusion_model(
        dataset,
        epochs=epochs,
        batch_size=min(4, len(dataset)),
        device=device
    )
    
    if model is None or diffusion is None:
        print("Training failed!")
        return None, None
    
    print("\n=== Step 4: Generating new 3D models ===")
    # Get shape information from dataset
    sample = dataset[0]
    resolution = int(np.sqrt(sample.shape[1]))  # Assuming square resolution
    feature_channels = sample.shape[0] // 3
    
    # Generate samples
    n_samples = 5
    samples_dir = os.path.join(output_dir, "generated_samples")
    os.makedirs(samples_dir, exist_ok=True)
    
    for i in range(n_samples):
        try:
            print(f"Generating sample {i+1}/{n_samples}")
            
            # Generate triplane features
            sample = diffusion.sample(
                batch_size=1,
                resolution=resolution,
                feature_channels=feature_channels
            )
            
            # Save features
            sample_np = sample.cpu().numpy()
            np.save(os.path.join(samples_dir, f"sample_{i}.npy"), sample_np)
            
            # Visualize features
            visualize_features(
                sample_np[0],
                save_path=os.path.join(samples_dir, f"sample_{i}_features.png")
            )
            
            # Generate mesh
            decoder = MiniTriplane(
                feature_dim=feature_channels,
                resolution=resolution
            ).to(device)
            
            # Load features
            for j in range(3):
                decoder.embeddings[j].data = sample[0, j, 0].to(device)
            
            # Create mesh
            vertices, triangles = create_mesh(decoder, res=128, device=device)
            
            # Save OBJ
            save_obj(
                vertices,
                triangles,
                os.path.join(samples_dir, f"sample_{i}.obj")
            )
            
        except Exception as e:
            print(f"Error generating sample {i}: {e}")
    
    print(f"\nPipeline complete! Results saved to {output_dir}")
    return model, diffusion

In [None]:
# Fixed sample method for the GaussianDiffusion class
@torch.no_grad()
def sample_fixed(self, batch_size=1, resolution=128, feature_channels=32):
    """Generate triplane samples with proper dimension handling"""
    # Define shape for stacked features
    channels = feature_channels * 3  # Three planes stacked in channel dimension
    shape = (batch_size, channels, resolution, resolution)
    
    # Generate samples
    samples = self.p_sample_loop(shape)
    
    # Reshape to triplane format [B, 3, 1, C, H, W]
    result = []
    for i in range(batch_size):
        # Get sample for this batch item
        sample = samples[i]  # [3*C, H, W]
        
        # Split into three equal chunks for each plane
        # Make sure the split is done correctly along the first dimension
        chunk_size = sample.shape[0] // 3
        planes = torch.chunk(sample, 3, dim=0)  # 3 x [C, H, W]
        
        # Reshape and add batch dimension for each plane
        planes_processed = []
        for plane in planes:
            # Add extra dimensions to match expected format [1, C, H, W]
            plane_expanded = plane.unsqueeze(0)
            planes_processed.append(plane_expanded)
        
        # Stack along new dimension to get [3, 1, C, H, W]
        stacked = torch.stack(planes_processed, dim=0)
        result.append(stacked)
    
    # Stack batch dimension
    result = torch.stack(result, dim=0)  # [B, 3, 1, C, H, W]
    
    return result

# Fixed function to use the sample in the training pipeline
def generate_samples_fixed(model, diffusion, output_dir, n_samples=5, resolution=128, feature_channels=32, device="cuda"):
    """Generate samples with proper error handling and dimension fixes"""
    samples_dir = os.path.join(output_dir, "generated_samples")
    os.makedirs(samples_dir, exist_ok=True)
    
    for i in range(n_samples):
        try:
            print(f"Generating sample {i+1}/{n_samples}")
            
            # Generate triplane features using fixed sampling function
            sample = diffusion.sample_fixed(
                batch_size=1,
                resolution=resolution,
                feature_channels=feature_channels
            )
            
            # Debug info
            print(f"Generated sample shape: {sample.shape}")
            
            # Save features
            sample_np = sample.cpu().numpy()
            np.save(os.path.join(samples_dir, f"sample_{i}.npy"), sample_np)
            
            # Visualize features
            visualize_features(
                sample_np[0],
                save_path=os.path.join(samples_dir, f"sample_{i}_features.png")
            )
            
            # Generate mesh
            decoder = MiniTriplane(
                feature_dim=feature_channels,
                resolution=resolution
            ).to(device)
            
            # Carefully load features into decoder with shape checking
            for j in range(3):
                if sample[0, j].dim() == 3 and sample[0, j].shape[0] == 1:
                    # Shape is [1, C, H, W]
                    decoder.embeddings[j].data = sample[0, j].to(device)
                else:
                    # Try to reshape if needed
                    print(f"Reshaping plane {j} from shape {sample[0, j].shape}")
                    decoder.embeddings[j].data = sample[0, j].reshape(1, feature_channels, resolution, resolution).to(device)
            
            # Create mesh
            vertices, triangles = create_mesh(decoder, res=128, device=device)
            
            # Save OBJ
            save_obj(
                vertices,
                triangles,
                os.path.join(samples_dir, f"sample_{i}.obj")
            )
            
        except Exception as e:
            print(f"Error generating sample {i}: {e}")
            # Print full traceback for debugging
            import traceback
            traceback.print_exc()

# Fixed run_modelnet40_pipeline function
def run_modelnet40_pipeline_fixed(modelnet40_path, output_dir="results", 
                                 categories=["airplane"], max_models=3, 
                                 epochs=50, device="cuda"):
    """Fixed pipeline from ModelNet40 to generated 3D models"""
    os.makedirs(output_dir, exist_ok=True)
    
    print("=== Step 1: Loading models from ModelNet40 ===")
    meshes, mesh_paths = load_modelnet40_models(
        modelnet40_path,
        categories=categories,
        max_models=max_models,
        split='train'
    )
    
    # Check if we have models
    if len(meshes) == 0:
        print("No models loaded! Adding fallback shapes")
        
        # Add simple shapes as fallback
        box = trimesh.creation.box(extents=[1.0, 1.0, 1.0])
        sphere = trimesh.creation.icosphere(radius=0.5)
        cylinder = trimesh.creation.cylinder(radius=0.5, height=1.0)
        
        meshes = [box, sphere, cylinder]
        mesh_paths = ["box.obj", "sphere.obj", "cylinder.obj"]
    
    print(f"Loaded {len(meshes)} models")
    
    print("\n=== Step 2: Creating triplane features ===")
    triplane_dir = os.path.join(output_dir, "triplane_features")
    features_list = create_triplane_dataset(meshes, mesh_paths, triplane_dir)
    
    # Check if we have features
    if len(features_list) == 0:
        print("Failed to create triplane features!")
        return None, None
    
    print("\n=== Step 3: Creating dataset and training diffusion model ===")
    dataset = TriplaneDataset(features_list)
    
    print(f"Dataset size: {len(dataset)}")
    print(f"Sample shape: {dataset[0].shape}")
    
    # Train diffusion model
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    model, diffusion = train_diffusion_model(
        dataset,
        epochs=epochs,
        batch_size=min(4, len(dataset)),
        device=device
    )
    
    if model is None or diffusion is None:
        print("Training failed!")
        return None, None
    
    # Monkey patch the sample method with our fixed version
    diffusion.sample_fixed = lambda batch_size=1, resolution=128, feature_channels=32: sample_fixed(
        diffusion, batch_size, resolution, feature_channels
    )
    
    print("\n=== Step 4: Generating new 3D models ===")
    # Get shape information from dataset
    sample = dataset[0]
    resolution = int(np.sqrt(sample.shape[1]))  # Assuming square resolution
    feature_channels = sample.shape[0] // 3
    
    # Generate samples using fixed function
    generate_samples_fixed(
        model, 
        diffusion, 
        output_dir, 
        n_samples=5, 
        resolution=resolution, 
        feature_channels=feature_channels,
        device=device
    )
    
    print(f"\nPipeline complete! Results saved to {output_dir}")
    return model, diffusion

In [None]:
# 修复GaussianDiffusion类中的sample方法
@torch.no_grad()
def fixed_sample(self, batch_size=1, resolution=128, feature_channels=32):
    """Modified sampling method that ensures compatible dimensions for UNet architecture"""
    # Ensure resolution is a power of 2 to avoid dimension mismatches in UNet
    # Find the closest power of 2 that is at least as large as the requested resolution
    power_of_2 = 2 ** (resolution - 1).bit_length()
    if power_of_2 != resolution:
        print(f"Warning: Adjusting resolution from {resolution} to {power_of_2} (power of 2) for UNet compatibility")
        resolution = power_of_2
    
    # Define shape for diffusion model output (channels for all three planes combined)
    channels = feature_channels * 3
    shape = (batch_size, channels, resolution, resolution)
    
    # Generate samples from noise
    print(f"Generating sample with shape {shape}")
    samples = self.p_sample_loop(shape)
    print(f"Generated raw samples with shape {samples.shape}")
    
    # Reshape results into triplane format
    results = []
    for i in range(batch_size):
        # Get current sample
        sample = samples[i]  # [3*C, H, W]
        print(f"Processing sample {i} with shape {sample.shape}")
        
        # Split channels evenly into three parts for each plane
        chunk_size = sample.shape[0] // 3
        planes = torch.split(sample, chunk_size, dim=0)  # 3 x [C, H, W]
        
        # Debug info
        print(f"Split into {len(planes)} planes of shape {planes[0].shape}")
        
        # Process each plane to ensure correct dimensions
        processed_planes = []
        for plane in planes:
            # Add batch dimension to get [1, C, H, W]
            plane = plane.unsqueeze(0)
            processed_planes.append(plane)
        
        # Stack processed planes to get [3, 1, C, H, W]
        stacked = torch.stack(processed_planes, dim=0)
        print(f"Stacked planes shape: {stacked.shape}")
        results.append(stacked)
    
    # Stack all batch samples
    final_result = torch.stack(results, dim=0)  # [B, 3, 1, C, H, W]
    print(f"Final output shape: {final_result.shape}")
    
    return final_result
# 单独用于生成样本的函数
def generate_samples(model, diffusion, output_dir, n_samples=5, resolution=128, feature_channels=32, device="cuda"):
    """Generate samples from diffusion model and create 3D models with error handling"""
    samples_dir = os.path.join(output_dir, "generated_samples")
    os.makedirs(samples_dir, exist_ok=True)
    
    # Replace the sample method with our fixed version
    diffusion.sample = fixed_sample.__get__(diffusion, type(diffusion))
    
    for i in range(n_samples):
        try:
            print(f"\nGenerating sample {i+1}/{n_samples}")
            
            # Generate triplane features
            sample = diffusion.sample(
                batch_size=1,
                resolution=resolution,
                feature_channels=feature_channels
            )
            
            # Save and visualize features
            sample_np = sample.cpu().numpy()
            np.save(os.path.join(samples_dir, f"sample_{i}.npy"), sample_np)
            
            print(f"Feature shape: {sample_np.shape}")
            
            try:
                # Attempt to visualize
                if hasattr(sample_np, 'shape') and len(sample_np.shape) >= 5:
                    visualize_features(
                        sample_np[0],
                        save_path=os.path.join(samples_dir, f"sample_{i}_features.png")
                    )
                else:
                    print(f"Cannot visualize features with shape {sample_np.shape}")
            except Exception as e:
                print(f"Visualization error: {e}")
            
            # Create decoder
            decoder = MiniTriplane(
                feature_dim=feature_channels,
                resolution=resolution
            ).to(device)
            
            # Load features into decoder
            for j in range(3):
                try:
                    if sample.dim() >= 5 and sample.shape[2] == 1:
                        # If shape is [B, 3, 1, C, H, W]
                        decoder.embeddings[j].data = sample[0, j].to(device)
                    else:
                        # If shape doesn't match, try to reshape
                        print(f"Reshaping plane {j}, original shape: {sample[0, j].shape}")
                        plane_data = sample[0, j].reshape(1, feature_channels, resolution, resolution)
                        decoder.embeddings[j].data = plane_data.to(device)
                except Exception as e:
                    print(f"Failed to load plane {j}: {e}")
                    # Use random features as fallback
                    decoder.embeddings[j].data = torch.randn(1, feature_channels, resolution, resolution, device=device) * 0.01
            
            # Create mesh
            try:
                vertices, triangles = create_mesh(decoder, res=128, device=device)
                
                # Save OBJ file
                obj_path = os.path.join(samples_dir, f"sample_{i}.obj")
                save_obj(vertices, triangles, obj_path)
                print(f"Saved OBJ file to: {obj_path}")
            except Exception as e:
                print(f"Failed to create mesh: {e}")
                # Print full error info
                import traceback
                traceback.print_exc()
                
        except Exception as e:
            print(f"Failed to generate sample {i}: {e}")
            # Print full error info
            import traceback
            traceback.print_exc()
    
    print(f"Completed generating {n_samples} samples")

### Fixed: TriplaneDataset

In [None]:
class FixedTriplaneDataset(torch.utils.data.Dataset):
    def __init__(self, features_list, target_resolution=None):
        self.features = []
        
        for feature in features_list:
            # Convert to torch tensor if needed
            if isinstance(feature, np.ndarray):
                tensor = torch.from_numpy(feature).float()
            else:
                tensor = feature
            
            # Process based on shape
            if tensor.ndim == 5 and tensor.shape[0] == 3 and tensor.shape[1] == 1:
                # Shape: [3, 1, C, H, W]
                C, H, W = tensor.shape[2], tensor.shape[3], tensor.shape[4]
                
                # Ensure H and W are compatible with UNet (optional resize)
                if target_resolution is not None and (H != target_resolution or W != target_resolution):
                    # Resize each plane
                    resized_planes = []
                    for i in range(3):
                        plane = tensor[i, 0]  # [C, H, W]
                        plane = F.interpolate(
                            plane.unsqueeze(0),  # Add batch dim: [1, C, H, W]
                            size=(target_resolution, target_resolution),
                            mode='bilinear',
                            align_corners=True
                        ).squeeze(0)  # Remove batch dim: [C, target_resolution, target_resolution]
                        resized_planes.append(plane)
                    
                    # Stack resized planes
                    stacked = torch.cat(resized_planes, dim=0)  # [3*C, target_resolution, target_resolution]
                else:
                    # Keep original size
                    planes = [tensor[i, 0] for i in range(3)]  # 3 x [C, H, W]
                    stacked = torch.cat(planes, dim=0)  # [3*C, H, W]
                
                self.features.append(stacked)
                
            elif tensor.ndim == 4 and tensor.shape[0] == 3:
                # Shape: [3, C, H, W]
                C, H, W = tensor.shape[1], tensor.shape[2], tensor.shape[3]
                
                # Ensure H and W are compatible with UNet (optional resize)
                if target_resolution is not None and (H != target_resolution or W != target_resolution):
                    # Resize each plane
                    resized_planes = []
                    for i in range(3):
                        plane = tensor[i]  # [C, H, W]
                        plane = F.interpolate(
                            plane.unsqueeze(0),  # Add batch dim: [1, C, H, W]
                            size=(target_resolution, target_resolution),
                            mode='bilinear',
                            align_corners=True
                        ).squeeze(0)  # Remove batch dim: [C, target_resolution, target_resolution]
                        resized_planes.append(plane)
                    
                    # Stack resized planes
                    stacked = torch.cat(resized_planes, dim=0)  # [3*C, target_resolution, target_resolution]
                else:
                    # Keep original size
                    stacked = torch.cat([tensor[i] for i in range(3)], dim=0)  # [3*C, H, W]
                
                self.features.append(stacked)
            else:
                print(f"Warning: Unsupported feature shape {tensor.shape}")
    
    def __len__(self):
        return len(self.features)
    
    def __getitem__(self, idx):
        return self.features[idx]

In [None]:
def triplane_diffusion_pipeline(modelnet40_path, output_dir="results", 
                                    categories=["airplane"], max_models=1, 
                                    epochs=30, device="cuda"):
    """Complete triplane diffusion pipeline with fixes for resolution and sampling"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Step 1: Load ModelNet40 models
    print("=== Step 1: Loading ModelNet40 models ===")
    meshes, mesh_paths = load_modelnet40_models(
        modelnet40_path,
        categories=categories,
        max_models=max_models,
        split='train'
    )
    
    # Check if we have models
    if len(meshes) == 0:
        print("No models loaded, creating default shapes")
        # Create default shapes
        box = trimesh.creation.box(extents=[1.0, 1.0, 1.0])
        sphere = trimesh.creation.icosphere(radius=0.5)
        
        meshes = [box, sphere]
        mesh_paths = ["box.obj", "sphere.obj"]
    
    print(f"Loaded {len(meshes)} models")
    
    # Step 2: Create triplane features
    print("\n=== Step 2: Creating triplane features ===")
    triplane_dir = os.path.join(output_dir, "triplane_features")
    features_list = []
    
    # Create triplane features for each model
    for i, (mesh, path_or_name) in enumerate(zip(meshes, mesh_paths)):
        name = os.path.basename(path_or_name).split('.')[0]
        output_path = os.path.join(triplane_dir, f"{name}.npy")
        
        try:
            print(f"Processing model {i+1}/{len(meshes)}: {name}")
            
            # Train encoder
            encoder = train_encoder(
                mesh=mesh,
                output_path=output_path,
                epochs=200,
                feature_dim=32,
                resolution=128,  # Use power of 2 for resolution
                device=device
            )
            
            # Load features
            features = np.load(output_path)
            features_list.append(features)
            
            # Create reconstruction mesh for validation
            try:
                recon_dir = os.path.join(triplane_dir, "reconstructions")
                os.makedirs(recon_dir, exist_ok=True)
                
                vertices, triangles = create_mesh(encoder, res=128, device=device)
                save_obj(
                    vertices, 
                    triangles, 
                    os.path.join(recon_dir, f"{name}_recon.obj")
                )
            except Exception as e:
                print(f"Failed to create reconstruction mesh: {e}")
        
        except Exception as e:
            print(f"Failed to process model {name}: {e}")
    
    if not features_list:
        print("Failed to create triplane features!")
        return None, None
    
    print(f"Created {len(features_list)} triplane features")
    
    # Step 3: Create dataset and train diffusion model
    print("\n=== Step 3: Training diffusion model ===")
    
    # Define target resolution that works with UNet (power of 2)
    target_resolution = 128  # This ensures compatibility with the UNet architecture
    
    # Create fixed dataset with compatible resolution
    dataset = FixedTriplaneDataset(features_list, target_resolution=target_resolution)
    print(f"Dataset size: {len(dataset)}")
    if len(dataset) > 0:
        print(f"Sample shape: {dataset[0].shape}")
    
    # Train diffusion model
    device = torch.device(device if torch.cuda.is_available() else "cpu")
    
    # Function to train diffusion model
    def train_diffusion_model_fixed(dataset, epochs=30, batch_size=1, lr=1e-4, device="cuda"):
        """Train diffusion model with fixed sample method"""
        if len(dataset) == 0:
            print("Error: Empty dataset")
            return None, None
        
        # Create data loader
        dataloader = torch.utils.data.DataLoader(
            dataset, 
            batch_size=min(batch_size, len(dataset)), 
            shuffle=True
        )
        
        # Get input feature shape
        sample = dataset[0]
        in_channels = sample.shape[0]
        print(f"Input channels: {in_channels}")
        
        # Create model
        model = TriplaneUNet(in_channels=in_channels).to(device)
        
        # Create diffusion process
        diffusion = GaussianDiffusion(model, device=device)
        
        # Optimizer
        optimizer = optim.Adam(model.parameters(), lr=lr)
        
        # Training loop
        for epoch in range(epochs):
            epoch_loss = 0
            batch_count = 0
            
            # Process each batch
            for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")):
                try:
                    # Move to device
                    batch = batch.to(device)
                    
                    # Handle NaN values
                    if torch.isnan(batch).any():
                        batch = torch.nan_to_num(batch, nan=0.0)
                    
                    # Training step
                    loss = diffusion.train_step(batch, optimizer)
                    epoch_loss += loss
                    batch_count += 1
                    
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue
            
            # Calculate average loss
            avg_loss = epoch_loss / max(1, batch_count)
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.6f}")
            
            # Save checkpoint
            if (epoch + 1) % 10 == 0 or epoch == epochs - 1:
                checkpoint_dir = os.path.join(output_dir, "checkpoints")
                os.makedirs(checkpoint_dir, exist_ok=True)
                
                checkpoint_path = os.path.join(checkpoint_dir, f"diffusion_epoch{epoch+1}.pt")
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'loss': avg_loss,
                }, checkpoint_path)
                
                print(f"Saved checkpoint to {checkpoint_path}")
        
        # Replace sample method with fixed version
        diffusion.sample = fixed_sample.__get__(diffusion, type(diffusion))
        
        return model, diffusion
    
    # Train model
    model, diffusion = train_diffusion_model_fixed(
        dataset,
        epochs=epochs,
        batch_size=min(4, len(dataset)),
        device=device
    )
    
    if model is None or diffusion is None:
        print("Training failed!")
        return None, None
    
    # Step 4: Generate new 3D models
    print("\n=== Step 4: Generating new 3D models ===")
    
    # Get shape information
    feature_channels = 32  # Fixed feature channels
    resolution = target_resolution  # Use the same resolution as training
    
    print(f"Using resolution: {resolution}, feature channels: {feature_channels}")
    
    # Generate samples
    generate_samples(
        model,
        diffusion,
        output_dir,
        n_samples=5,
        resolution=resolution,
        feature_channels=feature_channels,
        device=device
    )
    
    print(f"\nPipeline complete! Results saved to {output_dir}")
    return model, diffusion

In [None]:

if __name__ == "__main__":
    # ModelNet40 Dataset
    modelnet40_path = "/kaggle/input/modelnet40-princeton-3d-object-dataset/ModelNet40"
    
    model, diffusion = triplane_diffusion_pipeline(
        modelnet40_path,
        output_dir="fixed_results",
        categories=["airplane"],
        max_models=1,
        epochs=30
    )