In [1]:
# extension to agent copy 4.ipynb F1-score, pos weight
from typing import Tuple, List
from dataclasses import dataclass
from tqdm import tqdm
from datetime import timedelta
import zipfile
import shutil
import tempfile
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import os
import glob
from torch.utils.data import Subset, DataLoader

In [20]:
class VoxelDataLoader:
    """Loads and processes NPZ voxel data from a zip file"""

    def __init__(self, zip_path: str):
        # Create a temporary directory
        self.temp_dir = tempfile.mkdtemp()
        print(f"Created temporary directory: {self.temp_dir}")

        # Extract zip file
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.temp_dir)
        print(f"Extracted zip file to temporary directory")

        # Find all NPZ files
        all_files = glob.glob(os.path.join(self.temp_dir, "**/*.npz"), recursive=True)
        print(f"Found {len(all_files)} total NPZ files")

        if len(all_files) == 0:
            raise ValueError(f"No NPZ files found in zip file")

        random.shuffle(all_files)  # Shuffle before splitting
        cutoff = int(len(all_files))
        self.npz_files = all_files[:cutoff]
        print(f"Using {len(self.npz_files)}")

    def __del__(self):
        """Cleanup temporary directory when object is destroyed"""
        try:
            shutil.rmtree(self.temp_dir)
            print(f"Cleaned up temporary directory: {self.temp_dir}")
        except:
            print(f"Failed to clean up temporary directory: {self.temp_dir}")

    def load_single_file(self, file_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
        data = np.load(file_path)

        # More robust key checking
        if 'complete' not in data or 'partial' not in data:
            raise ValueError(f"NPZ file {file_path} must contain both 'complete' and 'partial' arrays")

        complete = torch.from_numpy(data['complete']).float()
        partial = torch.from_numpy(data['partial']).float()

        # Verify shapes match
        if complete.shape != partial.shape:
            raise ValueError(f"Shape mismatch in {file_path}: complete {complete.shape} vs partial {partial.shape}")

        return complete, partial

    def get_all_data(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Load all voxel pairs from all NPZ files"""
        all_data = []
        for file_path in self.npz_files:
            complete, partial = self.load_single_file(file_path)
            all_data.append((complete, partial))
        return all_data

    def get_voxel_grids(self, index: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns complete and partial voxel grids from a specific file"""
        if index >= len(self.npz_files):
            raise IndexError(f"Index {index} out of range. Only {len(self.npz_files)} files available.")
        return self.load_single_file(self.npz_files[index])


class VoxelDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for voxel completion"""

    def __init__(self, zip_path: str, transform=None):
        self.data_loader = VoxelDataLoader(zip_path)
        self.transform = transform

    def __len__(self):
        return len(self.data_loader.npz_files)

    def __getitem__(self, idx):
        complete, partial = self.data_loader.get_voxel_grids(idx)
        # Normalize to [0,1] if not already
        complete = (complete > 0).float()
        partial = (partial > 0).float()
        if self.transform:
            complete, partial = self.transform(complete, partial)
        return complete, partial


# Update data loader creation function
def create_data_loader(zip_path: str, batch_size: int = 1, shuffle: bool = True, num_workers: int = 0):
    """Create a PyTorch DataLoader for training"""
    dataset = VoxelDataset(zip_path)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        # pin_memory=True
    )


def split_dataset(dataset, train_ratio=0.8, val_ratio=0.2, seed=42):
    n = len(dataset)
    indices = list(range(n))
    random.Random(seed).shuffle(indices)

    # from dataset: 80% train 20% test 
    n_trainval = int(n * 0.8)
    n_test = n - n_trainval
    trainval_indices = indices[:n_trainval]
    test_indices = indices[n_trainval:]
    # from training data: 80% train 20% validation
    n_train = int(len(trainval_indices) * 0.8)
    train_indices = trainval_indices[:n_train]
    val_indices = trainval_indices[n_train:]

    return train_indices, val_indices, test_indices

def create_data_loaders(zip_path, batch_size=1, shuffle=True, num_workers=0, seed=42):
    dataset = VoxelDataset(zip_path)
    print(f"Dataset size: {len(dataset)}")
    train_idx, val_idx, test_idx = split_dataset(dataset, seed=seed)
    train_loader = DataLoader(Subset(dataset, train_idx), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    val_loader = DataLoader(Subset(dataset, val_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(Subset(dataset, test_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

class LocalAttention(nn.Module):
    """
    3D local attention: attends to a 3D window [ws, ws, ws] around the target voxel, preserving spatial structure.
    Input: target_embedding [B, d_model], neighbor_embeddings [B, ws, ws, ws, d_model]
    Output: attended_embedding [B, d_model]
    """
    def __init__(self, d_model: int, num_heads: int = 4, window_size: int = 3):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.window_size = window_size
        assert d_model % num_heads == 0
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.scale = self.head_dim ** -0.5

    def forward(self, target_embedding, neighbor_embeddings):
        # target_embedding: [B, d_model]
        # neighbor_embeddings: [B, ws, ws, ws, d_model]
        B, ws1, ws2, ws3, d_model = neighbor_embeddings.shape
        N = ws1 * ws2 * ws3
        Q = self.q_proj(target_embedding).view(B, self.num_heads, self.head_dim)  # [B, num_heads, head_dim]
        K = self.k_proj(neighbor_embeddings).view(B, ws1, ws2, ws3, self.num_heads, self.head_dim)  # [B, ws, ws, ws, num_heads, head_dim]
        V = self.v_proj(neighbor_embeddings).view(B, ws1, ws2, ws3, self.num_heads, self.head_dim)  # [B, ws, ws, ws, num_heads, head_dim]
        # Center voxel is at (ws//2, ws//2, ws//2)
        Q = Q.unsqueeze(1).unsqueeze(1).unsqueeze(1)  # [B,1,1,1,num_heads,head_dim]
        # Compute attention scores: [B, ws, ws, ws, num_heads]
        attn_scores = (Q * K).sum(-1) * self.scale  # [B, ws, ws, ws, num_heads]
        attn_scores = attn_scores.permute(0, 4, 1, 2, 3)  # [B, num_heads, ws, ws, ws]
        attn_weights = attn_scores.reshape(B, self.num_heads, -1).softmax(-1).reshape(B, self.num_heads, ws1, ws2, ws3)
        # Weighted sum over 3D window
        V = V.permute(0, 4, 1, 2, 3, 5)  # [B, num_heads, ws, ws, ws, head_dim]
        attn_weights = attn_weights.unsqueeze(-1)  # [B, num_heads, ws, ws, ws, 1]
        attn_out = (attn_weights * V).sum(dim=(2,3,4))  # [B, num_heads, head_dim]
        attn_out = attn_out.reshape(B, self.d_model)
        return self.out_proj(attn_out)

class VoxelTransformerLayerVoxelwise(nn.Module):
    def __init__(self, d_model: int, num_heads: int = 4, window_size: int = 3, dropout: float = 0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn = LocalAttention(d_model, num_heads, window_size)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
    def forward(self, target_embedding, neighbor_embeddings):
        x = self.norm1(target_embedding)
        attn_out = self.attn(x, neighbor_embeddings)
        x = x + attn_out
        x2 = self.norm2(x)
        ffn_out = self.ffn(x2)
        return x + ffn_out

class PositionalEncoding3D(nn.Module):
    def __init__(self, d_model: int, max_grid_size: int = 32):
        super().__init__()
        self.d_model = d_model
        self.max_grid_size = max_grid_size
        
        # Create learnable positional embeddings
        self.pos_embed = nn.Parameter(
            torch.zeros(max_grid_size, max_grid_size, max_grid_size, d_model)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)
    
    def forward(self, x):
        # x: input tensor with shape [B, D, H, W] or similar
        # Extract spatial dimensions
        *_, D, H, W = x.shape
        return self.pos_embed[:D, :H, :W, :]  # [D, H, W, d_model]
    
    def get_encoding(self, D, H, W):
        """Get positional encoding for specific dimensions"""
        return self.pos_embed[:D, :H, :W, :]  # [D, H, W, d_model]


class VoxelCompletionTransformerVoxelwise(nn.Module):
    def __init__(self, d_model=64, num_heads=4, num_layers=4, window_size=3, dropout=0.1, max_grid_size=32):
        super().__init__()
        self.d_model = d_model
        self.window_size = window_size
        self.input_proj = nn.Linear(1, d_model)
        self.pos_encoding = PositionalEncoding3D(d_model, max_grid_size)
        self.layers = nn.ModuleList([
            VoxelTransformerLayerVoxelwise(d_model, num_heads, window_size, dropout)
            for _ in range(num_layers)
        ])
        self.output_proj = nn.Linear(d_model, 1)
    
    def extract_neighborhood(self, grid, idx):
        """Extract neighborhood windows around target positions"""
        # grid: [B, D, H, W], idx: [B, 3] (x, y, z)
        B = grid.shape[0]
        ws = self.window_size
        pad = ws // 2
        
        # Pad the grid
        padded = F.pad(grid, (pad, pad, pad, pad, pad, pad), mode='constant', value=0)  # [B, D+2p, H+2p, W+2p]
        
        # Extract neighborhoods for all batch items
        neighborhoods = []
        for b in range(B):
            x, y, z = idx[b]
            x, y, z = x + pad, y + pad, z + pad  # Account for padding
            
            # Extract the window
            neighborhood = padded[b, x-pad:x+pad+1, y-pad:y+pad+1, z-pad:z+pad+1]  # [ws, ws, ws]
            neighborhoods.append(neighborhood.unsqueeze(-1))  # [ws, ws, ws, 1]
        
        neighborhoods = torch.stack(neighborhoods, dim=0)  # [B, ws, ws, ws, 1]
        return neighborhoods
    
    def extract_neighborhood_positions(self, pos_enc_grid, target_idx):
        """Extract positional encodings for neighborhood windows around target positions"""
        B = target_idx.shape[0]
        ws = self.window_size
        pad = ws // 2
        
        # Pad the positional encoding grid
        pos_enc_padded = F.pad(pos_enc_grid, (0, 0, pad, pad, pad, pad, pad, pad), mode='constant', value=0)
        
        # Extract neighborhoods for all batch items
        neighbor_positions = []
        for b in range(B):
            x, y, z = target_idx[b]
            x, y, z = x + pad, y + pad, z + pad  # Account for padding
            
            # Extract the positional encoding window
            pos_window = pos_enc_padded[x-pad:x+pad+1, y-pad:y+pad+1, z-pad:z+pad+1, :]  # [ws, ws, ws, d_model]
            neighbor_positions.append(pos_window)
        
        neighbor_positions = torch.stack(neighbor_positions, dim=0)  # [B, ws, ws, ws, d_model]
        return neighbor_positions

    def forward(self, partial_grid, target_idx):
        """
        Args:
            partial_grid: [B, D, H, W] - partially filled voxel grid
            target_idx: [B, 3] - target positions (x, y, z) to predict
        
        Returns:
            logit: [B] - prediction logits for target voxels
        """
        B, D, H, W = partial_grid.shape
        ws = self.window_size
        center = ws // 2
        
        # Get positional encodings for the entire spatial grid
        pos_enc_grid = self.pos_encoding.get_encoding(D, H, W)  # [D, H, W, d_model]
        
        # Extract neighborhoods from the voxel grid
        neighborhoods = self.extract_neighborhood(partial_grid, target_idx)  # [B, ws, ws, ws, 1]
        
        # Get target voxel values (center of neighborhood)
        target_voxel = neighborhoods[:, center, center, center, :]  # [B, 1]
        
        # Get positional encodings for target positions
        # Use advanced indexing to efficiently get positions for all batch items
        target_pos = pos_enc_grid[target_idx[:, 0], target_idx[:, 1], target_idx[:, 2], :]  # [B, d_model]
        
        # Project target voxel and add positional encoding
        x = self.input_proj(target_voxel) + target_pos  # [B, d_model]
        
        # Get positional encodings for neighborhood windows
        neighbor_pos = self.extract_neighborhood_positions(pos_enc_grid, target_idx)  # [B, ws, ws, ws, d_model]
        
        # Project neighbor voxels and add their positional encodings
        neighbor_emb = self.input_proj(neighborhoods) + neighbor_pos  # [B, ws, ws, ws, d_model]
        
        # Apply transformer layers
        for layer in self.layers:
            x = layer(x, neighbor_emb)
        
        # Generate final prediction logit
        logit = self.output_proj(x).squeeze(-1)  # [B]
        return logit

def train_model_voxelwise(
    model: nn.Module,
    train_set,
    val_set,
    num_epochs: int = 50,
    batch_size: int = 1,
    window_size: int = 3,
    num_voxels_per_obj: int = 64,
    seed: int = 42
):
    print(f"Batch Size: {batch_size}, Window Size: {window_size}")
    torch.cuda.empty_cache()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
    print(f"Train loader size: {len(train_loader)}, Val loader size: {len(val_loader)}")
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    total_start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        total_loss = 0
        num_samples_processed = 0
        model.train()
        epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, unit='sample')
        for batch_idx, (complete_grid, partial_grid) in enumerate(epoch_pbar):
            complete_grid = complete_grid.to(device)
            partial_grid = partial_grid.to(device)
            B, D, H, W = partial_grid.shape
            # 1. Sample indices: half truly missing, half always empty
            target_indices = sample_balanced_voxels(partial_grid, complete_grid, num_voxels_per_obj, window_size=window_size)  # [B, num_voxels, 3]
            batch_losses = []
            # Iterative prediction: update partial_grid after each prediction
            for v in range(num_voxels_per_obj):
                idx = target_indices[:, v, :]  # [B, 3]
                logits = model(partial_grid, idx)  # [B]
                labels = complete_grid[torch.arange(B), idx[:,0], idx[:,1], idx[:,2]].float()  # [B]
                loss = criterion(logits, labels)
                batch_losses.append(loss)
                # Update partial_grid for next iteration (if predicted filled)
                with torch.no_grad():
                    pred_filled = (torch.sigmoid(logits) > 0.5).float()
                    for b in range(B):
                        x, y, z = idx[b]
                        if pred_filled[b] == 1.0:
                            partial_grid[b, x, y, z] = 1.0
            total_batch_loss = torch.stack(batch_losses).mean()
            optimizer.zero_grad()
            total_batch_loss.backward()
            optimizer.step()
            total_loss += total_batch_loss.item()
            num_samples_processed += 1
            epoch_pbar.set_postfix({
                'train_loss': f'{total_loss/num_samples_processed:.4f}',
                'samples': num_samples_processed,
                'lr': optimizer.param_groups[0]['lr']
            })
        avg_train_loss = total_loss / max(num_samples_processed, 1)
        print(f"Epoch {epoch+1}/{num_epochs} - Train loss: {avg_train_loss:.4f}, Samples: {num_samples_processed}")
    total_time = time.time() - total_start_time
    print(f"\nTraining completed in {timedelta(seconds=int(total_time))}")
    print(f"Average time per epoch: {timedelta(seconds=int(total_time/num_epochs))}")

def sample_missing_voxels(partial_grid, complete_grid, num_samples, window_size=3):
    # Only sample from truly missing voxels: (partial == 0) & (complete == 1)
    # Prioritize those with most filled neighbors, break ties by distance to origin
    B, D, H, W = partial_grid.shape
    pad = window_size // 2
    indices = []
    for b in range(B):
        missing = ((partial_grid[b] == 0) & (complete_grid[b] == 1)).nonzero(as_tuple=False)  # [N, 3]
        if len(missing) == 0:
            # fallback: pick random
            x = torch.randint(0, D, (num_samples,))
            y = torch.randint(0, H, (num_samples,))
            z = torch.randint(0, W, (num_samples,))
            idx = torch.stack([x, y, z], dim=1)
            indices.append(idx)
            continue
        # Pad partial grid for easy neighbor counting
        padded = F.pad(partial_grid[b].unsqueeze(0), (pad, pad, pad, pad, pad, pad), mode='constant', value=0)[0]
        neighbor_counts = []
        dists = []
        for coord in missing:
            x, y, z = coord.tolist()
            x_p, y_p, z_p = x+pad, y+pad, z+pad
            window = padded[x_p-pad:x_p+pad+1, y_p-pad:y_p+pad+1, z_p-pad:z_p+pad+1]
            count = window.sum().item()
            neighbor_counts.append(count)
            dists.append(x**2 + y**2 + z**2)
        neighbor_counts = np.array(neighbor_counts)
        dists = np.array(dists)
        # Sort: most filled neighbors first, then closest to origin
        sort_idx = np.lexsort((dists, -neighbor_counts))
        chosen = missing[sort_idx][:num_samples]
        # Pad if not enough missing voxels
        if chosen.shape[0] < num_samples:
            pad_count = num_samples - chosen.shape[0]
            pad_idx = chosen[-1].unsqueeze(0).repeat(pad_count, 1) if chosen.shape[0] > 0 else torch.zeros((pad_count, 3), dtype=torch.long)
            chosen = torch.cat([chosen, pad_idx], dim=0)
        indices.append(chosen)
    indices = torch.stack(indices, dim=0)  # [B, num_samples, 3]
    return indices

def sample_balanced_voxels(partial_grid, complete_grid, num_samples, window_size=3):
    import numpy as np
    import torch
    import torch.nn.functional as F

    # Half from truly missing, half from always empty
    B, D, H, W = partial_grid.shape
    pad = window_size // 2
    indices = []
    n1 = num_samples // 2
    n2 = num_samples - n1

    for b in range(B):
        # Truly missing: present in complete, missing in partial
        missing = ((partial_grid[b] == 0) & (complete_grid[b] == 1)).nonzero(as_tuple=False)  # [N1, 3]
        # Always empty: missing in both
        empty = ((partial_grid[b] == 0) & (complete_grid[b] == 0)).nonzero(as_tuple=False)  # [N2, 3]

        # Prioritize truly missing first
        if len(missing) > 0:
            padded = F.pad(
                partial_grid[b].unsqueeze(0),
                (pad, pad, pad, pad, pad, pad),
                mode="constant",
                value=0
            )[0]
            neighbor_counts = []
            dists = []

            for coord in missing:
                x, y, z = coord.tolist()
                x_p, y_p, z_p = x + pad, y + pad, z + pad
                window = padded[
                    x_p - pad : x_p + pad + 1,
                    y_p - pad : y_p + pad + 1,
                    z_p - pad : z_p + pad + 1,
                ]
                count = window.sum().item()
                neighbor_counts.append(count)
                dists.append(x**2 + y**2 + z**2)

            # Explicit sorting: most neighbors first, then closest to origin
            sort_idx = sorted(
                range(len(missing)),
                key=lambda i: (-neighbor_counts[i], dists[i])
            )
            chosen_missing = missing[sort_idx][:n1]

            # Pad if not enough
            if chosen_missing.shape[0] < n1:
                pad_count = n1 - chosen_missing.shape[0]
                pad_idx = (
                    chosen_missing[-1].unsqueeze(0).repeat(pad_count, 1)
                    if chosen_missing.shape[0] > 0
                    else torch.zeros((pad_count, 3), dtype=torch.long)
                )
                chosen_missing = torch.cat([chosen_missing, pad_idx], dim=0)
        else:
            # Random fallback if no missing voxels
            x = torch.randint(0, D, (n1,))
            y = torch.randint(0, H, (n1,))
            z = torch.randint(0, W, (n1,))
            chosen_missing = torch.stack([x, y, z], dim=1)

        # Randomly sample from always empty
        if len(empty) > 0:
            perm = torch.randperm(len(empty))[:n2]
            chosen_empty = empty[perm]
            if chosen_empty.shape[0] < n2:
                pad_count = n2 - chosen_empty.shape[0]
                pad_idx = (
                    chosen_empty[-1].unsqueeze(0).repeat(pad_count, 1)
                    if chosen_empty.shape[0] > 0
                    else torch.zeros((pad_count, 3), dtype=torch.long)
                )
                chosen_empty = torch.cat([chosen_empty, pad_idx], dim=0)
        else:
            # Random fallback if no empty voxels
            x = torch.randint(0, D, (n2,))
            y = torch.randint(0, H, (n2,))
            z = torch.randint(0, W, (n2,))
            chosen_empty = torch.stack([x, y, z], dim=1)

        # Combine missing + empty
        indices.append(torch.cat([chosen_missing, chosen_empty], dim=0))

    indices = torch.stack(indices, dim=0)  # [B, num_samples, 3]
    return indices



In [3]:
# !pip install google-auth-oauthlib

# from google.colab import drive
# drive.mount('/content/drive')

# DRIVE_PATH = "/content/drive/MyDrive/AUB_masters/thesis/data/partial_data_16.zip"  # Adjust this path to match your Drive structure
# LOCAL_PATH = "/content/partial_data"
# !mkdir -p {LOCAL_PATH}

# print("Copying data from Drive to local storage...")
# !cp "{DRIVE_PATH}" "{LOCAL_PATH}/data.zip"
# zip_path = f"{LOCAL_PATH}/data.zip"

In [None]:
zip_path = "../../chunk_data_16_flood_fill_rm_40.zip"
dataset = VoxelDataset(zip_path)

print(f"Total dataset size: {len(dataset)}")

train_idx, val_idx, test_idx = split_dataset(dataset, seed=42)
train_set = torch.utils.data.Subset(dataset, train_idx)
val_set = torch.utils.data.Subset(dataset, val_idx)
test_set = torch.utils.data.Subset(dataset, test_idx)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

model = VoxelCompletionTransformer(
    d_model=48,        
    num_heads=6,       
    num_layers=6,      
    window_size=3,    
    dropout=0.1
).to(device)

print("Starting training...")
torch.cuda.empty_cache()
# best model before this was num_epochs=2
train_model(model, train_set, val_set, num_epochs=5, batch_size=2, lambda_consistency=1)

In [3]:
MODEL_SAVE_PATH = "../../models/iterative_model_rm_20/iterative_model_rm_20_dmodel_RES_32.pth"
# os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
# torch.save({
#     'model_state_dict': model.state_dict(),
# }, MODEL_SAVE_PATH)
# print(f"Model saved to {MODEL_SAVE_PATH}")

In [5]:
import numpy as np
import torch
import glob
import json 

# load_dir = "../../test_data/test_data_rm_40/"
# test_samples = []

# for file in sorted(glob.glob(os.path.join(load_dir, "test_*.npz"))):
#     data = np.load(file)
#     complete = torch.from_numpy(data['complete']).float()
#     partial = torch.from_numpy(data['partial']).float()
#     test_samples.append((complete, partial))
    
zip_path = "../../model_data/chunk_data_32_flood_fill_rm_20.zip"
dataset = VoxelDataset(zip_path)

test_dir = "../../test_data/test_data_rm_20_iterative/"
test_indices_file = os.path.join(test_dir, "test_indices.json")

with open(test_indices_file, "r") as f:
    test_idx = json.load(f)

test_samples = []
for idx in range(1000):
    complete, partial = dataset[idx]
    test_samples.append((complete, partial))

print(f"Loaded {len(test_samples)} test samples from {test_dir}")



Created temporary directory: /tmp/tmplcoiamyx
Extracted zip file to temporary directory
Found 264091 total NPZ files
Using 264091
Cleaned up temporary directory: /tmp/tmpbnh9we5n
Loaded 1000 test samples from ../../test_data/test_data_rm_20_iterative/


In [None]:
# --- Updated test_model to use test_set ---
from torchviz import make_dot
import torch
from torchinfo import summary

def test_model(model_path, test_set, sample_idx=0, threshold=0.5, device=None):
    """
    For every empty voxel in the partial grid, predict if it should be filled (probability > threshold).
    Voxels are sorted by number of filled neighbors in a 3x3x3 window (descending), breaking ties by distance to origin (ascending).
    After each prediction, update the output grid for the next prediction.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformerVoxelwise(
        d_model=96,        
        num_heads=6,       
        num_layers=6,
        max_grid_size=32,      
        window_size=3,    
        dropout=0.1
    ).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    complete, partial = test_set[sample_idx]
    complete = (complete > 0).float()
    partial = (partial > 0).float()
    B, D, H, W = 1, *partial.shape
    partial_grid = partial.unsqueeze(0).to(device)
    complete_grid = complete.unsqueeze(0).to(device)
    output_grid = partial_grid.clone()

    # Find all empty voxels in the partial grid
    empty_voxels = (partial_grid[0] == 0).nonzero(as_tuple=False)  # [N, 3]

    # Sort empty voxels by number of filled neighbors in a 3x3x3 window, break ties by distance to origin
    def count_filled_neighbors_3x3x3(voxel, grid):
        x, y, z = voxel
        D, H, W = grid.shape[1:]
        count = 0
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x + dx, y + dy, z + dz
                    if 0 <= nx < D and 0 <= ny < H and 0 <= nz < W:
                        if dx == 0 and dy == 0 and dz == 0:
                            continue  # skip the voxel itself
                        if grid[0, nx, ny, nz] > 0:
                            count += 1
        return count

    voxel_scores = []
    for v in empty_voxels:
        filled_neighbors = count_filled_neighbors_3x3x3(v.tolist(), output_grid)
        distance = sum([coord**2 for coord in v.tolist()])  # squared Euclidean distance to origin
        voxel_scores.append((v, filled_neighbors, distance))
    # Sort: more neighbors first, then closer to origin
    voxel_scores.sort(key=lambda x: (-x[1], x[2]))
    sorted_empty_voxels = [v[0] for v in voxel_scores]

    with torch.no_grad():
        for idx in sorted_empty_voxels:
            idx = idx.unsqueeze(0)  # [1, 3]
            logits = model(output_grid, idx)  # [1]
            pred_filled = (torch.sigmoid(logits) > threshold).float()
            # Update output grid for next prediction
            if pred_filled[0] == 1.0:
                x, y, z = idx[0]
                output_grid[0, x, y, z] = 1.0
    output = output_grid.squeeze(0).cpu().numpy()
    np.save("output_voxel.npy", output)
    np.save("complete_voxel.npy", complete.cpu().numpy())
    np.save("partial_voxel.npy", partial.cpu().numpy())
    print("Inference complete. Output saved.")

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_model(
    model_path=MODEL_SAVE_PATH,
    test_set=test_samples,
    sample_idx=
    # 74,
    random.randint(0, 1000),
    threshold=0.9,
    device=device
)


Inference complete. Output saved.


# Test Function w/ Metrics

In [87]:
# --- Updated test_model_metrics to compute CD, IoU, F1, HD ---
from torchviz import make_dot
import torch
from torchinfo import summary
import numpy as np
from scipy.spatial.distance import cdist, directed_hausdorff

def test_model_metrics(model_path, test_set, sample_idx=0, threshold=0.5, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformer(
        d_model=96,        
        num_heads=6,       
        num_layers=6,
        max_grid_size=32,      
        window_size=3,    
        dropout=0.1
    ).to(device)
    # checkpoint = torch.load(model_path, map_location=device)
    # model.load_state_dict(checkpoint['model_state_dict'])

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v

    model.load_state_dict(new_state_dict)

    model.eval()
    complete, partial = test_set[sample_idx]

    filled_complete = (complete > 0).sum()
    filled_partial = ((complete > 0) & (partial > 0)).sum()
    missing_percent = (filled_complete - filled_partial) / filled_complete
    print("Missing Percentage: ", missing_percent)

    ptl = partial
    partial = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, D, H, W]
    with torch.no_grad():
        output = model(partial)
        output = torch.sigmoid(output)
        # output[0, 0][ptl == 1] = 1.0
        output = output.squeeze().cpu()

    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            for k in range(output.shape[2]):
                if output[i, j, k] > threshold:
                    output[i, j, k] = 1.0
                else:
                    output[i, j, k] = 0.0
    print("Inference complete.")
    print("Partial shape:", partial.shape)
    print("Output shape:", output.shape)

    # Binarize output
    pred = (output > threshold).float()
    gt = (complete > 0).float()

    # IoU
    intersection = ((pred == 1) & (gt == 1)).sum().item()
    union = ((pred == 1) | (gt == 1)).sum().item()
    iou = intersection / (union + 1e-8)
    print(f"IoU: {iou:.4f}")

    # F1 Score
    tp = ((pred == 1) & (gt == 1)).sum().item()
    fp = ((pred == 1) & (gt == 0)).sum().item()
    fn = ((pred == 0) & (gt == 1)).sum().item()
    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)
    print(f"F1 Score: {f1:.4f}")

    # Chamfer Distance (CD)
    def get_points(voxel):
        return np.argwhere(voxel.numpy() > 0.5)
    pred_points = get_points(pred)
    gt_points = get_points(gt)
    if len(pred_points) > 0 and len(gt_points) > 0:
        dist_pred_to_gt = cdist(pred_points, gt_points)
        dist_gt_to_pred = cdist(gt_points, pred_points)
        chamfer = np.mean(np.min(dist_pred_to_gt, axis=1)) + np.mean(np.min(dist_gt_to_pred, axis=1))
        print(f"Chamfer Distance: {chamfer:.4f}")
    else:
        print("Chamfer Distance: N/A (empty prediction or ground truth)")

    # Hausdorff Distance (HD, UHD)
    if len(pred_points) > 0 and len(gt_points) > 0:
        hd_pred_gt = directed_hausdorff(pred_points, gt_points)[0]
        hd_gt_pred = directed_hausdorff(gt_points, pred_points)[0]
        hausdorff = max(hd_pred_gt, hd_gt_pred)
        print(f"Hausdorff Distance: {hausdorff:.4f}")
        print(f"Unidirectional HD (pred→gt): {hd_pred_gt:.4f}, (gt→pred): {hd_gt_pred:.4f}")
    else:
        print("Hausdorff Distance: N/A (empty prediction or ground truth)")

    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    np.save(out_path, output.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    print("Sample Index: ", sample_idx)
    print(f"Output saved to {out_path}")

In [210]:
test_model_metrics(
    model_path=MODEL_SAVE_PATH,
    test_set=test_samples,
    sample_idx=random.randint(0, len(test_samples) - 1),
    # sample_idx=293,
    threshold=0.7
)
# 25451
# 31878
# 41617

Missing Percentage:  tensor(0.1999)
Inference complete.
Partial shape: torch.Size([1, 1, 32, 32, 32])
Output shape: torch.Size([32, 32, 32])
IoU: 0.8934
F1 Score: 0.9437
Chamfer Distance: 0.1213
Hausdorff Distance: 2.4495
Unidirectional HD (pred→gt): 2.4495, (gt→pred): 1.0000
Sample Index:  765
Output saved to output_voxel.npy


# Attention Visualization

In [139]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.colors import LinearSegmentedColormap

def test_model_with_attention(model_path, test_set, sample_idx=0, threshold=0.5, device=None):
    """Modified inference function that captures attention weights"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model = VoxelCompletionTransformer(
        d_model=48,
        num_heads=6,
        num_layers=6,
        window_size=3,
        dropout=0.1
    ).to(device)
    
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    # Hook to capture attention weights
    attention_weights = []
    
    def attention_hook(module, input, output):
        # Capture attention weights from each attention layer
        if hasattr(module, 'attn_weights'):
            attention_weights.append(module.attn_weights.detach().cpu())
    
    # Register hooks on attention layers
    for name, module in model.named_modules():
        if 'attention' in name.lower() or 'attn' in name.lower():
            module.register_forward_hook(attention_hook)
    
    complete, partial = test_set[sample_idx]
    filled_complete = (complete > 0).sum()
    filled_partial = ((complete > 0) & (partial > 0)).sum()
    missing_percent = (filled_complete - filled_partial) / filled_complete
    
    print("Missing Percentage: ", missing_percent)
    ptl = partial
    partial = partial.unsqueeze(0).unsqueeze(0).to(device)
    
    with torch.no_grad():
        output = model(partial)
    
    output = torch.sigmoid(output)
    output[0, 0][ptl == 1] = 1.0
    output = output.squeeze().cpu()
    
    # Process output
    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            for k in range(output.shape[2]):
                if output[i, j, k] > threshold:
                    output[i, j, k] = 1.0
                else:
                    output[i, j, k] = 0.0
    
    # Save outputs
    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    
    np.save(out_path, output.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    
    print("Sample Index: ", sample_idx)
    print(f"Output saved to {out_path}")
    
    return output, attention_weights, partial.squeeze(), complete, ptl

def visualize_attention_weights(attention_weights, layer_idx=0, head_idx=0, save_path="attention_viz.png"):
    """Visualize attention weights for a specific layer and head"""
    if len(attention_weights) == 0:
        print("No attention weights captured. Make sure your model has attention layers with 'attention_weights' attribute.")
        return
    
    if layer_idx >= len(attention_weights):
        print(f"Layer {layer_idx} not found. Available layers: 0-{len(attention_weights)-1}")
        return
    
    attn = attention_weights[layer_idx]
    print(f"Attention weights shape: {attn.shape}")
    
    # Handle different attention weight shapes
    if len(attn.shape) == 6:  # [batch, heads, D, H, W, window_size]
        attn = attn[0, head_idx]  # Select first batch, specific head -> [D, H, W, window_size]
        # Reshape to 2D for visualization: [D*H*W, window_size]
        attn = attn.reshape(-1, attn.shape[-1])
    elif len(attn.shape) == 5:  # [heads, D, H, W, window_size]
        attn = attn[head_idx]  # Select specific head -> [D, H, W, window_size]
        attn = attn.reshape(-1, attn.shape[-1])
    elif len(attn.shape) == 4:  # [batch, heads, seq_len, seq_len]
        attn = attn[0, head_idx]  # Select first batch, specific head
    elif len(attn.shape) == 3:  # [heads, seq_len, seq_len]
        attn = attn[head_idx]  # Select specific head
    
    plt.figure(figsize=(12, 10))
    
    # Create heatmap
    plt.subplot(2, 2, 1)
    sns.heatmap(attn.numpy(), cmap='Blues', cbar=True, square=True)
    plt.title(f'Attention Weights - Layer {layer_idx}, Head {head_idx}')
    plt.xlabel('Key Position')
    plt.ylabel('Query Position')
    
    # Attention distribution histogram
    plt.subplot(2, 2, 2)
    plt.hist(attn.numpy().flatten(), bins=50, alpha=0.7, color='skyblue')
    plt.title('Attention Weight Distribution')
    plt.xlabel('Attention Weight')
    plt.ylabel('Frequency')
    
    # Max attention per query
    plt.subplot(2, 2, 3)
    max_attn = torch.max(attn, dim=1)[0]
    plt.plot(max_attn.numpy(), marker='o', linewidth=2, markersize=4)
    plt.title('Maximum Attention per Query Position')
    plt.xlabel('Query Position')
    plt.ylabel('Max Attention Weight')
    plt.grid(True, alpha=0.3)
    
    # Attention entropy (attention spread)
    plt.subplot(2, 2, 4)
    entropy = -torch.sum(attn * torch.log(attn + 1e-9), dim=1)
    plt.plot(entropy.numpy(), marker='s', linewidth=2, markersize=4, color='red')
    plt.title('Attention Entropy per Query Position')
    plt.xlabel('Query Position')
    plt.ylabel('Entropy')
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def visualize_3d_attention(attention_weights, voxel_shape=(16, 16, 16), layer_idx=0, head_idx=0, 
                          query_pos=None, save_path="attention_3d.png"):
    """Visualize 3D attention patterns for voxel data"""
    if len(attention_weights) == 0:
        print("No attention weights captured.")
        return
    
    attn = attention_weights[layer_idx]
    print(f"Original attention shape: {attn.shape}")
    
    # Handle 6D attention weights [batch, heads, D, H, W, window_size]
    if len(attn.shape) == 6:
        attn = attn[0, head_idx]  # Select first batch and head -> [D, H, W, window_size]
    elif len(attn.shape) == 5:
        attn = attn[head_idx]  # Select head -> [D, H, W, window_size]
    elif len(attn.shape) == 4:
        attn = attn[0, head_idx]
    elif len(attn.shape) == 3:
        attn = attn[head_idx]
    
    print(f"After selection: {attn.shape}")
    
    # For windowed attention, we need to handle this differently
    if len(attn.shape) == 4:  # [D, H, W, window_size]
        D, H, W, window_size = attn.shape
        
        # Select a specific voxel position to visualize its attention
        if query_pos is None:
            query_d, query_h, query_w = D//2, H//2, W//2  # Center voxel
        else:
            query_d, query_h, query_w = query_pos
        
        # Get attention weights for this query position
        attention_values = attn[query_d, query_h, query_w].numpy()  # [window_size]
        
        # Create 3D visualization showing attention within the local window
        fig = plt.figure(figsize=(15, 5))
        
        # Calculate window positions relative to query
        window_radius = int(np.cbrt(window_size)) // 2  # Assuming cubic window
        positions = []
        
        idx = 0
        for dz in range(-window_radius, window_radius + 1):
            for dy in range(-window_radius, window_radius + 1):
                for dx in range(-window_radius, window_radius + 1):
                    if idx < window_size:
                        # Absolute positions
                        abs_d = max(0, min(D-1, query_d + dz))
                        abs_h = max(0, min(H-1, query_h + dy))
                        abs_w = max(0, min(W-1, query_w + dx))
                        positions.append([abs_w, abs_h, abs_d])  # x, y, z
                        idx += 1
        
        positions = np.array(positions)
        
        # 3D scatter plot
        ax1 = fig.add_subplot(131, projection='3d')
        scatter = ax1.scatter(positions[:, 0], positions[:, 1], positions[:, 2], 
                             c=attention_values, cmap='viridis', s=50, alpha=0.8)
        
        # Highlight query position
        query_3d = [query_w, query_h, query_d]
        ax1.scatter(query_3d[0], query_3d[1], query_3d[2], 
                   c='red', s=200, marker='*', edgecolors='black', linewidth=2)
        
        ax1.set_xlabel('X (W)')
        ax1.set_ylabel('Y (H)')
        ax1.set_zlabel('Z (D)')
        ax1.set_title(f'3D Attention Pattern\nQuery at ({query_w}, {query_h}, {query_d})')
        
        # Add colorbar
        cbar = plt.colorbar(scatter, ax=ax1, shrink=0.5)
        cbar.set_label('Attention Weight')
        
        # 2D projections
        ax2 = fig.add_subplot(132)
        scatter2 = ax2.scatter(positions[:, 0], positions[:, 1], c=attention_values, cmap='viridis', s=30)
        ax2.scatter(query_3d[0], query_3d[1], c='red', s=100, marker='*', edgecolors='black')
        ax2.set_xlabel('X (W)')
        ax2.set_ylabel('Y (H)')
        ax2.set_title('XY Projection')
        ax2.grid(True, alpha=0.3)
        
        ax3 = fig.add_subplot(133)
        scatter3 = ax3.scatter(positions[:, 0], positions[:, 2], c=attention_values, cmap='viridis', s=30)
        ax3.scatter(query_3d[0], query_3d[2], c='red', s=100, marker='*', edgecolors='black')
        ax3.set_xlabel('X (W)')
        ax3.set_ylabel('Z (D)')
        ax3.set_title('XZ Projection')
        ax3.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
        plt.show()
        
        return query_d, query_h, query_w
    
    else:
        # Fallback for other shapes
        print(f"Unexpected attention shape: {attn.shape}")
        return None

def visualize_attention_across_layers(attention_weights, save_path="attention_layers.png"):
    """Compare attention patterns across different layers"""
    if len(attention_weights) == 0:
        print("No attention weights captured.")
        return
    
    num_layers = len(attention_weights)
    fig, axes = plt.subplots(2, (num_layers + 1) // 2, figsize=(4 * num_layers, 8))
    
    if num_layers == 1:
        axes = [axes]
    elif num_layers > 1:
        axes = axes.flatten()
    
    for i, attn in enumerate(attention_weights):
        # Handle 6D attention weights
        if len(attn.shape) == 6:  # [batch, heads, D, H, W, window_size]
            attn = attn[0, 0]  # First batch, first head -> [D, H, W, window_size]
            # Average over spatial dimensions to get [window_size] pattern
            attn = attn.mean(dim=(0, 1, 2))  # Average over D, H, W
        elif len(attn.shape) == 5:  # [heads, D, H, W, window_size]
            attn = attn[0]  # First head -> [D, H, W, window_size]
            attn = attn.mean(dim=(0, 1, 2))
        elif len(attn.shape) == 4:  # [batch, heads, seq_len, seq_len]
            attn = attn[0, 0]  # First batch, first head
        elif len(attn.shape) == 3:  # [heads, seq_len, seq_len]
            attn = attn[0]  # First head
        
        ax = axes[i] if num_layers > 1 else axes
        
        # If we have a 1D attention pattern (from averaging), show as bar plot
        if len(attn.shape) == 1:
            ax.bar(range(len(attn)), attn.numpy(), alpha=0.7, color='skyblue')
            ax.set_title(f'Layer {i} - Avg Attention Pattern')
            ax.set_xlabel('Window Position')
            ax.set_ylabel('Average Attention Weight')
        else:
            # Create heatmap for 2D attention
            im = ax.imshow(attn.numpy(), cmap='Blues', aspect='auto')
            ax.set_title(f'Layer {i}')
            ax.set_xlabel('Key Position')
            ax.set_ylabel('Query Position')
            
            # Add colorbar
            plt.colorbar(im, ax=ax)
    
    # Hide unused subplots
    for i in range(num_layers, len(axes)):
        axes[i].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

def analyze_windowed_attention(attention_weights, layer_idx=0, head_idx=0, save_path="windowed_attention.png"):
    """Analyze attention patterns in windowed (local) attention"""
    if len(attention_weights) == 0:
        print("No attention weights captured.")
        return
    
    attn = attention_weights[layer_idx]
    
    # Handle 6D attention weights [batch, heads, D, H, W, window_size]
    if len(attn.shape) == 6:
        attn = attn[0, head_idx]  # Select first batch and head -> [D, H, W, window_size]
    elif len(attn.shape) == 5:
        attn = attn[head_idx]  # Select head -> [D, H, W, window_size]
    
    print(f"Analyzing windowed attention with shape: {attn.shape}")
    
    D, H, W, window_size = attn.shape
    
    # Convert to numpy for easier manipulation
    attn_np = attn.numpy()
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Average attention pattern across all spatial positions
    avg_attention = np.mean(attn_np, axis=(0, 1, 2))  # Average over D, H, W
    axes[0, 0].bar(range(window_size), avg_attention, alpha=0.7, color='skyblue')
    axes[0, 0].set_title('Average Attention Pattern')
    axes[0, 0].set_xlabel('Window Position')
    axes[0, 0].set_ylabel('Average Attention Weight')
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Attention variance across spatial positions
    var_attention = np.var(attn_np, axis=(0, 1, 2))
    axes[0, 1].bar(range(window_size), var_attention, alpha=0.7, color='lightcoral')
    axes[0, 1].set_title('Attention Variance')
    axes[0, 1].set_xlabel('Window Position')
    axes[0, 1].set_ylabel('Variance')
    axes[0, 1].grid(True, alpha=0.3)
    
    # 3. Attention entropy distribution
    # Calculate entropy for each spatial position
    entropy_map = np.zeros((D, H, W))
    for d in range(D):
        for h in range(H):
            for w in range(W):
                probs = attn_np[d, h, w]
                probs = probs + 1e-9  # Add small epsilon to avoid log(0)
                entropy_map[d, h, w] = -np.sum(probs * np.log(probs))
    
    # Show entropy for middle slice
    middle_slice = entropy_map[D//2, :, :]
    im1 = axes[0, 2].imshow(middle_slice, cmap='viridis', aspect='auto')
    axes[0, 2].set_title(f'Attention Entropy (Slice D={D//2})')
    axes[0, 2].set_xlabel('W')
    axes[0, 2].set_ylabel('H')
    plt.colorbar(im1, ax=axes[0, 2])
    
    # 4. Attention focus (max attention weight) across space
    max_attention = np.max(attn_np, axis=-1)  # Max over window_size
    middle_slice_max = max_attention[D//2, :, :]
    im2 = axes[1, 0].imshow(middle_slice_max, cmap='Blues', aspect='auto')
    axes[1, 0].set_title(f'Max Attention (Slice D={D//2})')
    axes[1, 0].set_xlabel('W')
    axes[1, 0].set_ylabel('H')
    plt.colorbar(im2, ax=axes[1, 0])
    
    # 5. Attention distribution histogram
    axes[1, 1].hist(attn_np.flatten(), bins=50, alpha=0.7, color='green', edgecolor='black')
    axes[1, 1].set_title('Attention Weight Distribution')
    axes[1, 1].set_xlabel('Attention Weight')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].grid(True, alpha=0.3)
    
    # 6. Window position preference
    # Reshape window to 3D coordinates assuming cubic window
    window_cube_size = int(np.cbrt(window_size))
    if window_cube_size ** 3 == window_size:
        window_3d = avg_attention.reshape(window_cube_size, window_cube_size, window_cube_size)
        # Show middle slice of the window
        middle_window_slice = window_3d[window_cube_size//2, :, :]
        im3 = axes[1, 2].imshow(middle_window_slice, cmap='Reds', aspect='auto')
        axes[1, 2].set_title(f'Window Attention Pattern (Middle Slice)')
        axes[1, 2].set_xlabel('Window W')
        axes[1, 2].set_ylabel('Window H')
        plt.colorbar(im3, ax=axes[1, 2])
    else:
        # If not cubic, show as 1D
        axes[1, 2].plot(avg_attention, marker='o', linewidth=2, markersize=4)
        axes[1, 2].set_title('Window Attention Pattern')
        axes[1, 2].set_xlabel('Window Position')
        axes[1, 2].set_ylabel('Attention Weight')
        axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print some statistics
    print(f"\nAttention Statistics:")
    print(f"Average attention range: {avg_attention.min():.4f} - {avg_attention.max():.4f}")
    print(f"Most attended window position: {np.argmax(avg_attention)}")
    print(f"Least attended window position: {np.argmin(avg_attention)}")
    print(f"Average entropy: {np.mean(entropy_map):.4f}")
    print(f"Entropy range: {entropy_map.min():.4f} - {entropy_map.max():.4f}")


output, attention_weights, partial, complete, ptl = test_model_with_attention(
    model_path=MODEL_SAVE_PATH,
    test_set=test_set,
    sample_idx=0
)

# Visualize attention weights
if attention_weights:
    analyze_windowed_attention(attention_weights, layer_idx=0, head_idx=0)
    visualize_attention_weights(attention_weights, layer_idx=0, head_idx=0)
    visualize_3d_attention(attention_weights, layer_idx=0, head_idx=0)
    visualize_attention_across_layers(attention_weights)
else:
    print("No attention weights captured. You may need to modify your model to store attention weights.")

NameError: name 'test_set' is not defined

# multi-step inference

In [74]:
def test_model_multi_step(
    model_path, 
    test_set, 
    sample_idx=0, 
    start_threshold=0.7, 
    threshold_step=0.05, 
    num_steps=3, 
    device=None
):
    """
    Run multi-step inference on a test sample, binarizing and feeding output as next partial.
    Args:
        model_path: Path to model checkpoint
        test_set: List of (complete, partial) tuples
        sample_idx: Index of test sample
        start_threshold: Initial threshold for binarization
        threshold_step: Amount to decrease threshold each step
        num_steps: Number of thresholds to try
        device: torch.device
    """
    import torch
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformer(
        d_model=96,        
        num_heads=6,       
        num_layers=6,      
        window_size=3,    
        dropout=0.1
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    # Remove 'module.' prefix if present (from DataParallel)
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    complete, partial = test_set[sample_idx]
    ptl = partial
    partial = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, D, H, W]
    thresholds = [start_threshold - i * threshold_step for i in range(num_steps)]
    with torch.no_grad():
        for i, threshold in enumerate(thresholds):
            output = model(partial)
            output = torch.sigmoid(output)
            output = output.squeeze().cpu()
            # Binarize output and use as next partial
            for x in range(output.shape[0]):
                for y in range(output.shape[1]):
                    for z in range(output.shape[2]):
                        if output[x, y, z] > threshold:
                            output[x, y, z] = 1.0
                        else:
                            output[x, y, z] = 0.0
            print(f"\nThreshold {threshold:.2f} (step {i+1}/{num_steps}):")
            print(f"  Predicted filled voxels: {int(output.sum().item())}")
            # Prepare for next step
            partial = output.unsqueeze(0).unsqueeze(0).to(device)
    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    np.save(out_path, output.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    print("Sample Index: ", sample_idx)
    print(f"Output saved to {out_path}")

In [75]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
test_model_multi_step(
    model_path=MODEL_SAVE_PATH,
    test_set=test_samples,
    sample_idx=
    0,
    # random.randint(0, len(test_samples) - 1),
    start_threshold=0.7, 
    threshold_step=0.05, 
    num_steps=3, 
    device=device

)


Threshold 0.70 (step 1/3):
  Predicted filled voxels: 434

Threshold 0.65 (step 2/3):
  Predicted filled voxels: 1205

Threshold 0.60 (step 3/3):
  Predicted filled voxels: 2246
Sample Index:  0
Output saved to output_voxel.npy
