In [1]:
# extension to agent copy 4.ipynb F1-score, pos weight
from typing import Tuple, List
from dataclasses import dataclass
from tqdm import tqdm
from datetime import timedelta
import zipfile
import shutil
import tempfile
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
import time
import os
import glob
from torch.utils.data import Subset, DataLoader

In [2]:
class VoxelDataLoader:
    """Loads and processes NPZ voxel data from a zip file"""

    def __init__(self, zip_path: str):
        # Create a temporary directory
        self.temp_dir = tempfile.mkdtemp()
        print(f"Created temporary directory: {self.temp_dir}")

        # Extract zip file
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.temp_dir)
        print(f"Extracted zip file to temporary directory")

        # Find all NPZ files
        all_files = glob.glob(os.path.join(self.temp_dir, "**/*.npz"), recursive=True)
        print(f"Found {len(all_files)} total NPZ files")

        if len(all_files) == 0:
            raise ValueError(f"No NPZ files found in zip file")

        random.shuffle(all_files)  # Shuffle before splitting
        cutoff = int(len(all_files))
        self.npz_files = all_files[:cutoff]
        print(f"Using {len(self.npz_files)}")

    def __del__(self):
        """Cleanup temporary directory when object is destroyed"""
        try:
            shutil.rmtree(self.temp_dir)
            print(f"Cleaned up temporary directory: {self.temp_dir}")
        except:
            print(f"Failed to clean up temporary directory: {self.temp_dir}")

    def load_single_file(self, file_path: str) -> Tuple[torch.Tensor, torch.Tensor]:
        data = np.load(file_path)

        # More robust key checking
        if 'complete' not in data or 'partial' not in data:
            raise ValueError(f"NPZ file {file_path} must contain both 'complete' and 'partial' arrays")

        complete = torch.from_numpy(data['complete']).float()
        partial = torch.from_numpy(data['partial']).float()

        # Verify shapes match
        if complete.shape != partial.shape:
            raise ValueError(f"Shape mismatch in {file_path}: complete {complete.shape} vs partial {partial.shape}")

        return complete, partial

    def get_all_data(self) -> List[Tuple[torch.Tensor, torch.Tensor]]:
        """Load all voxel pairs from all NPZ files"""
        all_data = []
        for file_path in self.npz_files:
            complete, partial = self.load_single_file(file_path)
            all_data.append((complete, partial))
        return all_data

    def get_voxel_grids(self, index: int = 0) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns complete and partial voxel grids from a specific file"""
        if index >= len(self.npz_files):
            raise IndexError(f"Index {index} out of range. Only {len(self.npz_files)} files available.")
        return self.load_single_file(self.npz_files[index])


class VoxelDataset(torch.utils.data.Dataset):
    """PyTorch Dataset for voxel completion"""

    def __init__(self, zip_path: str, transform=None):
        self.data_loader = VoxelDataLoader(zip_path)
        self.transform = transform

    def __len__(self):
        return len(self.data_loader.npz_files)

    def __getitem__(self, idx):
        complete, partial = self.data_loader.get_voxel_grids(idx)
        # Normalize to [0,1] if not already
        complete = (complete > 0).float()
        partial = (partial > 0).float()
        if self.transform:
            complete, partial = self.transform(complete, partial)
        return complete, partial


# Update data loader creation function
def create_data_loader(zip_path: str, batch_size: int = 1, shuffle: bool = True, num_workers: int = 0):
    """Create a PyTorch DataLoader for training"""
    dataset = VoxelDataset(zip_path)
    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        # pin_memory=True
    )


def split_dataset(dataset, train_ratio=0.8, val_ratio=0.2, seed=42):
    n = len(dataset)
    indices = list(range(n))
    random.Random(seed).shuffle(indices)

    # from dataset: 80% train 20% test 
    n_trainval = int(n * 0.8)
    n_test = n - n_trainval
    trainval_indices = indices[:n_trainval]
    test_indices = indices[n_trainval:]
    # from training data: 80% train 20% validation
    n_train = int(len(trainval_indices) * 0.8)
    train_indices = trainval_indices[:n_train]
    val_indices = trainval_indices[n_train:]

    return train_indices, val_indices, test_indices

def create_data_loaders(zip_path, batch_size=1, shuffle=True, num_workers=0, seed=42):
    dataset = VoxelDataset(zip_path)
    print(f"Dataset size: {len(dataset)}")
    train_idx, val_idx, test_idx = split_dataset(dataset, seed=seed)
    train_loader = DataLoader(Subset(dataset, train_idx), batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
    val_loader = DataLoader(Subset(dataset, val_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)
    test_loader = DataLoader(Subset(dataset, test_idx), batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_loader, val_loader, test_loader

class SpatialAttention3D(nn.Module):
    """
    Efficient 3D spatial attention with proper windowing.
    Maintains 3D structure throughout (never flattens below 3D).
    At each level, attention looks in all 6 directions (behind, in front, left, right, above, under) via 3D windowing.
    Now supports dynamic window size per forward pass.
    """
    def __init__(self, d_model: int, num_heads: int = 4, window_size: int = 3):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.window_size = window_size
        self.head_dim = d_model // num_heads
        assert d_model % num_heads == 0
        self.qkv = nn.Conv3d(d_model, d_model * 3, kernel_size=1)
        self.proj = nn.Conv3d(d_model, d_model, kernel_size=1)
        self.scale = self.head_dim ** -0.5

    def forward(self, x, window_size=None):
        B, C, D, H, W = x.shape
        # Generate Q, K, V
        qkv = self.qkv(x)  # [B, 3*C, D, H, W]
        q, k, v = qkv.chunk(3, dim=1)  # Each: [B, C, D, H, W]
        # Reshape for multi-head attention
        q = q.view(B, self.num_heads, self.head_dim, D, H, W)
        k = k.view(B, self.num_heads, self.head_dim, D, H, W)
        v = v.view(B, self.num_heads, self.head_dim, D, H, W)
        # Extract windows efficiently using unfold
        ws = window_size if window_size is not None else self.window_size
        pad = ws // 2
        # Pad the tensors
        q_pad = F.pad(q, [pad]*6, mode='constant', value=0)
        k_pad = F.pad(k, [pad]*6, mode='constant', value=0)
        v_pad = F.pad(v, [pad]*6, mode='constant', value=0)
        # Extract windows - much more efficient than conv3d approach
        def extract_windows(tensor):
            # tensor: [B, heads, head_dim, D_pad, H_pad, W_pad]
            windows = tensor.unfold(3, ws, 1).unfold(4, ws, 1).unfold(5, ws, 1)
            # Result: [B, heads, head_dim, D, H, W, ws, ws, ws]
            return windows.contiguous()
        q_win = extract_windows(q_pad)  # [B, heads, head_dim, D, H, W, ws, ws, ws]
        k_win = extract_windows(k_pad)
        v_win = extract_windows(v_pad)
        # Get center query for each position
        center = ws // 2
        q_center = q_win[:, :, :, :, :, :, center, center, center]  # [B, heads, head_dim, D, H, W]
        # Flatten spatial dimensions of windows
        k_flat = k_win.view(B, self.num_heads, self.head_dim, D, H, W, ws*ws*ws)
        v_flat = v_win.view(B, self.num_heads, self.head_dim, D, H, W, ws*ws*ws)
        # Compute attention scores
        q_center = q_center.permute(0, 1, 3, 4, 5, 2).unsqueeze(-1)
        k_flat = k_flat.permute(0, 1, 3, 4, 5, 2, 6)
        v_flat = v_flat.permute(0, 1, 3, 4, 5, 2, 6)
        # Attention computation
        attn_scores = (q_center * k_flat).sum(dim=-2) * self.scale  # [B, heads, D, H, W, ws³]
        attn_weights = F.softmax(attn_scores, dim=-1)
        # Apply attention to values
        attn_out = (attn_weights.unsqueeze(-2) * v_flat).sum(dim=-1)  # [B, heads, D, H, W, head_dim]
        # Reshape back to original format
        attn_out = attn_out.permute(0, 1, 5, 2, 3, 4).contiguous()  # [B, heads, head_dim, D, H, W]
        attn_out = attn_out.view(B, C, D, H, W)
        # Final projection
        out = self.proj(attn_out)
        return out


class VoxelTransformerLayer3D(nn.Module):
    """
    Complete transformer layer with proper normalization and residuals.
    Now supports dynamic window size for attention.
    """
    def __init__(self, d_model: int, num_heads: int = 8, window_size: int = 3, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.window_size = window_size
        
        # Layer normalization (adapted for 3D)
        self.norm1 = nn.GroupNorm(1, d_model)  # GroupNorm works better for 3D than LayerNorm
        self.norm2 = nn.GroupNorm(1, d_model)
        
        # Attention
        self.attention = SpatialAttention3D(d_model, num_heads, window_size)
        
        # Feed-forward network
        self.ffn = nn.Sequential(
            nn.Conv3d(d_model, d_model * 4, kernel_size=1),
            nn.GELU(),
            nn.Dropout3d(dropout),
            nn.Conv3d(d_model * 4, d_model, kernel_size=1),
            nn.Dropout3d(dropout)
        )
        
        self.dropout = nn.Dropout3d(dropout)
        
    def forward(self, x, window_size=None):
        # Attention block with residual connection
        norm_x = self.norm1(x)
        attn_out = self.attention(norm_x, window_size=window_size)
        x = x + self.dropout(attn_out)
        
        # FFN block with residual connection
        norm_x = self.norm2(x)
        ffn_out = self.ffn(norm_x)
        x = x + ffn_out
        
        return x


class PositionalEncoding3D(nn.Module):
    """
    Learned 3D positional encoding for voxel grids.
    """
    def __init__(self, d_model: int, max_grid_size: int = 16):
        super().__init__()
        self.d_model = d_model
        self.max_grid_size = max_grid_size
        # Learnable positional embedding for each voxel position
        self.pos_embed = nn.Parameter(
            torch.zeros(1, d_model, max_grid_size, max_grid_size, max_grid_size)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

    def forward(self, x):
        # x: [B, d_model, D, H, W]
        _, _, D, H, W = x.shape
        return self.pos_embed[:, :, :D, :H, :W]


class VoxelCompletionTransformer(nn.Module):
    """
    Improved 3D transformer for voxel completion.
    Predicts in a single level at the given window size.
    """
    def __init__(self, d_model: int = 64, num_heads: int = 8, num_layers: int = 4,
                 max_grid_size: int = 16, window_size: int = 3, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.max_grid_size = max_grid_size
        self.num_layers = num_layers
        self.window_size = window_size
        # Input projection
        self.input_proj = nn.Conv3d(1, d_model, kernel_size=1)
        # Positional encoding
        self.pos_encoding = PositionalEncoding3D(d_model, max_grid_size)
        # Transformer layers
        self.layers = nn.ModuleList([
            VoxelTransformerLayer3D(d_model, num_heads, window_size, dropout)
            for _ in range(num_layers)
        ])
        # Output projection
        self.output_norm = nn.GroupNorm(1, d_model)
        self.output_proj = nn.Conv3d(d_model, 1, kernel_size=1)
        # Initialize weights
        self.apply(self._init_weights)
    def _init_weights(self, m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
    def forward(self, x, window_size=None):
        # x: [B, 1, D, H, W]
        x = self.input_proj(x)  # [B, d_model, D, H, W]
        x = x + self.pos_encoding(x)
        ws = window_size if window_size is not None else self.window_size
        for layer in self.layers:
            x = layer(x, window_size=ws)
        x = self.output_norm(x)
        x = self.output_proj(x)  # [B, 1, D, H, W]
        return x


def masked_bce_loss(preds, targets, partial_grid, criterion):
    # Mask for unknown voxels (where partial is 0)
    unknown_mask = (partial_grid == 0)
    # Only compute BCE loss on unknown voxels
    masked_loss = criterion(preds * unknown_mask, targets * unknown_mask)
    # Avoid division by zero
    denom = unknown_mask.float().sum() + 1e-6
    return (masked_loss * unknown_mask.float()).sum() / denom

def consistency_loss(preds, partial_grid):
    # Penalize changes to known voxels (where partial is 1)
    known_mask = (partial_grid == 1)
    return F.mse_loss(preds * known_mask, partial_grid * known_mask)

# levels = 16
def compute_pos_weight(dataset, sample_size=100):
    """Estimate pos_weight for BCEWithLogitsLoss based on dataset occupancy."""
    total_occupied = 0
    total_empty = 0
    n = min(sample_size, len(dataset))
    for i in range(n):
        complete, _ = dataset[i]
        total_occupied += (complete > 0.5).sum().item()
        total_empty += (complete <= 0.5).sum().item()
    if total_occupied == 0:
        return torch.tensor([1.0])
    return torch.tensor([total_empty / total_occupied])


def train_model(
    model: nn.Module,
    train_set,
    val_set,
    num_epochs: int = 50,
    batch_size: int = 1,
    window_size: int = 3,
    lambda_consistency: float = 1.0,
    seed: int = 42
):
    """Training loop for the voxel completion model with improved loss functions."""
    torch.cuda.empty_cache()
    torch.backends.cuda.enable_mem_efficient_sdp(True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    model.to(device)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = torch.utils.data.DataLoader(val_set, batch_size=batch_size, shuffle=False, num_workers=0)
    print(f"Train loader size: {len(train_loader)}, Val loader size: {len(val_loader)}")
    # Compute pos_weight for BCEWithLogitsLoss
    pos_weight = compute_pos_weight(train_set)
    print(f"Using pos_weight for BCEWithLogitsLoss: {pos_weight.item():.2f}")
    criterion = nn.BCEWithLogitsLoss(reduction='none', pos_weight=pos_weight.to(device))
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    scaler = torch.cuda.amp.GradScaler() if device.type == 'cuda' else None
    total_start_time = time.time()
    for epoch in range(num_epochs):
        epoch_start_time = time.time()
        total_loss = 0
        num_samples_processed = 0
        model.train()
        epoch_pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=True, unit='sample')
        for batch_idx, (complete_grid, partial_grid) in enumerate(epoch_pbar):
            complete_grid = complete_grid.to(device, non_blocking=True)
            partial_grid = partial_grid.to(device, non_blocking=True)
            optimizer.zero_grad()
            if partial_grid.dim() == 4:
                partial_grid = partial_grid.unsqueeze(1)
            if complete_grid.dim() == 4:
                complete_grid = complete_grid.unsqueeze(1)
            with torch.cuda.amp.autocast(enabled=(scaler is not None)):
                preds = model(partial_grid, window_size=window_size)
                masked_loss = masked_bce_loss(preds, complete_grid, partial_grid, criterion)
                cons_loss = consistency_loss(preds, partial_grid)
                total_batch_loss = masked_loss + lambda_consistency * cons_loss
            if scaler is not None:
                scaler.scale(total_batch_loss).backward()
                scaler.step(optimizer)
                scaler.update()
            else:
                total_batch_loss.backward()
                optimizer.step()
            total_loss += total_batch_loss.item()
            num_samples_processed += 1
            del complete_grid, partial_grid, preds, masked_loss, cons_loss, total_batch_loss
            torch.cuda.empty_cache()
            epoch_pbar.set_postfix({
                'train_loss': f'{total_loss/num_samples_processed:.4f}',
                'samples': num_samples_processed,
                'lr': optimizer.param_groups[0]['lr']
            })
        avg_train_loss = total_loss / max(num_samples_processed, 1)
        # Validation
        model.eval()
        val_loss = 0
        val_samples = 0
        with torch.no_grad():
            for complete_grid, partial_grid in val_loader:
                complete_grid = complete_grid.to(device, non_blocking=True)
                partial_grid = partial_grid.to(device, non_blocking=True)
                if partial_grid.dim() == 4:
                    partial_grid = partial_grid.unsqueeze(1)
                if complete_grid.dim() == 4:
                    complete_grid = complete_grid.unsqueeze(1)
                preds = model(partial_grid, window_size=window_size)
                masked_loss = masked_bce_loss(preds, complete_grid, partial_grid, criterion)
                cons_loss = consistency_loss(preds, partial_grid)
                total_batch_loss = masked_loss + lambda_consistency * cons_loss
                val_loss += total_batch_loss.item()
                val_samples += 1
                del complete_grid, partial_grid, preds, masked_loss, cons_loss, total_batch_loss
                torch.cuda.empty_cache()
        avg_val_loss = val_loss / max(val_samples, 1)
        epoch_time = time.time() - epoch_start_time
        print(f"Epoch {epoch+1}/{num_epochs} - Time: {timedelta(seconds=int(epoch_time))}, Train loss: {avg_train_loss:.4f}, Val loss: {avg_val_loss:.4f}, Samples: {num_samples_processed}")
    total_time = time.time() - total_start_time
    print(f"\nTraining completed in {timedelta(seconds=int(total_time))}")
    print(f"Average time per epoch: {timedelta(seconds=int(total_time/num_epochs))}")



In [3]:
# !pip install google-auth-oauthlib

# from google.colab import drive
# drive.mount('/content/drive')

# DRIVE_PATH = "/content/drive/MyDrive/AUB_masters/thesis/data/partial_data_16.zip"  # Adjust this path to match your Drive structure
# LOCAL_PATH = "/content/partial_data"
# !mkdir -p {LOCAL_PATH}

# print("Copying data from Drive to local storage...")
# !cp "{DRIVE_PATH}" "{LOCAL_PATH}/data.zip"
# zip_path = f"{LOCAL_PATH}/data.zip"

In [None]:
zip_path = "../../chunk_data_16_flood_fill_rm_20.zip"
dataset = VoxelDataset(zip_path)

print(f"Total dataset size: {len(dataset)}")

train_idx, val_idx, test_idx = split_dataset(dataset, seed=42)
train_set = torch.utils.data.Subset(dataset, train_idx)
val_set = torch.utils.data.Subset(dataset, val_idx)
test_set = torch.utils.data.Subset(dataset, test_idx)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {device}")

model = VoxelCompletionTransformer(
    d_model=48,        
    num_heads=6,       
    num_layers=6,      
    window_size=3,    
    dropout=0.1
).to(device)

print("Starting training...")
torch.cuda.empty_cache()
# best model before this was num_epochs=2
train_model(model, train_set, val_set, num_epochs=5, batch_size=2, lambda_consistency=1)

Created temporary directory: /tmp/tmppmkvwylj
Extracted zip file to temporary directory
Found 256571 total NPZ files
Using 256571
Total dataset size: 256571


In [None]:
MODEL_SAVE_PATH = "../../models/trained_model_rm_20_flood_fill.pth"
# os.makedirs(os.path.dirname(MODEL_SAVE_PATH), exist_ok=True)
# torch.save({
#     'model_state_dict': model.state_dict(),
# }, MODEL_SAVE_PATH)
# print(f"Model saved to {MODEL_SAVE_PATH}")

In [94]:
# --- Updated test_model to use test_set ---
from torchviz import make_dot
import torch
from torchinfo import summary

def test_model(model_path, test_set, sample_idx=0, threshold=0.5, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformer(
        d_model=48,        
        num_heads=6,       
        num_layers=6,      
        window_size=3,    
        dropout=0.1
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    complete, partial = test_set[sample_idx]

    filled_complete = (complete > 0).sum()
    # Of those, how many are also filled in partial?
    filled_partial = ((complete > 0) & (partial > 0)).sum()
    
    missing_percent = (filled_complete - filled_partial) / filled_complete

    # if missing_percent > 0.3:
    print("Missing Percentage: ", missing_percent)

    ptl = partial
    partial = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, D, H, W]
    with torch.no_grad():
        output = model(partial)

        # with open("model_summary.txt", "a") as f:
        #     f.write(str(summary(model, input_size=(1, 1, 16, 16, 16))))

        output = torch.sigmoid(output)
        output[0, 0][ptl == 1] = 1.0
        output = output.squeeze().cpu()
    print("Inference complete.")
    print("Partial shape:", partial.shape)
    print("Output shape:", output.shape)
    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            for k in range(output.shape[2]):
                if output[i, j, k] > threshold:
                    output[i, j, k] = 1.0
                else:
                    output[i, j, k] = 0.0
    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    np.save(out_path, output.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    print("Sample Index: ", sample_idx)
    print(f"Output saved to {out_path}")

In [None]:
test_model(
    model_path=MODEL_SAVE_PATH,
    test_set=test_set,
    sample_idx=
    random.randint(0, len(test_set) - 1),
    # 46000,
    threshold=0.5,
    device=device
)
# random.randint(0, len(test_set) - 1)

In [5]:
# --- Updated test_model to use test_set ---
from torchviz import make_dot
import torch
from torchinfo import summary
import numpy as np

def test_model_acc(model_path, test_set, sample_idx=0, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformer(
        d_model=48,        
        num_heads=6,       
        num_layers=6,      
        window_size=5,    
        dropout=0.1
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    complete, partial = test_set[sample_idx]
    ptl = partial
    partial = partial.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, D, H, W]
    with torch.no_grad():
        output = model(partial)

        # with open("model_summary.txt", "a") as f:
        #     f.write(str(summary(model, input_size=(1, 1, 16, 16, 16))))

        output = torch.sigmoid(output)
        output[0, 0][ptl == 1] = 1.0
        output = output.squeeze().cpu()
    print("Inference complete.")
    print("Partial shape:", partial.shape)
    print("Output shape:", output.shape)
    # Binarize output
    output_bin = (output > 0.5).float()

    # Only evaluate on unknown voxels (partial == 0)
    unknown_mask = (ptl == 0)
    total_unknown = unknown_mask.sum().item()
    if total_unknown == 0:
        print("No unknown voxels to evaluate accuracy.")
        accuracy = None
    else:
        correct = ((output_bin == complete) & unknown_mask).sum().item()
        accuracy = correct / total_unknown
        print(f"Accuracy on unknown voxels: {accuracy:.4f} ({correct}/{total_unknown})")

    for i in range(output.shape[0]):
        for j in range(output.shape[1]):
            for k in range(output.shape[2]):
                if output[i, j, k] > 0.5:
                    output[i, j, k] = 1.0
                else:
                    output[i, j, k] = 0.0
    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    np.save(out_path, output.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    print("Sample Index: ", sample_idx)
    print(f"Output saved to {out_path}")

In [8]:
test_model_acc(
    model_path=MODEL_SAVE_PATH,
    test_set=test_set,
    sample_idx=random.randint(0, len(test_set) - 1),
    device=device
)
# random.randint(0, len(test_set) - 1)

Inference complete.
Partial shape: torch.Size([1, 1, 16, 16, 16])
Output shape: torch.Size([16, 16, 16])
Accuracy on unknown voxels: 0.9875 (3879/3928)
Sample Index:  18971
Output saved to output_voxel.npy


In [78]:
# --- Iterative refinement inference ---
from torchviz import make_dot
import torch
from torchinfo import summary
import numpy as np

def test_model_iterative_ref(model_path, test_set, sample_idx=0, device=None, n_steps=5, verbose=True):
    """
    Iterative refinement: repeatedly feed the model's output as input, allowing it to refine its prediction.
    At each step, the model tries to improve upon its previous output, correcting errors and adding finer structure.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = VoxelCompletionTransformer(
        d_model=48,        
        num_heads=6,       
        num_layers=6,      
        window_size=5,    
        dropout=0.1
    ).to(device)
    checkpoint = torch.load(model_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    complete, partial = test_set[sample_idx]
    ptl = partial.clone()
    # Initial input is the partial observation
    current = partial.clone()
    current_input = current.unsqueeze(0).unsqueeze(0).to(device)  # [1, 1, D, H, W]
    with torch.no_grad():
        for step in range(n_steps):
            output = model(current_input)
            output = torch.sigmoid(output)
            # Always keep known voxels fixed (partial==1)
            output[0, 0][ptl == 1] = 1.0
            # Prepare for next step: use output as new input
            current = output.squeeze().cpu()
            if verbose:
                print(f"Refinement step {step+1}/{n_steps} complete.")
            # Next input: binarize or keep as probability?
            # Option 1: Use probabilities (soft refinement)
            current_input = current.unsqueeze(0).unsqueeze(0).to(device)
        # Final output: binarize for saving and evaluation
        output_bin = (current > 0.5).float()
        # Accuracy on unknown voxels
        unknown_mask = (ptl == 0)
        total_unknown = unknown_mask.sum().item()
        if total_unknown == 0:
            print("No unknown voxels to evaluate accuracy.")
            accuracy = None
        else:
            correct = ((output_bin == complete) & unknown_mask).sum().item()
            accuracy = correct / total_unknown
            print(f"Accuracy on unknown voxels after refinement: {accuracy:.4f} ({correct}/{total_unknown})")
    # Save results
    out_path = "output_voxel.npy"
    complete_path = "complete_voxel.npy"
    partial_path = "partial_voxel.npy"
    np.save(out_path, output_bin.numpy())
    np.save(complete_path, complete)
    np.save(partial_path, ptl)
    print(f"Refined output saved to {out_path}")
    print("Current Sample Index: ", sample_idx)

In [85]:
test_model_iterative_ref(
    model_path=MODEL_SAVE_PATH,
    test_set=test_set,
    sample_idx=random.randint(0, len(test_set) - 1),
    device=device,
    n_steps=2
)
# random.randint(0, len(test_set) - 1)

Refinement step 1/2 complete.
Refinement step 2/2 complete.
Accuracy on unknown voxels after refinement: 0.9333 (3626/3885)
Refined output saved to output_voxel.npy
Current Sample Index:  5729
