In [1]:
# extension to agent copy 4.ipynb F1-score, pos weight
from typing import Tuple, List
from dataclasses import dataclass
from tqdm import tqdm
from datetime import timedelta
import zipfile
import shutil
import tempfile
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import os
import glob
from torch.utils.data import Subset, DataLoader

In [2]:
class VoxelDataLoader:
    """Loads and processes NPZ voxel data from a zip file"""

    def __init__(self, zip_path: str):
        # Create a temporary directory
        self.temp_dir = tempfile.mkdtemp()
        print(f"Created temporary directory: {self.temp_dir}")

        # Extract zip file
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.temp_dir)
        print(f"Extracted zip file to temporary directory")

        # Find all NPZ files
        all_files = glob.glob(os.path.join(self.temp_dir, "**/*.npz"), recursive=True)
        print(f"Found {len(all_files)} total NPZ files")

        if len(all_files) == 0:
            raise ValueError(f"No NPZ files found in zip file")

        random.shuffle(all_files)  # Shuffle before splitting
        cutoff = int(len(all_files))
        self.npz_files = all_files[:cutoff]
        print(f"Using {len(self.npz_files)}")

    def __del__(self):
        """Cleanup temporary directory when object is destroyed"""
        try:
            shutil.rmtree(self.temp_dir)
            print(f"Cleaned up temporary directory: {self.temp_dir}")
        except:
            print(f"Failed to clean up temporary directory: {self.temp_dir}")

    def load_single_file(self, file_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
        data = np.load(file_path)

        # More robust key checking
        if 'complete' not in data or 'partial' not in data:
            raise ValueError(f"NPZ file {file_path} must contain both 'complete' and 'partial' arrays")

        complete = torch.from_numpy(data['complete']).float()
        partial = torch.from_numpy(data['partial']).float()

        # Verify shapes match
        if complete.shape != partial.shape:
            raise ValueError(f"Shape mismatch in {file_path}: complete {complete.shape} vs partial {partial.shape}")

        return complete, partial

    def get_all_data(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Load all voxel pairs from all NPZ files"""
        all_data = []
        for file_path in self.npz_files:
            complete, partial = self.load_single_file(file_path)
            all_data.append((complete, partial))
        return all_data

    def get_voxel_grids(self, index: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns complete and partial voxel grids from a specific file"""
        if index >= len(self.npz_files):
            raise IndexError(f"Index {index} out of range. Only {len(self.npz_files)} files available.")
        return self.load_single_file(self.npz_files[index])


class VoxelDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for voxel completion"""

    def __init__(self, zip_path: str, transform=None):
        self.data_loader = VoxelDataLoader(zip_path)
        self.transform = transform

    def __len__(self):
        return len(self.data_loader.npz_files)

    def __getitem__(self, idx):
        complete, partial = self.data_loader.get_voxel_grids(idx)
        # Normalize to [0,1] if not already
        complete = (complete > 0).float()
        partial = (partial > 0).float()
        if self.transform:
            complete, partial = self.transform(complete, partial)
        return complete, partial


# Update data loader creation function
def create_data_loader(zip_path: str, batch_size: int = 1, shuffle: bool = True, num_workers: int = 0):
    """Create a PyTorch DataLoader for training"""
    dataset = VoxelDataset(zip_path)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        # pin_memory=True
    )


def split_dataset(dataset, train_ratio=0.8, val_ratio=0.2, seed=42):
    n = len(dataset)
    indices = list(range(n))
    random.Random(seed).shuffle(indices)

    # from dataset: 80% train 20% test 
    n_trainval = int(n * 0.8)
    n_test = n - n_trainval
    trainval_indices = indices[:n_trainval]
    test_indices = indices[n_trainval:]
    # from training data: 80% train 20% validation
    n_train = int(len(trainval_indices) * 0.8)
    train_indices = trainval_indices[:n_train]
    val_indices = trainval_indices[n_train:]

    return train_indices, val_indices, test_indices

def create_data_loaders(zip_path, batch_size=1, shuffle=True, num_workers=0, seed=42):
    dataset = VoxelDataset(zip_path)
    print(f"Dataset size: {len(dataset)}")
    train_idx, val_idx, test_idx = split_dataset(dataset, seed=seed)
    train_loader = DataLoader(Subset(dataset, train_idx), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    val_loader = DataLoader(Subset(dataset, val_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(Subset(dataset, test_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

# ------------------------------
# Positional encoding (fixed for [B, D, H, W, d_model])
# ------------------------------
class PositionalEncoding3D(nn.Module):
    def __init__(self, d_model: int, max_grid_size: int = 32):
        super().__init__()
        self.d_model = d_model
        self.max_grid_size = max_grid_size
        # stored as (D, H, W, d_model)
        self.pos_embed = nn.Parameter(torch.zeros(max_grid_size, max_grid_size, max_grid_size, d_model))
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        # x: [B, D, H, W, d_model]
        B, D, H, W, _ = x.shape
        pos = self.pos_embed[:D, :H, :W, :].unsqueeze(0)  # [1, D, H, W, d_model]
        return x + pos

    def get_encoding(self, D, H, W):
        return self.pos_embed[:D, :H, :W, :]

# ------------------------------
# Local attention (your class, unchanged)
# ------------------------------
class LocalAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int = 4, window_size: int = 3):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.window_size = window_size
        assert d_model % num_heads == 0
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.scale = self.head_dim ** -0.5

    def forward(self, target_embedding, neighbor_embeddings, mask):
        """
        Args:
            target_embedding: [B, d_model] - Embedding of the voxel to predict.
            neighbor_embeddings: [B, ws, ws, ws, d_model] - Embeddings of the neighborhood.
            mask: [B, ws, ws, ws] - Boolean mask (True for known, False for unknown).
        """
        B = target_embedding.shape[0]
        ws = self.window_size

        neighbor_flat = neighbor_embeddings.view(B, ws * ws * ws, self.d_model)
        mask_flat = mask.view(B, ws * ws * ws)  # [B, ws^3]

        # Query from target embedding: [B, 1, d_model]
        q = self.q_proj(target_embedding.unsqueeze(1))  # [B, 1, d_model]
        # Keys and values from neighbors: [B, ws^3, d_model]
        k = self.k_proj(neighbor_flat)
        v = self.v_proj(neighbor_flat)

        # Reshape for multi-head attention
        q = q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, 1, head_dim]
        k = k.view(B, ws * ws * ws, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, ws^3, head_dim]
        v = v.view(B, ws * ws * ws, self.num_heads, self.head_dim).transpose(1, 2)  # [B, num_heads, ws^3, head_dim]

        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # [B, num_heads, 1, ws^3]

        # Apply mask (True means allowed)
        mask_expanded = mask_flat.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, 1, -1)
        scores = scores.masked_fill(~mask_expanded, float('-inf'))

        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = torch.where(torch.isnan(attn_weights), torch.zeros_like(attn_weights), attn_weights)

        out = torch.matmul(attn_weights, v)  # [B, num_heads, 1, head_dim]
        out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model).squeeze(1)
        return self.out_proj(out)

# ------------------------------
# Voxel transformer layer that uses LocalAttention per voxel
# ------------------------------
class VoxelTransformerLayer3D(nn.Module):
    def __init__(self, d_model: int, num_heads: int = 8, window_size: int = 3, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.window_size = window_size
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.attention = LocalAttention(d_model, num_heads, window_size)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, neighborhood_fn, mask_fn):
        """
        x: [B, D, H, W, d_model]
        neighborhood_fn: callable(grid, d, h, w, window_size) -> [B, ws, ws, ws, d_model]
        mask_fn: callable(D, H, W, d, h, w, window_size) -> [ws,ws,ws] or [1,ws,ws,ws] or [B,ws,ws,ws]
        """
        B, D, H, W, C = x.shape
        out = torch.zeros_like(x)

        for dd in range(D):
            for hh in range(H):
                for ww in range(W):
                    target = x[:, dd, hh, ww, :]  # [B, d_model]
                    neighbors = neighborhood_fn(x, dd, hh, ww, self.window_size)  # [B, ws, ws, ws, d_model]

                    # === robust mask handling ===
                    mask = mask_fn(D, H, W, dd, hh, ww, self.window_size)
                    # mask can be one of: [ws,ws,ws], [1,ws,ws,ws], [B,ws,ws,ws]
                    if mask.dim() == 3:
                        # [ws,ws,ws] -> [1,ws,ws,ws]
                        mask = mask.unsqueeze(0)
                    if mask.shape[0] == 1 and B > 1:
                        # [1,ws,ws,ws] -> [B,ws,ws,ws]
                        mask = mask.expand(B, -1, -1, -1).contiguous()
                    # now mask is guaranteed to be [B, ws, ws, ws]

                    # attention (per-voxel)
                    tgt_norm = self.norm1(target)
                    attn_out = self.attention(tgt_norm, neighbors, mask)
                    target = target + self.dropout(attn_out)

                    # ffn
                    tgt_norm = self.norm2(target)
                    ffn_out = self.ffn(tgt_norm)
                    target = target + ffn_out

                    out[:, dd, hh, ww, :] = target
        return out

# ------------------------------
# Stack layers
# ------------------------------
class VoxelTransformer3D(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, window_size: int, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            VoxelTransformerLayer3D(d_model, num_heads, window_size, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, neighborhood_fn, mask_fn):
        """
        x: [B, D, H, W, d_model]
        """
        for layer in self.layers:
            x = layer(x, neighborhood_fn, mask_fn)
        return x

# ------------------------------
# Helper functions for voxel candidate selection
# ------------------------------
def get_voxel_candidates(complete_grid, partial_grid, max_voxels: int = 256):
    """
    Returns a balanced *sample* of candidate voxels to predict.
    
    Args:
        complete_grid: [B, 1, D, H, W]
        partial_grid: [B, 1, D, H, W]
        max_voxels: maximum number of candidate voxels per batch element.
    Returns:
        list of voxel coords: (b, d, h, w, label)
    """
    B, _, D, H, W = complete_grid.shape
    candidates = []

    for b in range(B):
        filled = ((complete_grid[b,0]==1) & (partial_grid[b,0]==0)).nonzero(as_tuple=False)
        empty  = ((complete_grid[b,0]==0) & (partial_grid[b,0]==0)).nonzero(as_tuple=False)

        # balance between filled and empty
        k = max_voxels // 2
        filled_k = min(len(filled), k)
        empty_k  = min(len(empty),  k)

        if filled_k > 0:
            filled_idx = torch.randperm(len(filled))[:filled_k]
            filled = filled[filled_idx]
        else:
            filled = []

        if empty_k > 0:
            empty_idx = torch.randperm(len(empty))[:empty_k]
            empty = empty[empty_idx]
        else:
            empty = []

        for f in filled:
            candidates.append((b, f[0].item(), f[1].item(), f[2].item(), 1))
        for e in empty:
            candidates.append((b, e[0].item(), e[1].item(), e[2].item(), 0))

    return candidates


def compute_density(grid, d, h, w, window_size):
    """Counts filled neighbors around voxel. grid: [D,H,W]"""
    D, H, W = grid.shape
    r = window_size // 2
    d0, d1 = max(0, d-r), min(D, d+r+1)
    h0, h1 = max(0, h-r), min(H, h+r+1)
    w0, w1 = max(0, w-r), min(W, w+r+1)
    patch = grid[d0:d1, h0:h1, w0:w1]
    return float(patch.sum().item())

def sort_voxels(candidates, complete_grid, window_size):
    """Sorts voxel list by density then by distance to origin."""
    sorted_list = []
    for (b,d,h,w,label) in candidates:
        density = compute_density(complete_grid[b,0], d,h,w, window_size)
        dist = d + h + w  # Manhattan distance to origin
        sorted_list.append(((b,d,h,w,label), density, dist))
    sorted_list.sort(key=lambda x: (-x[1], x[2]))  # high density first, then close to origin
    return [item[0] for item in sorted_list]

# ------------------------------
# neighborhood_raw: returns raw patch (channel-first) and known_mask
# ------------------------------
def neighborhood_raw(grid, b, d, h, w, window_size):
    """
    grid: [B, 1, D, H, W] (binary partial grid with 1=known occupied and 0=unknown/empty)
    returns:
        patch: [1, 1, ws, ws, ws]  (channel-first, ready for Conv3d)
        known_mask: [1, ws, ws, ws] (boolean True where neighbor is known)
    """
    # grid: [B, C, D, H, W] where C==1
    _, C, D, H, W = grid.shape
    assert C == 1
    r = window_size // 2
    d0, d1 = max(0, d - r), min(D, d + r + 1)
    h0, h1 = max(0, h - r), min(H, h + r + 1)
    w0, w1 = max(0, w - r), min(W, w + r + 1)

    patch = grid[b:b+1, :, d0:d1, h0:h1, w0:w1]  # [1,1,d_patch,h_patch,w_patch]

    # pads: (W_left,W_right,H_top,H_bottom,D_front,D_back)
    pad_d0 = max(0, r - d)
    pad_h0 = max(0, r - h)
    pad_w0 = max(0, r - w)
    pad_d1 = max(0, (d + r + 1) - D)
    pad_h1 = max(0, (h + r + 1) - H)
    pad_w1 = max(0, (w + r + 1) - W)
    pad = (pad_w0, pad_w1, pad_h0, pad_h1, pad_d0, pad_d1)
    if any(p > 0 for p in pad):
        patch = F.pad(patch, pad)  # pad with zeros

    # patch now [1,1,ws,ws,ws]
    # known mask: True if this position was observed as known in partial grid
    known = (patch[:, 0] == 1).to(torch.bool)  # [1, ws, ws, ws]

    return patch.contiguous(), known.contiguous()

# ------------------------------
# IterativeVoxelModel: projection + positional encoding + small transformer on the patch
# ------------------------------
class IterativeVoxelModel(nn.Module):
    def __init__(self, d_model: int = 64, num_heads: int = 4, num_layers: int = 3,
                 window_size: int = 3, max_grid_size: int = 32, dropout: float = 0.1):
        """
        This model expects per-call patch input:
            neighbors_patch: [B, 1, ws, ws, ws]  (binary observed values or zeros)
            known_mask: [B, ws, ws, ws] (boolean)
        The model:
            - projects neighbors -> d_model via Conv3d(1 -> d_model)
            - permutes to [B, Dp, Hp, Wp, d_model], adds positional encoding
            - runs VoxelTransformer3D on the patch
            - extracts center voxel embedding and returns a scalar logit (per batch)
        """
        super().__init__()
        self.d_model = d_model
        self.window_size = window_size
        self.input_proj = nn.Conv3d(1, d_model, kernel_size=1)  # works on [B,1,ws,ws,ws]
        self.pos_encoding = PositionalEncoding3D(d_model, max_grid_size=max_grid_size)
        # use the same transformer as defined above, but it expects patch in [B, D, H, W, d_model]
        self.transformer = VoxelTransformer3D(num_layers=num_layers, d_model=d_model,
                                              num_heads=num_heads, window_size=window_size,
                                              dropout=dropout)
        # output head: from d_model to one logit
        self.output_head = nn.Linear(d_model, 1)

    def forward(self, neighbors_patch, known_mask):
        """
        neighbors_patch: [B, 1, ws, ws, ws]
        known_mask: [B, ws, ws, ws] boolean - True where neighbor is observed/known
        Returns:
            logits: [B, 1] (logit for center voxel occupancy)
        """
        B = neighbors_patch.shape[0]
        ws = self.window_size
        assert neighbors_patch.shape[2:] == (ws, ws, ws), f"neighbors size mismatch {neighbors_patch.shape}"
        # project
        emb = self.input_proj(neighbors_patch)  # [B, d_model, ws, ws, ws]
        # permute to [B, D, H, W, d_model]
        emb = emb.permute(0, 2, 3, 4, 1).contiguous()
        # add positional encoding
        emb = self.pos_encoding(emb)  # [B, ws, ws, ws, d_model]

        # define patch-local neighborhood_fn and mask_fn used by VoxelTransformerLayer3D
        def neighborhood_fn_patch(grid, dd, hh, ww, window_size):
            # grid: [B, Dp, Hp, Wp, d_model] where Dp=ws
            B2, Dp, Hp, Wp, C = grid.shape
            r = window_size // 2
            d0, d1 = max(0, dd - r), min(Dp, dd + r + 1)
            h0, h1 = max(0, hh - r), min(Hp, hh + r + 1)
            w0, w1 = max(0, ww - r), min(Wp, ww + r + 1)
            patch_local = grid[:, d0:d1, h0:h1, w0:w1, :]  # may be smaller than ws on boundaries
            # pad if necessary to shape [B, ws, ws, ws, d_model]
            pd0 = max(0, r - dd); pd1 = max(0, (dd + r + 1) - Dp)
            ph0 = max(0, r - hh); ph1 = max(0, (hh + r + 1) - Hp)
            pw0 = max(0, r - ww); pw1 = max(0, (ww + r + 1) - Wp)
            if any([pd0,pd1,ph0,ph1,pw0,pw1]):
                # permute to channel-first temporarily to use F.pad on last 3 dims
                # patch_local: [B, d_patch, h_patch, w_patch, C] -> permute to [B, C, d,h,w]
                tmp = patch_local.permute(0, 4, 1, 2, 3).contiguous()
                pad = (pw0, pw1, ph0, ph1, pd0, pd1)
                tmp = F.pad(tmp, pad)
                patch_local = tmp.permute(0, 2, 3, 4, 1).contiguous()
            # ensure final shape is [B, ws, ws, ws, d_model]
            return patch_local

        def mask_fn_patch(Dp, Hp, Wp, dd, hh, ww, window_size):
            # We'll use the known_mask passed to the outer forward to compute mask per patch position
            # known_mask: [B, ws, ws, ws] - but voxel transformer expects mask independent of batch or same across batch
            # We return per-batch masks in VoxelTransformerLayer3D; it will expand as needed.
            # compute boundaries within patch
            r = window_size // 2
            d0, d1 = max(0, dd - r), min(Dp, dd + r + 1)
            h0, h1 = max(0, hh - r), min(Hp, hh + r + 1)
            w0, w1 = max(0, ww - r), min(Wp, ww + r + 1)
            # Because we already have known_mask for the entire patch, we'll extract the appropriate patch of known_mask
            # BUT VoxelTransformerLayer3D will call mask_fn_patch(Dp,Hp,Wp,dd,hh,ww,ws) which doesn't give batch index.
            # To keep simple, return a mask that will be interpreted per-batch inside VoxelTransformerLayer3D.
            # We'll return the mask for a single (arbitrary) batch element shape [ws,ws,ws], but VoxelTransformerLayer3D
            # will expand to [B, ws, ws, ws].
            # Build a mask with all True (attend all) — the LocalAttention will still receive per-batch mask when it is used.
            # Here, deliver a mask that only marks valid (in-range) positions (True) — padding areas will be False.
            mask = torch.ones((window_size, window_size, window_size), dtype=torch.bool, device=emb.device)
            # pad areas relative to patch edges:
            pd0 = max(0, r - dd); pd1 = max(0, (dd + r + 1) - Dp)
            ph0 = max(0, r - hh); ph1 = max(0, (hh + r + 1) - Hp)
            pw0 = max(0, r - ww); pw1 = max(0, (ww + r + 1) - Wp)
            if pd0 > 0:
                mask[:pd0, :, :] = False
            if pd1 > 0:
                mask[-pd1:, :, :] = False
            if ph0 > 0:
                mask[:, :ph0, :] = False
            if ph1 > 0:
                mask[:, -ph1:, :] = False
            if pw0 > 0:
                mask[:, :, :pw0] = False
            if pw1 > 0:
                mask[:, :, -pw1:] = False
            return mask.unsqueeze(0)  # [1, ws, ws, ws]

        # Now run the transformer on emb patch
        # emb: [B, ws, ws, ws, d_model]
        # we must give neighborhood_fn_patch and mask_fn_patch to transformer
        out_patch = self.transformer(emb, neighborhood_fn_patch, mask_fn_patch)  # [B, ws, ws, ws, d_model]

        # extract center voxel index
        center = ws // 2
        center_emb = out_patch[:, center, center, center, :]  # [B, d_model]
        logits = self.output_head(center_emb)  # [B, 1]
        return logits


def neighborhood_fn(grid, b, d, h, w, window_size):
    """
    Extract a cubic neighborhood around voxel (d,h,w) from batch element b.
    grid: [B, C, D, H, W]
    Returns: [1, C, ws, ws, ws]
    """
    assert grid.dim() == 5, f"Expected 5D grid, got {grid.shape}"
    assert window_size % 2 == 1, "window_size must be odd."
    radius = window_size // 2

    # Pad grid symmetrically
    padded = F.pad(grid[b:b+1], (radius, radius, radius, radius, radius, radius), mode="constant", value=0)

    # shift indices because of padding
    d, h, w = d + radius, h + radius, w + radius

    # slice out the patch
    patch = padded[:, :, d-radius:d+radius+1, h-radius:h+radius+1, w-radius:w+radius+1]
    return patch  # [1, C, ws, ws, ws]


def mask_fn(grid_shape, d, h, w, window_size):
    """
    Create a binary mask for the cubic neighborhood around (d,h,w).
    grid_shape: (D,H,W)
    Returns: [1, ws, ws, ws]
    """
    D, H, W = grid_shape
    assert window_size % 2 == 1, "window_size must be odd."
    radius = window_size // 2

    mask = torch.ones((window_size, window_size, window_size), dtype=torch.float32)

    # Prevent model from "seeing" the target voxel itself
    mask[radius, radius, radius] = 0.0

    # Handle boundaries (mask out-of-range neighbors)
    for dd in range(window_size):
        for hh in range(window_size):
            for ww in range(window_size):
                gd, gh, gw = d + dd - radius, h + hh - radius, w + ww - radius
                if not (0 <= gd < D and 0 <= gh < H and 0 <= gw < W):
                    mask[dd, hh, ww] = 0.0

    return mask.unsqueeze(0)  # [1, ws, ws, ws]



In [3]:
MODEL_SAVE_PATH = "iterative_model.pth"

In [4]:
import numpy as np
import torch
import glob
import json 

# zip_path = "../../model_data/chunk_data_16_flood_fill_rm_20.zip"
zip_path = "/home/raedfidawi/Documents/thesis/3DLLM/chunk_data_16_flood_fill_rm_20.zip"
dataset = VoxelDataset(zip_path)

test_dir = "../../test_data/test_data_rm_20_RES_32/"
test_indices_file = os.path.join(test_dir, "test_indices.json")

with open(test_indices_file, "r") as f:
    test_idx = json.load(f)

test_samples = []
for idx in range(1000):
    complete, partial = dataset[idx]
    test_samples.append((complete, partial))

print(f"Loaded {len(test_samples)} test samples from {test_dir}")



Created temporary directory: /tmp/tmpn8lpdtnz
Extracted zip file to temporary directory
Found 256571 total NPZ files
Using 256571
Loaded 1000 test samples from ../../test_data/test_data_rm_20_RES_32/


# Voxels given to Model: All empty voxels

In [56]:
import torch
import numpy as np
import random
from collections import deque

def iterative_inference(model_path, test_samples, device):
    """
    Perform iterative inference on a random partial voxel object from test_samples.
    Sorts empty voxels by density in 3x3x3 window (filled neighbors), breaking ties by distance to origin.
    """
    print(f"Using device: {device}")
    # Pick a random sample
    # sample_idx = random.randint(0, len(test_samples)-1)
    sample_idx = 641
    complete, partial = test_samples[sample_idx]

    print(f"Selected sample index: {sample_idx}")
    
    # Load model
    # model = IterativeVoxelModel(
    #     d_model=96,
    #     num_heads=6,
    #     num_layers=3,
    #     max_grid_size=32,
    #     window_size=3,
    #     dropout=0.1
    # ).to(device)

    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,
        dropout=0.1
    ).to(device)

    # Load checkpoint
    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']

    # Handle DataParallel prefix
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()  # initially known voxels

    D, H, W = partial.shape
    window = model.window_size
    half = window // 2

    # Get coordinates of empty voxels as (x,y,z)
    empty_voxels = torch.nonzero(partial_tensor[0,0] == 0, as_tuple=False)  # [N,3]

    # --- Sort empty voxels by density in 3x3x3 window, break ties by distance to origin ---
    def count_filled_neighbors_3x3x3(x, y, z, grid):
        count = 0
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x + dx, y + dy, z + dz
                    if 0 <= nx < D and 0 <= ny < H and 0 <= nz < W:
                        if dx == 0 and dy == 0 and dz == 0:
                            continue
                        if grid[0,0,nx,ny,nz] > 0:
                            count += 1
        return count

    voxel_scores = []
    for v in empty_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors_3x3x3(x, y, z, partial_tensor)
        dist = x**2 + y**2 + z**2  # Euclidean squared distance to origin
        voxel_scores.append((x, y, z, density, dist))

    # Sort: high density first, then close to origin
    voxel_scores.sort(key=lambda x: (-x[3], x[4]))
    sorted_empty_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]
    print("Entering prediction...")
    with torch.no_grad():
        i = 0
        for x, y, z in sorted_empty_voxels:
            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))  # [1,1]
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > 0.9 else 0.0
            partial_tensor[0,0,x,y,z] = pred_filled
            known_mask[0,0,x,y,z] = 1.0
            print(f"Predicted {i}th Voxel")
            i += 1

    # Save outputs
    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())

    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")


In [57]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
iterative_inference(MODEL_SAVE_PATH, test_samples, device=device)

Using device: cuda
Selected sample index: 641
Entering prediction...
Predicted 0th Voxel
Predicted 1th Voxel
Predicted 2th Voxel
Predicted 3th Voxel
Predicted 4th Voxel
Predicted 5th Voxel
Predicted 6th Voxel
Predicted 7th Voxel
Predicted 8th Voxel
Predicted 9th Voxel
Predicted 10th Voxel
Predicted 11th Voxel
Predicted 12th Voxel
Predicted 13th Voxel
Predicted 14th Voxel
Predicted 15th Voxel
Predicted 16th Voxel
Predicted 17th Voxel
Predicted 18th Voxel
Predicted 19th Voxel
Predicted 20th Voxel
Predicted 21th Voxel
Predicted 22th Voxel
Predicted 23th Voxel
Predicted 24th Voxel
Predicted 25th Voxel
Predicted 26th Voxel
Predicted 27th Voxel
Predicted 28th Voxel
Predicted 29th Voxel
Predicted 30th Voxel
Predicted 31th Voxel
Predicted 32th Voxel
Predicted 33th Voxel
Predicted 34th Voxel
Predicted 35th Voxel
Predicted 36th Voxel
Predicted 37th Voxel
Predicted 38th Voxel
Predicted 39th Voxel
Predicted 40th Voxel
Predicted 41th Voxel
Predicted 42th Voxel
Predicted 43th Voxel
Predicted 44th Vo

# Voxels Given to Model: The missing Voxels

In [37]:
def inference_missing_voxels(model_path, test_samples, device):
    """
    Perform inference only on voxels that are missing from the complete object (i.e., (partial == 0) & (complete == 1)).
    Sorts these voxels by density in 3x3x3 window (filled neighbors), breaking ties by distance to origin.
    """
    print(f"Using device: {device}")
    sample_idx = random.randint(0, len(test_samples)-1)
    # sample_idx = 5
    complete, partial = test_samples[sample_idx]
    print(f"Selected sample index: {sample_idx}")


    # model = IterativeVoxelModel(
    #     d_model=96,
    #     num_heads=6,
    #     num_layers=3,
    #     max_grid_size=32,
    #     window_size=3,
    #     dropout=0.1
    # ).to(device)

    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,
        dropout=0.1
    ).to(device)


    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()
    complete_tensor = complete.unsqueeze(0).unsqueeze(0).to(device)
    D, H, W = partial.shape
    window = model.window_size
    # Find missing voxels: present in complete, missing in partial
    missing_voxels = torch.nonzero((partial_tensor[0,0] == 0) & (complete_tensor[0,0] == 1), as_tuple=False)  # [N,3]
    # Sort missing voxels by density in 3x3x3 window, break ties by distance to origin
    def count_filled_neighbors_3x3x3(x, y, z, grid):
        count = 0
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x + dx, y + dy, z + dz
                    if 0 <= nx < D and 0 <= ny < H and 0 <= nz < W:
                        if dx == 0 and dy == 0 and dz == 0:
                            continue
                        if grid[0,0,nx,ny,nz] > 0:
                            count += 1
        return count
    voxel_scores = []
    for v in missing_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors_3x3x3(x, y, z, partial_tensor)
        dist = x**2 + y**2 + z**2
        voxel_scores.append((x, y, z, density, dist))
    voxel_scores.sort(key=lambda x: (-x[3], x[4]))
    sorted_missing_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]
    print("Entering prediction for missing voxels...")
    with torch.no_grad():
        i = 0
        for x, y, z in sorted_missing_voxels:
            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > 0.5 else 0.0
            partial_tensor[0,0,x,y,z] = pred_filled
            known_mask[0,0,x,y,z] = 1.0
            print(f"Predicted {i}th missing voxel")
            i += 1
    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())
    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_missing_voxels(MODEL_SAVE_PATH, test_samples, device=device)

Using device: cuda
Selected sample index: 373
Entering prediction for missing voxels...
Predicted 0th missing voxel
Predicted 1th missing voxel
Predicted 2th missing voxel
Predicted 3th missing voxel
Predicted 4th missing voxel
Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.


# Voxels given to Model: 50% of the voxels are all truly missing and 50% are randomly empty and irrelevant voxels 

In [41]:
def inference_balanced_missing_empty(model_path, test_samples, device):
    """
    Perform inference on a balanced set of voxels:
    - Half are missing voxels (present in complete, missing in partial)
    - Half are truly empty voxels (empty in both complete and partial), randomly sampled to match the number of missing voxels
    Voxels are sorted by filled neighbor density (3x3x3 window), breaking ties by distance to origin.
    Additionally, save an npy file ('test.npy') where all truly missing voxels are set to filled (1).
    """
    print(f"Using device: {device}")
    sample_idx = random.randint(0, len(test_samples)-1)
    # sample_idx = 5
    complete, partial = test_samples[sample_idx]
    print(f"Selected sample index: {sample_idx}")

    # model = IterativeVoxelModel(
    #     d_model=96,
    #     num_heads=6,
    #     num_layers=3,
    #     max_grid_size=32,
    #     window_size=3,
    #     dropout=0.1
    # ).to(device)

    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,
        dropout=0.1
    ).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()
    complete_tensor = complete.unsqueeze(0).unsqueeze(0).to(device)
    D, H, W = partial.shape
    window = model.window_size
    # Find missing voxels: present in complete, missing in partial
    missing_voxels = torch.nonzero((partial_tensor[0,0] == 0) & (complete_tensor[0,0] == 1), as_tuple=False)  # [N,3]
    num_missing = missing_voxels.shape[0]
    # Find truly empty voxels: empty in both complete and partial
    truly_empty_voxels = torch.nonzero((partial_tensor[0,0] == 0) & (complete_tensor[0,0] == 0), as_tuple=False)  # [M,3]
    # Randomly sample truly empty voxels to match number of missing voxels
    if num_missing > 0 and truly_empty_voxels.shape[0] > 0:
        FACTOR = 3
        perm = torch.randperm(truly_empty_voxels.shape[0])[:num_missing*FACTOR]
        sampled_empty_voxels = truly_empty_voxels[perm]
    else:
        sampled_empty_voxels = torch.empty((0,3), dtype=torch.long)
    # Combine both sets
    candidate_voxels = torch.cat([missing_voxels, sampled_empty_voxels], dim=0)
    # Sort candidates by density and distance to origin
    def count_filled_neighbors_3x3x3(x, y, z, grid):
        count = 0
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x + dx, y + dy, z + dz
                    if 0 <= nx < D and 0 <= ny < H and 0 <= nz < W:
                        if dx == 0 and dy == 0 and dz == 0:
                            continue
                        if grid[0,0,nx,ny,nz] > 0:
                            count += 1
        return count
    voxel_scores = []
    for v in candidate_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors_3x3x3(x, y, z, partial_tensor)
        dist = x**2 + y**2 + z**2
        voxel_scores.append((x, y, z, density, dist))
    voxel_scores.sort(key=lambda x: (-x[3], x[4]))
    sorted_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]
    print(f"Entering prediction for {len(sorted_voxels)} balanced voxels...")
    with torch.no_grad():
        i = 0
        for x, y, z in sorted_voxels:
            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > 0.9 else 0.0
            partial_tensor[0,0,x,y,z] = pred_filled
            known_mask[0,0,x,y,z] = 1.0
            print(f"Predicted {i}th balanced voxel")
            i += 1
    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())
    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")

    # --- Save test.npy with all truly missing voxels set to filled (1) ---
    test_voxel = partial.clone()
    for v in sampled_empty_voxels:
        x, y, z = v.tolist()
        test_voxel[x, y, z] = 1.0
    np.save("test.npy", test_voxel.cpu().numpy())
    print("test.npy saved: all truly missing voxels set to filled.")


In [44]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_balanced_missing_empty(MODEL_SAVE_PATH, test_samples, device=device)

Using device: cuda
Selected sample index: 823
Entering prediction for 52 balanced voxels...
Predicted 0th balanced voxel
Predicted 1th balanced voxel
Predicted 2th balanced voxel
Predicted 3th balanced voxel
Predicted 4th balanced voxel
Predicted 5th balanced voxel
Predicted 6th balanced voxel
Predicted 7th balanced voxel
Predicted 8th balanced voxel
Predicted 9th balanced voxel
Predicted 10th balanced voxel
Predicted 11th balanced voxel
Predicted 12th balanced voxel
Predicted 13th balanced voxel
Predicted 14th balanced voxel
Predicted 15th balanced voxel
Predicted 16th balanced voxel
Predicted 17th balanced voxel
Predicted 18th balanced voxel
Predicted 19th balanced voxel
Predicted 20th balanced voxel
Predicted 21th balanced voxel
Predicted 22th balanced voxel
Predicted 23th balanced voxel
Predicted 24th balanced voxel
Predicted 25th balanced voxel
Predicted 26th balanced voxel
Predicted 27th balanced voxel
Predicted 28th balanced voxel
Predicted 29th balanced voxel
Predicted 30th bal

In [None]:
def inference_balanced_missing_empty_compact(model_path, test_samples, device):
    """
    Perform inference on a balanced set of voxels:
    - Half are missing voxels (present in complete, missing in partial)
    - Half are truly empty voxels (empty in both complete and partial), sampled to be close to the object (high filled neighbor density)
    Voxels are sorted by filled neighbor density (3x3x3 window), breaking ties by distance to origin.
    Additionally, save an npy file ('test.npy') where all truly missing voxels are set to filled (1).
    """
    print(f"Using device: {device}")
    sample_idx = random.randint(0, len(test_samples)-1)
    # sample_idx = 5
    complete, partial = test_samples[sample_idx]
    print(f"Selected sample index: {sample_idx}")

    # model = IterativeVoxelModel(
    #     d_model=96,
    #     num_heads=6,
    #     num_layers=3,
    #     max_grid_size=32,
    #     window_size=3,
    #     dropout=0.1
    # ).to(device)

    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,
        dropout=0.1
    ).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()
    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()
    complete_tensor = complete.unsqueeze(0).unsqueeze(0).to(device)
    D, H, W = partial.shape
    window = model.window_size
    # Find missing voxels: present in complete, missing in partial
    missing_voxels = torch.nonzero((partial_tensor[0,0] == 0) & (complete_tensor[0,0] == 1), as_tuple=False).to(device)  # [N,3]
    num_missing = missing_voxels.shape[0]
    # Find truly empty voxels: empty in both complete and partial
    truly_empty_voxels = torch.nonzero((partial_tensor[0,0] == 0) & (complete_tensor[0,0] == 0), as_tuple=False).to(device)  # [M,3]
    # --- Sample truly empty voxels to be close to the object (high filled neighbor density) ---
    def count_filled_neighbors_3x3x3(x, y, z, grid):
        count = 0
        for dx in [-1, 0, 1]:
            for dy in [-1, 0, 1]:
                for dz in [-1, 0, 1]:
                    nx, ny, nz = x + dx, y + dy, z + dz
                    if 0 <= nx < D and 0 <= ny < H and 0 <= nz < W:
                        if dx == 0 and dy == 0 and dz == 0:
                            continue
                        if grid[0,0,nx,ny,nz] > 0:
                            count += 1
        return count
    # Score truly empty voxels by density and distance
    empty_scores = []
    for v in truly_empty_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors_3x3x3(x, y, z, partial_tensor)
        dist = x**2 + y**2 + z**2
        empty_scores.append((x, y, z, density, dist))
    # Sort by high density, then close to origin
    empty_scores.sort(key=lambda x: (-x[3], x[4]))
    # Take top num_missing*FACTOR voxels
    FACTOR = 8
    sampled_empty_voxels = torch.tensor([s[:3] for s in empty_scores[:num_missing*FACTOR]], dtype=torch.long, device=device) if num_missing > 0 and len(empty_scores) > 0 else torch.empty((0,3), dtype=torch.long, device=device)
    # Ensure both tensors are on the same device before concatenation
    missing_voxels = missing_voxels.to(device)
    sampled_empty_voxels = sampled_empty_voxels.to(device)
    candidate_voxels = torch.cat([missing_voxels, sampled_empty_voxels], dim=0)
    # Sort candidates by density and distance to origin
    voxel_scores = []
    for v in candidate_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors_3x3x3(x, y, z, partial_tensor)
        dist = x**2 + y**2 + z**2
        voxel_scores.append((x, y, z, density, dist))
    voxel_scores.sort(key=lambda x: (-x[3], x[4]))
    sorted_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]
    print(f"Entering prediction for {len(sorted_voxels)} balanced voxels...")
    with torch.no_grad():
        i = 0
        for x, y, z in sorted_voxels:
            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > 0.9 else 0.0
            partial_tensor[0,0,x,y,z] = pred_filled
            known_mask[0,0,x,y,z] = 1.0
            print(f"Predicted {i}th balanced voxel")
            i += 1
    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())
    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")

    # --- Save test.npy with all sampled truly empty voxels set to filled (1) ---
    test_voxel = partial.clone()
    for v in sampled_empty_voxels:
        x, y, z = v.tolist()
        test_voxel[x, y, z] = 1.0
    np.save("test.npy", test_voxel.cpu().numpy())
    print("test.npy saved: all sampled truly empty voxels set to filled.")


In [90]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_balanced_missing_empty_compact(MODEL_SAVE_PATH, test_samples, device=device)

Using device: cuda
Selected sample index: 986
Entering prediction for 935 balanced voxels...
Predicted 0th balanced voxel
Predicted 1th balanced voxel
Predicted 2th balanced voxel
Predicted 3th balanced voxel
Predicted 4th balanced voxel
Predicted 5th balanced voxel
Predicted 6th balanced voxel
Predicted 7th balanced voxel
Predicted 8th balanced voxel
Predicted 9th balanced voxel
Predicted 10th balanced voxel
Predicted 11th balanced voxel
Predicted 12th balanced voxel
Predicted 13th balanced voxel
Predicted 14th balanced voxel
Predicted 15th balanced voxel
Predicted 16th balanced voxel
Predicted 17th balanced voxel
Predicted 18th balanced voxel
Predicted 19th balanced voxel
Predicted 20th balanced voxel
Predicted 21th balanced voxel
Predicted 22th balanced voxel
Predicted 23th balanced voxel
Predicted 24th balanced voxel
Predicted 25th balanced voxel
Predicted 26th balanced voxel
Predicted 27th balanced voxel
Predicted 28th balanced voxel
Predicted 29th balanced voxel
Predicted 30th ba

In [28]:
def inference_balanced_missing_empty_compact(model_path, test_samples, device):
    """
    Perform inference on a balanced set of voxels:
    - Half are missing voxels (present in complete, missing in partial)
    - Half are truly empty voxels (empty in both complete and partial), sampled to be close to the object (high filled neighbor density)
    Voxels are sorted by filled neighbor density using a generic cubic window of size `window_size` (must be odd),
    breaking ties by distance to origin.
    Additionally, save an npy file ('test.npy') where all sampled truly empty voxels are set to filled (1).
    """
    print(f"Using device: {device}")
    sample_idx = random.randint(0, len(test_samples)-1)
    sample_idx = 438
    complete, partial = test_samples[sample_idx]
    print(f"Selected sample index: {sample_idx}")

    # You can tweak these (kept small here)
    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,  # <-- generic window size; must be an odd integer
        dropout=0.1
    ).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()
    complete_tensor = complete.unsqueeze(0).unsqueeze(0).to(device)

    D, H, W = partial.shape
    window = int(model.window_size)
    assert window % 2 == 1 and window >= 1, "window_size must be an odd positive integer"

    # ---- Helper: generic neighbor counter for any odd window size ----
    def count_filled_neighbors(x, y, z, grid, win):
        """
        Count filled neighbors in a win x win x win cube centered at (x,y,z), excluding the center voxel.
        Assumes grid shape [1,1,D,H,W] with >0 treated as filled.
        """
        r = win // 2
        count = 0
        # clamp bounds once to avoid repeated checks
        x0, x1 = max(0, x - r), min(D - 1, x + r)
        y0, y1 = max(0, y - r), min(H - 1, y + r)
        z0, z1 = max(0, z - r), min(W - 1, z + r)
        for nx in range(x0, x1 + 1):
            for ny in range(y0, y1 + 1):
                for nz in range(z0, z1 + 1):
                    if nx == x and ny == y and nz == z:
                        continue  # exclude center
                    # grid is float/byte; >0 => filled
                    if grid[0, 0, nx, ny, nz] > 0:
                        count += 1
        return count

    # Find missing voxels: present in complete, missing in partial
    missing_voxels = torch.nonzero(
        (partial_tensor[0, 0] == 0) & (complete_tensor[0, 0] == 1),
        as_tuple=False
    ).to(device)  # [N,3]
    num_missing = missing_voxels.shape[0]

    # Find truly empty voxels: empty in both complete and partial
    truly_empty_voxels = torch.nonzero(
        (partial_tensor[0, 0] == 0) & (complete_tensor[0, 0] == 0),
        as_tuple=False
    ).to(device)  # [M,3]

    # --- Score truly empty voxels by density (generic window) and distance ---
    empty_scores = []
    for v in truly_empty_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors(x, y, z, partial_tensor, window)
        dist = x**2 + y**2 + z**2
        empty_scores.append((x, y, z, density, dist))

    # Sort by high density, then close to origin
    empty_scores.sort(key=lambda t: (-t[3], t[4]))

    # Take top num_missing * FACTOR voxels
    FACTOR = 5
    limit = num_missing * FACTOR if num_missing > 0 else 0
    sampled_empty_voxels = (
        torch.tensor([s[:3] for s in empty_scores[:limit]], dtype=torch.long, device=device)
        if limit > 0 and len(empty_scores) > 0 else
        torch.empty((0, 3), dtype=torch.long, device=device)
    )

    # Candidate voxels: union of missing + sampled empties
    candidate_voxels = (
        torch.cat([missing_voxels, sampled_empty_voxels], dim=0)
        if missing_voxels.numel() > 0 or sampled_empty_voxels.numel() > 0
        else torch.empty((0, 3), dtype=torch.long, device=device)
    )

    # Sort candidates by density (generic window) and distance to origin
    voxel_scores = []
    for v in candidate_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors(x, y, z, partial_tensor, window)
        dist = x**2 + y**2 + z**2
        voxel_scores.append((x, y, z, density, dist))
    voxel_scores.sort(key=lambda t: (-t[3], t[4]))
    sorted_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]

    print(f"Entering prediction for {len(sorted_voxels)} balanced voxels...")
    with torch.no_grad():
        for i, (x, y, z) in enumerate(sorted_voxels):
            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > 0.5 else 0.0
            partial_tensor[0, 0, x, y, z] = pred_filled
            known_mask[0, 0, x, y, z] = 1.0
            print(f"Predicted {i}th balanced voxel")
            if i %50 == 0:
                np.save(f"output_voxel_{i}.npy", partial_tensor.squeeze().cpu().numpy())

    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())
    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")

    # --- Save test.npy with all sampled truly empty voxels set to filled (1) ---
    test_voxel = partial.clone()
    for v in sampled_empty_voxels:
        x, y, z = v.tolist()
        test_voxel[x, y, z] = 1.0
    np.save("test.npy", test_voxel.cpu().numpy())
    print("test.npy saved: all sampled truly empty voxels set to filled.")


In [29]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inference_balanced_missing_empty_compact(MODEL_SAVE_PATH, test_samples, device=device)

Using device: cuda
Selected sample index: 438
Entering prediction for 216 balanced voxels...
Predicted 0th balanced voxel
Predicted 1th balanced voxel
Predicted 2th balanced voxel
Predicted 3th balanced voxel
Predicted 4th balanced voxel
Predicted 5th balanced voxel
Predicted 6th balanced voxel
Predicted 7th balanced voxel
Predicted 8th balanced voxel
Predicted 9th balanced voxel
Predicted 10th balanced voxel
Predicted 11th balanced voxel
Predicted 12th balanced voxel
Predicted 13th balanced voxel
Predicted 14th balanced voxel
Predicted 15th balanced voxel
Predicted 16th balanced voxel
Predicted 17th balanced voxel
Predicted 18th balanced voxel
Predicted 19th balanced voxel
Predicted 20th balanced voxel
Predicted 21th balanced voxel
Predicted 22th balanced voxel
Predicted 23th balanced voxel
Predicted 24th balanced voxel
Predicted 25th balanced voxel
Predicted 26th balanced voxel
Predicted 27th balanced voxel
Predicted 28th balanced voxel
Predicted 29th balanced voxel
Predicted 30th ba

In [16]:
import math

def iterative_inference(model_path, test_samples, device, no_change_patience=100, prob_threshold=0.9):
    """
    Perform iterative inference on a random partial voxel object from test_samples.
    Sorts empty voxels by density in a window (model.window_size), breaking ties by distance to origin.
    Stops early if there are `no_change_patience` consecutive voxels with no change.
    """
    print(f"Using device: {device}")
    sample_idx = random.randint(0, len(test_samples)-1)
    # sample_idx = 300
    complete, partial = test_samples[sample_idx]
    print(f"Selected sample index: {sample_idx}")

    model = IterativeVoxelModel(
        d_model=48,
        num_heads=4,
        num_layers=3,
        max_grid_size=16,
        window_size=5,   # uses generic neighbor window you added earlier
        dropout=0.1
    ).to(device)

    checkpoint = torch.load(model_path, map_location=device)
    state_dict = checkpoint['model_state_dict']
    from collections import OrderedDict
    new_state_dict = OrderedDict()
    for k, v in state_dict.items():
        new_key = k.replace('module.', '') if k.startswith('module.') else k
        new_state_dict[new_key] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    partial_tensor = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1,1,D,H,W]
    known_mask = (partial_tensor > 0).float()

    D, H, W = partial.shape
    window = int(model.window_size)

    # --- neighbor density (generic window) ---
    def count_filled_neighbors(x, y, z, grid, win):
        r = win // 2
        x0, x1 = max(0, x - r), min(D - 1, x + r)
        y0, y1 = max(0, y - r), min(H - 1, y + r)
        z0, z1 = max(0, z - r), min(W - 1, z + r)
        cnt = 0
        for nx in range(x0, x1 + 1):
            for ny in range(y0, y1 + 1):
                for nz in range(z0, z1 + 1):
                    if nx == x and ny == y and nz == z:
                        continue
                    if grid[0, 0, nx, ny, nz] > 0:
                        cnt += 1
        return cnt

    empty_voxels = torch.nonzero(partial_tensor[0, 0] == 0, as_tuple=False)  # [N,3]
    voxel_scores = []
    for v in empty_voxels:
        x, y, z = v.tolist()
        density = count_filled_neighbors(x, y, z, partial_tensor, window)
        dist = x**2 + y**2 + z**2
        voxel_scores.append((x, y, z, density, dist))
    voxel_scores.sort(key=lambda t: (-t[3], t[4]))
    sorted_empty_voxels = [(x, y, z) for x, y, z, _, _ in voxel_scores]

    # --- early-stop state ---
    no_change_run = 0

    print("Entering prediction...")
    with torch.inference_mode():  # slightly faster than no_grad
        for i, (x, y, z) in enumerate(sorted_empty_voxels):
            old_val = partial_tensor[0, 0, x, y, z]  # tensor scalar on device

            patch, mask_patch = neighborhood_raw(partial_tensor, 0, x, y, z, window)
            logits = model(patch.to(device), mask_patch.to(device))
            prob = torch.sigmoid(logits).squeeze().item()
            pred_filled = 1.0 if prob > prob_threshold else 0.0
            partial_tensor[0, 0, x, y, z] = pred_filled
            changed = (pred_filled != old_val).item()
            known_mask[0, 0, x, y, z] = 1.0

            if changed:
                no_change_run = 0
            else:
                no_change_run += 1
                if no_change_run >= no_change_patience:
                    print(f"No change for {no_change_patience} consecutive voxels. Early stopping at i={i}.")
                    break

            if (i % 500) == 0:
                print(f"Predicted {i}/{len(sorted_empty_voxels)} (no-change run={no_change_run})")

    # Save outputs
    output_voxels = partial_tensor.squeeze().cpu().numpy()
    np.save("output_voxel.npy", output_voxels)
    np.save("partial_voxel.npy", partial.numpy())
    np.save("complete_voxel.npy", complete.numpy())
    print("Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.")


In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
iterative_inference(MODEL_SAVE_PATH, test_samples, device=device, no_change_patience=1000, prob_threshold=0.9)
# interesting 641
# 300

Using device: cuda
Selected sample index: 292
Entering prediction...
Predicted 0/3842 (no-change run=1)
Predicted 500/3842 (no-change run=501)
No change for 1000 consecutive voxels. Early stopping at i=999.
Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.


In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
iterative_inference("iterative_model_ws_5.pth", test_samples, device=device, no_change_patience=1000, prob_threshold=0.9)

Using device: cuda
Selected sample index: 300
Entering prediction...
Predicted 0/3546 (no-change run=1)
Predicted 500/3546 (no-change run=501)
No change for 1000 consecutive voxels. Early stopping at i=999.
Inference complete. Voxels saved to output_voxel.npy, partial_voxel.npy, complete_voxel.npy.
