In [1]:
from typing import Tuple, List
from tqdm import tqdm
import zipfile
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Subset, DataLoader
from torch.amp import autocast, GradScaler
import os
import itertools
import math
from pathlib import Path
import shutil

In [2]:
# Load the Drive helper and mount
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [3]:
def unzip_to_dir(zip_path: str, out_dir: str, overwrite: bool = False, only_ext=(".npz",)):
    """
    Safely extract files from a zip into out_dir.
    - Protects against Zip Slip (path traversal).
    - Skips existing files unless overwrite=True.
    - Optionally restricts to certain extensions (default: .npz).

    Returns:
        extracted_count (int)
    """
    zip_path = Path(zip_path)
    out_dir = Path(out_dir)
    out_dir.mkdir(parents=True, exist_ok=True)
    extracted = 0

    with zipfile.ZipFile(zip_path, "r") as zf:
        for member in zf.infolist():
            # skip directories and non-matching extensions
            if member.is_dir():
                continue
            if only_ext and not member.filename.lower().endswith(only_ext):
                continue

            # resolve safe, normalized extraction path
            dest = out_dir / member.filename
            dest_parent = dest.parent.resolve()
            out_root = out_dir.resolve()
            if not str(dest_parent).startswith(str(out_root)):
                # Zip Slip detected, skip
                continue

            dest.parent.mkdir(parents=True, exist_ok=True)
            if dest.exists() and not overwrite:
                continue

            with zf.open(member) as src, open(dest, "wb") as dst:
                shutil.copyfileobj(src, dst)
            extracted += 1

    return extracted

class VoxelFileLister:
    """
    Lists .npz files under a directory (recursively) and loads 'complete' arrays.
    """
    def __init__(self, root_dir: str, recursive: bool = True):
        self.root_dir = os.path.abspath(root_dir)
        self.recursive = recursive
        self.npz_files = self._scan_npz(self.root_dir, recursive)
        if len(self.npz_files) == 0:
            raise ValueError(f"No .npz files found under {self.root_dir}")
        self.npz_files.sort()  # stable order

    @staticmethod
    def _scan_npz(root: str, recursive: bool):
        paths = []
        if recursive:
            for dirpath, _, filenames in os.walk(root):
                for fn in filenames:
                    if fn.lower().endswith(".npz"):
                        paths.append(os.path.join(dirpath, fn))
        else:
            for fn in os.listdir(root):
                if fn.lower().endswith(".npz"):
                    paths.append(os.path.join(root, fn))
        return paths

    def __len__(self):
        return len(self.npz_files)

    def load_single_file(self, file_path: str) -> torch.Tensor:
        """
        Returns:
            complete: torch.FloatTensor [1, D, H, W] with {0,1}
        """
        with np.load(file_path) as data:
            if "complete" not in data:
                raise ValueError(f"NPZ file {file_path} must contain 'complete' array")
            complete = torch.from_numpy(data["complete"]).float()
        if complete.dim() == 3:
            complete = complete.unsqueeze(0)    # [1,D,H,W]
        elif complete.dim() == 4:
            pass                                # already [C,D,H,W]
        else:
            raise ValueError(f"Unexpected 'complete' shape in {file_path}: {tuple(complete.shape)}")
        complete = (complete > 0).float()
        return complete

    def get_voxel_grid(self, index: int) -> torch.Tensor:
        if index < 0 or index >= len(self.npz_files):
            raise IndexError(f"Index {index} out of range (0..{len(self.npz_files)-1})")
        return self.load_single_file(self.npz_files[index])


class VoxelDataset(torch.utils.data.Dataset):
    """
    Directory-based Dataset: yields only the 'complete' occupancy grid
    as float tensor [1, D, H, W].
    """
    def __init__(self, root_dir: str, transform=None, recursive: bool = True):
        self.files = VoxelFileLister(root_dir, recursive=recursive)
        self.transform = transform

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        complete = self.files.get_voxel_grid(idx)  # [1,D,H,W], float {0,1}
        if self.transform:
            complete = self.transform(complete)
        return complete

def create_data_loaders_from_dir(
    root_dir: str,
    batch_size: int = 1,
    shuffle: bool = True,
    num_workers: int = 4,              # you can now use workers safely
    seed: int = 42,
    recursive: bool = True,
    persistent_workers: bool = True,
    prefetch_factor: int = 2,
):
    """
    Builds train/val/test DataLoaders from an UNZIPPED directory of NPZ files.
    """
    dataset = VoxelDataset(root_dir, recursive=recursive)
    print(f"Dataset size: {len(dataset)}")

    n = len(dataset)
    indices = list(range(n))
    random.Random(seed).shuffle(indices)

    n_trainval = int(n * 0.8)
    trainval_indices = indices[:n_trainval]
    test_indices     = indices[n_trainval:]

    n_train = int(len(trainval_indices) * 0.8)
    train_indices = trainval_indices[:n_train]
    val_indices   = trainval_indices[n_train:]

    common_kwargs = dict(
        batch_size=batch_size,
        # num_workers=num_workers,
        num_workers= 0,
        pin_memory=True,
        persistent_workers=persistent_workers if num_workers > 0 else False,
        prefetch_factor=prefetch_factor if num_workers > 0 else None,
    )

    train_loader = DataLoader(Subset(dataset, train_indices), shuffle=shuffle, **common_kwargs)
    val_loader   = DataLoader(Subset(dataset, val_indices),   shuffle=False,  **common_kwargs)
    test_loader  = DataLoader(Subset(dataset, test_indices),  shuffle=False,  **common_kwargs)
    return train_loader, val_loader, test_loader

class CachedVoxelSubset(torch.utils.data.Dataset):
    """
    Load a subset of VoxelDataset into RAM once, so __getitem__
    is just a tensor lookup (no np.load during training).
    """
    def __init__(self, base_dataset: VoxelDataset, indices):
        self.data = []
        for i in indices:
            self.data.append(base_dataset[i])   # this calls np.load ONCE per file

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

# ------------------------------
# Positional encoding (fixed for [B, D, H, W, d_model])
# ------------------------------
class RelPosBias(nn.Module):
    """
    Learned 3D relative attention bias for a fixed window size.
    - One bias per (dz,dy,dx) offset per head.
    - Shape during use: [1, num_heads, 1, ws^3] so it can be added to attention scores.
    """
    def __init__(self, num_heads: int, window_size: int):
        super().__init__()
        self.num_heads = num_heads
        self.window_size = window_size
        self.num_positions = window_size ** 3  # matches flattening order used in LocalAttention
        self.bias = nn.Parameter(torch.zeros(num_heads, self.num_positions))
        nn.init.trunc_normal_(self.bias, std=0.02)

    def forward(self):
        # return [1, h, 1, ws^3] broadcastable to [B, h, 1, ws^3]
        return self.bias.unsqueeze(0).unsqueeze(2)


# ------------------------------
# Local attention (your class, unchanged)
# ------------------------------
class LocalAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int = 4, window_size: int = 3):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        assert d_model % num_heads == 0
        self.window_size = window_size
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.scale = self.head_dim ** -0.5

    def forward(self, target_embedding, neighbor_embeddings, mask, rel_pos_bias=None):
        """
        target_embedding:    [B, d_model]
        neighbor_embeddings: [B, ws, ws, ws, d_model]
        mask:                [B, ws, ws, ws] float {0,1}
        """
        B = target_embedding.shape[0]
        ws = self.window_size

        neighbor_flat = neighbor_embeddings.contiguous().view(B, ws*ws*ws, self.d_model)
        mask_flat = (mask > 0.5).contiguous().view(B, ws*ws*ws)

        q = self.q_proj(target_embedding.unsqueeze(1))                         # [B,1,C]
        k = self.k_proj(neighbor_flat)                                         # [B,ws^3,C]
        v = self.v_proj(neighbor_flat)

        q = q.view(B, 1, self.num_heads, self.head_dim).transpose(1, 2)        # [B,h,1,d]
        k = k.view(B, ws*ws*ws, self.num_heads, self.head_dim).transpose(1, 2) # [B,h,ws^3,d]
        v = v.view(B, ws*ws*ws, self.num_heads, self.head_dim).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale             # [B,h,1,ws^3]

        if rel_pos_bias is not None:
            # rel_pos_bias shape: [1,h,1,ws^3] -> broadcast over batch
            scores = scores + rel_pos_bias

        mask_expanded = mask_flat.unsqueeze(1).unsqueeze(2).expand(-1, self.num_heads, 1, -1)
        scores = scores.masked_fill(~mask_expanded, float('-inf'))

        # === Dangerous when all neighbors are masked ===
        # all-false safety: avoid NaNs when no known neighbors
        # all_false = ~mask_expanded.any(dim=-1, keepdim=True)                   # [B,h,1,1]
        # scores = torch.where(all_false, torch.zeros_like(scores), scores)

        # attn_weights = F.softmax(scores, dim=-1)
        # ===

        # === NEW CODE STARTS HERE ===
        # Detect samples that have no valid (attendable) neighbors
        valid_counts = mask_flat.sum(dim=1)       # [B]
        no_ctx = (valid_counts == 0)              # [B] bool
        no_ctx_broadcast = no_ctx.view(-1, 1, 1, 1)   # [B,1,1,1] for broadcasting

        # If there is no context, avoid softmax(-inf,...,-inf): set scores to 0 beforehand.
        # (This yields a uniform softmax, which we immediately zero out below.)
        scores = torch.where(no_ctx_broadcast, torch.zeros_like(scores), scores)

        # attention weights
        attn_weights = F.softmax(scores, dim=-1)

        # Zero the weights for no-context samples (no evidence => no attention contribution)
        attn_weights = torch.where(no_ctx_broadcast, torch.zeros_like(attn_weights), attn_weights)

        # attention output
        out = torch.matmul(attn_weights, v)                       # [B,h,1,d]
        out = out.transpose(1, 2).contiguous().view(B, 1, self.d_model).squeeze(1)  # [B,d_model]

        # Project, THEN zero for no-context to prevent Linear bias from leaking in
        proj = self.out_proj(out)                                 # [B,d_model]
        proj = torch.where(no_ctx.unsqueeze(1), torch.zeros_like(proj), proj)

        return proj

# ------------------------------
# Voxel transformer layer that uses LocalAttention per voxel
# ------------------------------
class TransformerLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int = 8, window_size: int = 3, dropout: float = 0.1):
        super().__init__()
        self.attn = LocalAttention(d_model, num_heads, window_size)
        self.rel_bias = RelPosBias(num_heads, window_size)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_model * 4),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(d_model * 4, d_model),
            nn.Dropout(dropout),
        )
        self.dropout = nn.Dropout(dropout)
        self.window_size = window_size
        self.d_model = d_model

    def forward(self, x, neighborhood_fn, mask_fn):
        """
        x: [B, ws, ws, ws, d_model]  (patch embeddings)
        neighborhood_fn: (grid, dd, hh, ww, ws) -> [B, ws, ws, ws, d_model]
        mask_fn:         (Dp,Hp,Wp, dd,hh,ww, ws) -> [B, ws, ws, ws]  (0/1 numeric; known-mask)
        Returns:
            x_center: [B, d_model]  (updated center embedding only)
        """
        B, Dp, Hp, Wp, C = x.shape
        r = self.window_size // 2
        # target (center) token
        target = x[:, r, r, r, :]  # [B, d_model]

        # neighborhood embeddings and known mask for the center position
        neighbors = neighborhood_fn(x, r, r, r, self.window_size)  # [B, ws, ws, ws, d_model]
        mask      = mask_fn(Dp, Hp, Wp, r, r, r, self.window_size) # [B, ws, ws, ws] (0/1)

        # attention (center queries neighbors)
        tgt = self.norm1(target)
        attn_out = self.attn(tgt, neighbors, mask, rel_pos_bias=self.rel_bias())                 # [B, d_model]
        target = target + self.dropout(attn_out)

        # ffn on center
        tgt2 = self.norm2(target)
        ffn_out = self.ffn(tgt2)                                   # [B, d_model]
        target = target + ffn_out

        return target  # [B, d_model]


# ------------------------------
# Stack layers
# ------------------------------
class Transformer(nn.Module):
    def __init__(self, num_layers: int, d_model: int, num_heads: int, window_size: int, dropout: float = 0.1):
        super().__init__()
        self.layers = nn.ModuleList([
            TransformerLayer(d_model, num_heads, window_size, dropout)
            for _ in range(num_layers)
        ])

    def forward(self, x, neighborhood_fn, mask_fn):
        """
        x: [B, D, H, W, d_model]
        """
        center = None
        for layer in self.layers:
            center = layer(x, neighborhood_fn, mask_fn)
            x = x.clone()
            r = x.shape[1] // 2
            x[:, r, r, r, :] = center
        return center

# ------------------------------
# neighborhood_raw: returns raw patch (channel-first) and known_mask
# ------------------------------
def neighborhood_raw(occ_grid, known_grid, b, d, h, w, window_size):
    """
    occ_grid:   [B,1,D,H,W] float {0,1}
    known_grid: [B,1,D,H,W] float {0,1}

    Returns:
        patch_occ:   [1,1,ws,ws,ws] float {0,1}
        patch_known: [1,  ws,ws,ws] float {0,1}
    """
    _, C, D, H, W = occ_grid.shape
    assert C == 1
    r = window_size // 2

    d0, d1 = max(0, d-r), min(D, d+r+1)
    h0, h1 = max(0, h-r), min(H, h+r+1)
    w0, w1 = max(0, w-r), min(W, w+r+1)

    patch_occ   = occ_grid[b:b+1, :, d0:d1, h0:h1, w0:w1]          # [1,1,dp,hp,wp]
    patch_known = known_grid[b:b+1, :, d0:d1, h0:h1, w0:w1]        # [1,1,dp,hp,wp]

    # pads: (W_left, W_right, H_top, H_bottom, D_front, D_back)
    pad = (
        max(0, r - w),               max(0, (w + r + 1) - W),
        max(0, r - h),               max(0, (h + r + 1) - H),
        max(0, r - d),               max(0, (d + r + 1) - D),
    )

    if any(p > 0 for p in pad):
        patch_occ   = F.pad(patch_occ,   pad, value=0.0)  # unknown outside -> 0 (empty value; but also unknown)
        patch_known = F.pad(patch_known, pad, value=0.0)  # outside is unknown => 0

    # squeeze channel for known to [1,ws,ws,ws]
    patch_known = patch_known[:, 0]
    return patch_occ.contiguous(), patch_known.contiguous()

class IterativeVoxelModel(nn.Module):
    def __init__(self, d_model: int = 64, num_heads: int = 4, num_layers: int = 3,
                 window_size: int = 3, max_grid_size: int = 16, dropout: float = 0.1):
        """
        This model expects per-call patch input:
            neighbors_patch: [B, 1, ws, ws, ws]  (binary occupancy values or zeros)
            known_mask:      [B,    ws, ws, ws]  (1 = this voxel is known/visible, 0 = unknown/masked)

        - Input is projected via Conv3d(2 -> d_model):
              channel 0 = occupancy (0/1)
              channel 1 = known_flag (0/1)

        - We also add a learned 3D positional embedding of shape [1, ws, ws, ws, d_model]
          on top of the projected features before feeding the transformer.

        - Transformer returns center embedding -> output_head -> scalar logit.
        """
        super().__init__()
        self.d_model = d_model
        self.window_size = window_size

        # 2-channel input: occupancy + known_flag
        self.input_proj = nn.Conv3d(2, d_model, kernel_size=1)

        # same transformer as before
        self.transformer = Transformer(
            num_layers=num_layers,
            d_model=d_model,
            num_heads=num_heads,
            window_size=window_size,
            dropout=dropout,
        )

        # learned absolute 3D positional embedding inside the window
        # shape: [1, ws, ws, ws, d_model]
        self.pos_embed = nn.Parameter(
            torch.zeros(1, window_size, window_size, window_size, d_model)
        )
        nn.init.trunc_normal_(self.pos_embed, std=0.02)

        # output head: from d_model to one logit
        self.output_head = nn.Linear(d_model, 1)

    def forward(self, neighbors_patch, known_mask):
        """
        neighbors_patch: [B, 1, ws, ws, ws] float {0,1}
        known_mask:      [B,    ws, ws, ws] float {0,1} (1 = known, 0 = unknown)
        """
        B = neighbors_patch.shape[0]
        ws = self.window_size
        r = ws // 2

        # Hide the true center occupancy during training / inference
        # (so model must infer it from neighbors + known_mask)
        neighbors_patch = neighbors_patch.clone()
        neighbors_patch[:, :, r, r, r] = 0.0

        # Build 2-channel input: [occupancy, known_flag]
        # known_mask is [B, ws, ws, ws] -> [B,1,ws,ws,ws]
        known_flag = known_mask.unsqueeze(1)
        x_in = torch.cat([neighbors_patch, known_flag], dim=1)  # [B,2,ws,ws,ws]

        # Project to d_model and permute to [B,ws,ws,ws,d_model]
        emb = self.input_proj(x_in)               # [B,d_model,ws,ws,ws]
        emb = emb.permute(0, 2, 3, 4, 1).contiguous()  # [B,ws,ws,ws,d_model]

        # Add learned absolute 3D positional embedding inside the window
        emb = emb + self.pos_embed  # auto-broadcast [1,ws,ws,ws,d_model]

        # Neighborhood + mask functions (same semantics as before)
        def neighborhood_fn_patch(grid, dd, hh, ww, window_size):
            B2, Dp, Hp, Wp, C = grid.shape
            r = window_size // 2
            d0, d1 = max(0, dd - r), min(Dp, dd + r + 1)
            h0, h1 = max(0, hh - r), min(Hp, hh + r + 1)
            w0, w1 = max(0, ww - r), min(Wp, ww + r + 1)
            patch_local = grid[:, d0:d1, h0:h1, w0:w1, :]   # [B,dp,hp,wp,C]

            pd0 = max(0, r - dd); pd1 = max(0, (dd + r + 1) - Dp)
            ph0 = max(0, r - hh); ph1 = max(0, (hh + r + 1) - Hp)
            pw0 = max(0, r - ww); pw1 = max(0, (ww + r + 1) - Wp)
            if any([pd0, pd1, ph0, ph1, pw0, pw1]):
                tmp = patch_local.permute(0, 4, 1, 2, 3)   # [B,C,d,h,w]
                tmp = F.pad(tmp, (pw0, pw1, ph0, ph1, pd0, pd1))
                patch_local = tmp.permute(0, 2, 3, 4, 1).contiguous()
            return patch_local                             # [B,ws,ws,ws,C]

        def mask_fn_patch(Dp, Hp, Wp, dd, hh, ww, window_size):
            # Here, known_mask already encodes which voxels are known / visible (1) vs unknown (0),
            # with padding / masked-out neighbors already set to 0.
            # We use it directly as the attention "attend mask".
            return known_mask  # [B,ws,ws,ws], numeric 0/1

        center_emb = self.transformer(emb, neighborhood_fn_patch, mask_fn_patch)  # [B, d_model]
        logits = self.output_head(center_emb)  # [B,1]
        return logits

def _extract_patch_with_pad_complete(complete_batch, b, d, h, w, ws):
    """
    complete_batch: [B,1,D,H,W] float {0,1}
    Returns:
      patch_occ:   [1,1,ws,ws,ws]
      base_mask:   [1,ws,ws,ws]  (1=in-bounds, 0=pad); center set to 0 later by caller
    """
    _, _, D, H, W = complete_batch.shape
    r = ws // 2

    d0, d1 = max(0, d-r), min(D, d+r+1)
    h0, h1 = max(0, h-r), min(H, h+r+1)
    w0, w1 = max(0, w-r), min(W, w+r+1)

    patch_occ = complete_batch[b:b+1, :, d0:d1, h0:h1, w0:w1]  # [1,1,dp,hp,wp]
    base_mask = torch.ones_like(patch_occ, dtype=patch_occ.dtype)  # [1,1,dp,hp,wp]

    # pad to ws on (W,H,D)
    pad = (
        max(0, r - (w - w0)), max(0, (w + r + 1) - w1),   # W_left, W_right
        max(0, r - (h - h0)), max(0, (h + r + 1) - h1),   # H_top, H_bottom
        max(0, r - (d - d0)), max(0, (d + r + 1) - d1),   # D_front, D_back
    )
    if any(pad):
        patch_occ  = F.pad(patch_occ,  pad, value=0.0)  # padded voxels are empty values
        base_mask  = F.pad(base_mask,  pad, value=0.0)  # padded positions cannot attend

    base_mask = base_mask[:, 0]  # [1,ws,ws,ws]
    return patch_occ.contiguous(), base_mask.contiguous()

def _valid_positions_from_base(base_mask):
    """Return a list of (z,y,x) valid positions (in-bounds, non-center)."""
    _, ws, _, _ = base_mask.shape
    r = ws // 2
    valid = (base_mask > 0.5).clone()  # [1,ws,ws,ws]
    valid[:, r, r, r] = False          # center never attends
    idxs = torch.nonzero(valid[0], as_tuple=False)  # [M,3]
    return idxs  # on same device

def _attend_mask_from_combo(base_mask, valid_idxs, combo_indices):
    """
    base_mask: [1,ws,ws,ws] with 1=in-bounds
    valid_idxs: [M,3] positions eligible to attend/mask
    combo_indices: iterable (tuple/list) or 1D LongTensor of indices into valid_idxs to MASK OUT
    Returns attend_mask: [1,ws,ws,ws] with 1=attend, 0=masked, center=0
    """
    attend = (base_mask > 0.5).float()  # start with in-bounds as 1

    # normalize indices -> 1D LongTensor
    if combo_indices is not None:
        if isinstance(combo_indices, torch.Tensor):
            idx = combo_indices.to(device=valid_idxs.device, dtype=torch.long)
        else:
            # tuple/list from itertools / random.sample
            idx = torch.as_tensor(combo_indices, device=valid_idxs.device, dtype=torch.long)

        if idx.numel() > 0:
            sel = valid_idxs.index_select(0, idx)  # [k,3]
            attend[0, sel[:, 0], sel[:, 1], sel[:, 2]] = 0.0

    # force center to 0
    ws = base_mask.shape[1]
    r  = ws // 2
    attend[:, r, r, r] = 0.0
    return attend

def build_masked_window_batch_complete(
    complete_batch: torch.Tensor,   # [B,1,D,H,W], float {0,1}
    window_size: int,
    *,
    stride: int = 1,
    mask_ratio: float = 0.2,
    max_masks_per_center: int = 8,
    max_centers_per_sample: int = None, # Caps how many centers per volume to take (AFTER skipping empty windows)
    rng: random.Random = None,
):
    """
    SAMPLE-ONLY variant with **per-volume random center sampling**:
      - For each volume b:
          1) Scan all centers (d,h,w) with given stride.
          2) For each center, extract a window; if the window is fully empty (occ_count==0), skip it.
          3) Collect all remaining centers into a candidate list.
          4) Randomly sample up to `max_centers_per_sample` centers from this list.
          5) For each chosen center, sample up to `max_masks_per_center` UNIQUE masking
             combinations of size k = round(mask_ratio * M), where M = number of valid neighbors
             (non-center, in-bounds).
             - If k == 0, emits a single unmasked attend map (center still forced to 0 downstream).

    Returns:
      patches:      [N,1,ws,ws,ws]
      attend_masks: [N,  ws,ws,ws]  (1=attend, 0=masked; center forced 0)
      labels:       [N,1]           (center GT from `complete_batch`)
    """
    if rng is None:
        rng = random

    device = complete_batch.device
    B, _, D, H, W = complete_batch.shape
    ws = window_size
    patches, masks, labels = [], [], []

    for b in range(B):
        # 1) Gather candidate centers (non-empty windows) for this volume
        candidate_centers = []  # list of (d,h,w)
        for d in range(0, D, stride):
            for h in range(0, H, stride):
                for w in range(0, W, stride):
                    # Extract window and base mask just to check occ_count
                    patch_occ, base_mask = _extract_patch_with_pad_complete(
                        complete_batch, b, d, h, w, ws
                    )
                    occ_count = patch_occ.sum().item()

                    # Skip fully empty windows
                    if occ_count == 0:
                        continue

                    candidate_centers.append((d, h, w))

        # 2) Randomly subsample candidate centers if max_centers_per_sample is set
        if max_centers_per_sample is not None and max_centers_per_sample > 0:
            if len(candidate_centers) > max_centers_per_sample:
                # random.sample uses Python's RNG; rng can be a Random instance
                if isinstance(rng, random.Random):
                    chosen_centers = rng.sample(candidate_centers, max_centers_per_sample)
                else:
                    # fallback to global random if rng is the module
                    chosen_centers = random.sample(candidate_centers, max_centers_per_sample)
            else:
                chosen_centers = candidate_centers
        else:
            chosen_centers = candidate_centers

        # 3) For each chosen center, build patches + masks + labels
        for (d, h, w) in chosen_centers:
            # Re-extract window & base mask for this center
            patch_occ, base_mask = _extract_patch_with_pad_complete(
                complete_batch, b, d, h, w, ws
            )

            # Center label from complete GT (kept as [1,1])
            label = complete_batch[b:b+1, :, d:d+1, h:h+1, w:w+1].view(1, 1)  # {0,1}

            # Valid neighbor indices (non-center, in-bounds)
            valid_idxs = _valid_positions_from_base(base_mask)  # [M,3]
            M = valid_idxs.shape[0]
            k = int(round(mask_ratio * M))
            k = max(0, min(k, M))  # clamp

            # Handle k == 0: single attend map with no extra masking
            if k == 0:
                attend_mask = _attend_mask_from_combo(base_mask, valid_idxs, [])
                patches.append(patch_occ)
                masks.append(attend_mask)
                labels.append(label)
                continue

            # 4) Sample-only: up to `max_masks_per_center` UNIQUE k-subsets
            if max_masks_per_center is None or max_masks_per_center <= 0:
                max_masks_per_center = 1

            seen = set()
            trials = 0
            max_trials = max_masks_per_center * 10  # soft cap on retries

            while (len(seen) < max_masks_per_center) and (trials < max_trials):
                combo = tuple(sorted(rng.sample(range(M), k))) if k > 0 else tuple()
                if combo not in seen:
                    seen.add(combo)
                    attend_mask = _attend_mask_from_combo(base_mask, valid_idxs, combo)
                    patches.append(patch_occ)
                    masks.append(attend_mask)
                    labels.append(label)
                trials += 1

    # 4) Collate across all volumes
    if len(patches) == 0:
        return None, None, None

    patches = torch.cat(patches, dim=0).to(device)  # [N,1,ws,ws,ws]
    masks   = torch.cat(masks,   dim=0).to(device)  # [N,ws,ws,ws]
    labels  = torch.cat(labels,  dim=0).to(device)  # [N,1]
    return patches, masks, labels

# Training

In [9]:
def run_validation(
    model: nn.Module,
    val_set,
    device,
    *,
    window_size: int = 5,
    stride: int = 1,
    mask_ratio: float = 0.2,
    masks_per_center: int = 4,      # smaller than train for speed
    max_centers: int = 64,          # fewer centers per volume than train
    pos_weight: float = 1.0,
    batch_size: int = 4,            # smaller BS is fine for val
    max_batches: int = 50,          # limit how many val volumes you touch
):
    """
    Fast validation:
      - Uses the same build_masked_window_batch_complete as training.
      - But with smaller max_centers, fewer masks_per_center, and a cap on val batches.
      - Computes loss, overall acc, fill_acc, empty_acc on a subset of val_set.
    """
    model.eval()

    val_loader = torch.utils.data.DataLoader(
        val_set, batch_size=batch_size, shuffle=False, num_workers=0, pin_memory=True
    )

    pw = torch.tensor([pos_weight], device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pw)

    loss_sum = 0.0
    sample_count = 0
    correct_total = 0
    total_total = 0

    filled_correct = 0
    filled_total   = 0
    empty_correct  = 0
    empty_total    = 0

    with torch.no_grad():
        for i_batch, complete_batch in enumerate(val_loader):
            if max_batches is not None and i_batch >= max_batches:
                break

            # Ensure shapes [B,1,D,H,W] and keep on CPU
            if complete_batch.dim() == 4:
                complete_batch = complete_batch.unsqueeze(1)
            complete_batch = complete_batch.float()

            # Build masked-window batch from COMPLETE volumes (CPU tensors)
            patches, attend_masks, labels = build_masked_window_batch_complete(
                complete_batch,
                window_size=window_size,
                stride=stride,
                mask_ratio=mask_ratio,
                max_masks_per_center=masks_per_center,
                max_centers_per_sample=max_centers,
            )

            if patches is None:
                continue

            PATCH_BATCH = 512
            N = patches.size(0)

            for start in range(0, N, PATCH_BATCH):
                end = min(start + PATCH_BATCH, N)

                p = patches[start:end].to(device, non_blocking=True)
                m = attend_masks[start:end].to(device, non_blocking=True)
                y = labels[start:end].to(device, non_blocking=True)

                logits = model(p, m)
                loss = criterion(logits, y)

                n_chunk = y.numel()
                loss_sum += loss.item() * n_chunk
                sample_count += n_chunk

                # --- accuracy aggregation ---
                probs = torch.sigmoid(logits)
                preds = (probs >= 0.5).float()

                correct_total += (preds == y).sum().item()
                total_total   += n_chunk

                labels_flat = y.view(-1)
                preds_flat  = preds.view(-1)

                mask_filled = (labels_flat == 1)
                mask_empty  = (labels_flat == 0)

                if mask_filled.any():
                    filled_correct += (preds_flat[mask_filled] == labels_flat[mask_filled]).sum().item()
                    filled_total   += mask_filled.sum().item()

                if mask_empty.any():
                    empty_correct += (preds_flat[mask_empty] == labels_flat[mask_empty]).sum().item()
                    empty_total   += mask_empty.sum().item()

    if sample_count == 0:
        return {
            "val_loss": float("nan"),
            "val_acc": float("nan"),
            "val_fill_acc": float("nan"),
            "val_empty_acc": float("nan"),
        }

    val_loss = loss_sum / sample_count
    val_acc  = correct_total / total_total if total_total > 0 else 0.0
    val_fill_acc  = filled_correct / filled_total if filled_total > 0 else float('nan')
    val_empty_acc = empty_correct  / empty_total  if empty_total  > 0 else float('nan')

    return {
        "val_loss": val_loss,
        "val_acc": val_acc,
        "val_fill_acc": val_fill_acc,
        "val_empty_acc": val_empty_acc,
    }

def train(
    model: nn.Module,
    train_set,
    val_set=None,
    *,
    num_epochs: int = 20,
    batch_size: int = 8,
    window_size: int = 5,
    lr: float = 1e-4,
    weight_decay: float = 0.0,
    pos_weight: float = 1.0,
    amp: bool = True,
    seed: int = 42,
    progress: bool = True,
    stride: int = 1,
    masks_per_center: int = 16,
    max_centers: int = None,
    mask_ratio: float = 0.2,
    # --- new optional args ---
    checkpoint_dir: str = None,
    val_fast: bool = True,
    val_max_centers: int = 64,
    val_masks_per_center: int = 4,
    val_max_batches: int = 50,
):
    torch.manual_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print("---------------")
    print(f"Using {max_centers} centers \n {masks_per_center} masks")
    print("Started training...")
    print(device)
    print()
    model.to(device)
    model.train()

    train_loader = torch.utils.data.DataLoader(
        train_set, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=True
    )
    pw = torch.tensor([pos_weight], device=device)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pw)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scaler = GradScaler(enabled=amp)

    if checkpoint_dir is not None:
        os.makedirs(checkpoint_dir, exist_ok=True)

    best_val_loss = float("inf")
    best_ckpt_path = None

    for epoch in range(1, num_epochs + 1):
        model.train()

        # --- epoch-level accumulators (sample-weighted) ---
        loss_sum = 0.0
        sample_count = 0
        correct_total = 0
        total_total = 0

        filled_correct = 0
        filled_total   = 0
        empty_correct  = 0
        empty_total    = 0

        pbar = tqdm(train_loader, desc=f"Epoch {epoch}/{num_epochs}", disable=not progress)
        for complete_batch in pbar:
            # Ensure shapes [B,1,D,H,W] and keep on CPU
            if complete_batch.dim() == 4:
                complete_batch = complete_batch.unsqueeze(1)
            complete_batch = complete_batch.float()

            # Build masked-window batch from COMPLETE volumes (CPU tensors)
            patches, attend_masks, labels = build_masked_window_batch_complete(
                complete_batch,
                window_size=window_size,
                stride=stride,
                mask_ratio=mask_ratio,
                max_masks_per_center=masks_per_center,
                max_centers_per_sample=max_centers,
            )

            if patches is None:
                continue

            PATCH_BATCH = 512  # how many patches you want in GPU at once
            N = patches.size(0)

            for start in range(0, N, PATCH_BATCH):
                end = min(start + PATCH_BATCH, N)

                p = patches[start:end].to(device, non_blocking=True)
                m = attend_masks[start:end].to(device, non_blocking=True)
                y = labels[start:end].to(device, non_blocking=True)

                optimizer.zero_grad(set_to_none=True)
                with autocast(device_type='cuda', enabled=amp):
                    logits = model(p, m)              # [patch_batch, 1]
                    loss = criterion(logits, y)       # scalar for this patch chunk

                scaler.scale(loss).backward()
                scaler.step(optimizer)
                scaler.update()

                # --- sample-weighted loss aggregation ---
                n_chunk = y.numel()
                loss_sum += loss.item() * n_chunk
                sample_count += n_chunk

                # --- accuracy aggregation ---
                with torch.no_grad():
                    probs = torch.sigmoid(logits)
                    preds = (probs >= 0.5).float()

                    correct_total += (preds == y).sum().item()
                    total_total   += n_chunk

                    labels_flat = y.view(-1)
                    preds_flat  = preds.view(-1)

                    mask_filled = (labels_flat == 1)
                    mask_empty  = (labels_flat == 0)

                    if mask_filled.any():
                        filled_correct += (preds_flat[mask_filled] == labels_flat[mask_filled]).sum().item()
                        filled_total   += mask_filled.sum().item()

                    if mask_empty.any():
                        empty_correct += (preds_flat[mask_empty] == labels_flat[mask_empty]).sum().item()
                        empty_total   += mask_empty.sum().item()

            # --- update tqdm ONCE per complete_batch, using epoch-level stats so far ---
            avg_loss = loss_sum / max(sample_count, 1)
            avg_acc  = correct_total / max(total_total, 1)
            fill_acc = filled_correct / filled_total if filled_total > 0 else float('nan')
            empty_acc = empty_correct / empty_total if empty_total > 0 else float('nan')

            pbar.set_postfix({
                "loss": f"{avg_loss:.4f}",
                "acc": f"{avg_acc:.3f}",
                "fill_acc": f"{fill_acc:.3f}",
                "empty_acc": f"{empty_acc:.3f}",
            })

        # --- end of epoch: compute final train metrics ---
        if sample_count > 0:
            epoch_loss = loss_sum / sample_count
            epoch_acc  = correct_total / total_total if total_total > 0 else 0.0
            fill_acc   = filled_correct / filled_total if filled_total > 0 else float('nan')
            empty_acc  = empty_correct / empty_total if empty_total > 0 else float('nan')
        else:
            epoch_loss = float("nan")
            epoch_acc  = float("nan")
            fill_acc   = float("nan")
            empty_acc  = float("nan")

        print(
            f"Epoch {epoch}: "
            f"train_loss = {epoch_loss:.4f} | "
            f"train_acc = {epoch_acc:.4f} | "
            f"fill_acc = {fill_acc:.4f} | "
            f"empty_acc = {empty_acc:.4f}"
        )

        # --- quick validation ---
        val_stats = None
        if val_set is not None and val_fast:
            val_stats = run_validation(
                model,
                val_set,
                device,
                window_size=window_size,
                stride=stride,
                mask_ratio=mask_ratio,
                masks_per_center=val_masks_per_center,
                max_centers=val_max_centers,
                pos_weight=pos_weight,
                batch_size=batch_size,      # reuse train BS, or make smaller
                max_batches=val_max_batches,
            )
            print(
                f"  [VAL] loss = {val_stats['val_loss']:.4f} | "
                f"acc = {val_stats['val_acc']:.4f} | "
                f"fill_acc = {val_stats['val_fill_acc']:.4f} | "
                f"empty_acc = {val_stats['val_empty_acc']:.4f}"
            )

        # --- checkpointing ---
        if checkpoint_dir is not None:
            ckpt = {
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scaler_state_dict": scaler.state_dict(),
                "train_loss": epoch_loss,
                "train_acc": epoch_acc,
                "train_fill_acc": fill_acc,
                "train_empty_acc": empty_acc,
            }
            if val_stats is not None:
                ckpt.update(val_stats)

            ckpt_path = os.path.join(checkpoint_dir, f"checkpoint_epoch_{epoch}.pth")
            torch.save(ckpt, ckpt_path)
            print(f"Saved checkpoint to {ckpt_path}")

            # Track best by val_loss (if val is available)
            if val_stats is not None:
                if val_stats["val_loss"] < best_val_loss:
                    best_val_loss = val_stats["val_loss"]
                    best_ckpt_path = os.path.join(checkpoint_dir, "best_checkpoint.pth")
                    torch.save(ckpt, best_ckpt_path)
                    print(f"Updated best checkpoint: {best_ckpt_path}")

    if checkpoint_dir is not None and best_ckpt_path is not None:
        print(f"Best checkpoint: {best_ckpt_path} with val_loss = {best_val_loss:.4f}")

    return model


In [None]:
# ------------------------------
# Example main (instantiate and run)
# ------------------------------
if __name__ == "__main__":
    ZIP_PATH   = "/content/drive/MyDrive/AUB_masters/thesis/data/chunk_data_16_flood_fill_rm_20.zip"
    UNZIP_PATH = "./chunk_data_16_flood_fill_rm_20"
    train_list = "./train_list.txt"
    extracted = unzip_to_dir(ZIP_PATH, UNZIP_PATH, overwrite=False)
    print(f"Unzip complete. New/updated files extracted: {extracted}")

    # 2) Build dataset/loaders from the directory (fast & worker-safe)
    dataset = VoxelDataset(UNZIP_PATH)
    print(f"Total dataset size: {len(dataset)}")

    n = len(dataset)
    indices = list(range(n))
    random.Random(42).shuffle(indices)

    DATA_PER   = 0.05
    TRAIN_FRAC = 0.8    # within that DATA_PER%, use 80% for training, 20% for validation

    # Take a subset of the dataset = 5% of total
    n_subset = max(1, int(n * DATA_PER))
    subset_indices = indices[:n_subset]

    # Split that subset into train / val
    n_train = max(1, int(n_subset * TRAIN_FRAC))
    train_idx = subset_indices[:n_train]
    val_idx   = subset_indices[n_train:]

    # (Optional) the rest of the dataset could be considered "test"
    test_idx = indices[n_subset:]

    # train_set = torch.utils.data.Subset(dataset, train_idx)
    # val_set   = torch.utils.data.Subset(dataset, val_idx)
    print("Caching train subset into RAM...")
    train_set = CachedVoxelSubset(dataset, train_idx)
    print("Caching val subset into RAM...")
    val_set   = CachedVoxelSubset(dataset, val_idx)

    print(f"Training samples: {len(train_idx)} | Validation samples: {len(val_idx)}")

    # === Save list of training file names ===
    npz_files = dataset.files.npz_files
    train_files = [npz_files[i] for i in train_idx]

    os.makedirs(os.path.dirname(train_list), exist_ok=True)
    with open(train_list, "w") as f:
        for name in train_files:
            f.write(name + "\n")

    print(f"Train List Length: {len(train_files)}")
    print(f"Saved training file list to: {train_list}")
    print(f"Training samples: {len(train_idx)} | Validation samples: {len(val_idx)}")
    # ========================================

    # === Model parameters ===
    WINDOW_SIZE = 11
    D_MODEL = 64
    NUM_HEADS = 4
    NUM_LAYERS = 4
    DROPOUT = 0.1
    GRID_SIZE = 16
    # ========================
    model = IterativeVoxelModel(
        d_model=D_MODEL,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        window_size=WINDOW_SIZE,
        dropout=DROPOUT,
        max_grid_size=GRID_SIZE,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # === Training parameters ===
    EPOCHS = 4
    BATCH_SIZE = 8
    MASKS_PER_CENTER = 8
    MAX_CENTERS = 128
    # ===========================
    checkpoint_dir = "./checkpoints_it"

    # train(
    #     model,
    #     train_set,
    #     num_epochs=EPOCHS,
    #     batch_size=BATCH_SIZE,
    #     window_size=WINDOW_SIZE,
    #     masks_per_center=MASKS_PER_CENTER,
    #     max_centers=MAX_CENTERS,
    # )

    train(
    model,
    train_set,
    val_set=val_set,
    num_epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    window_size=WINDOW_SIZE,
    masks_per_center=MASKS_PER_CENTER,
    max_centers=MAX_CENTERS,
    pos_weight=1.0,               # or >1 later
    checkpoint_dir=checkpoint_dir,
    val_fast=True,
    val_max_centers=64,           # cheap val
    val_masks_per_center=4,
    val_max_batches=50,           # at most 50 val batches per epoch
)

    MODEL_SAVE_PATH = "./it.pth"
    torch.save({'model_state_dict': model.state_dict()}, MODEL_SAVE_PATH)
    print(f"Model saved to {MODEL_SAVE_PATH}")


Unzip complete. New/updated files extracted: 0
Total dataset size: 256571
Caching train subset into RAM...
Caching val subset into RAM...
Training samples: 10262 | Validation samples: 2566
Train List Length: 10262
Saved training file list to: ./train_list.txt
Training samples: 10262 | Validation samples: 2566
---------------
Using 128 centers 
 8 masks
Started training...
cuda



Epoch 1/4: 100%|██████████| 1283/1283 [1:32:47<00:00,  4.34s/it, loss=0.0678, acc=0.973, fill_acc=0.852, empty_acc=0.986]


Epoch 1: train_loss = 0.0678 | train_acc = 0.9731 | fill_acc = 0.8516 | empty_acc = 0.9862
  [VAL] loss = 0.0492 | acc = 0.9808 | fill_acc = 0.9229 | empty_acc = 0.9865
Saved checkpoint to ./checkpoints_it/checkpoint_epoch_1.pth
Updated best checkpoint: ./checkpoints_it/best_checkpoint.pth


Epoch 2/4:  13%|█▎        | 168/1283 [12:13<1:20:30,  4.33s/it, loss=0.0526, acc=0.979, fill_acc=0.902, empty_acc=0.988]

# Inference

In [59]:
def to_numpy(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()

import os
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap


def to_numpy(x):
    """Convert PyTorch tensor → NumPy array safely."""
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()


def _plot_voxels(grid, colors, out_path, title=None):
    """
    Internal helper to plot a 3D voxel grid with per-voxel RGBA colors.
    grid   : [D,H,W] boolean or 0/1 occupancy (NumPy)
    colors : [D,H,W,4] RGBA numeric NumPy
    """
    grid = to_numpy(grid)
    colors = to_numpy(colors)

    D, H, W = grid.shape

    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection='3d')
    ax.voxels(grid.astype(bool), facecolors=colors, edgecolor='k')

    ax.set_xlim(0, D)
    ax.set_ylim(0, H)
    ax.set_zlim(0, W)

    try:
        ax.set_box_aspect([1, 1, 1])
    except:
        pass

    if title:
        ax.set_title(title)

    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[VIS] Saved {out_path}")


# ------------------------------------------------------------
# 1) FULL OBJECT (GT)
# ------------------------------------------------------------
def visualize_full_object(voxel_grid, out_dir, name="gt"):
    """
    voxel_grid: [D,H,W] or [1,1,D,H,W] or [1,D,H,W]
    All GT filled voxels shown in blue.
    """
    os.makedirs(out_dir, exist_ok=True)
    vg = to_numpy(voxel_grid)

    if vg.ndim == 5:
        vg = vg[0, 0]
    elif vg.ndim == 4:
        vg = vg[0]

    gt_filled = vg > 0.5

    colors = np.zeros((*vg.shape, 4))
    colors[gt_filled] = [0.0, 0.7, 1.0, 1.0]  # blue

    out_path = os.path.join(out_dir, f"{name}_full.png")
    _plot_voxels(vg, colors, out_path, title="Full Object (GT)")



# ------------------------------------------------------------
# 2) MASKED OBJECT (masked voxels = red)
# ------------------------------------------------------------
def visualize_masked_object(complete_gt, masked_coords, out_dir, name="masked"):
    """
    complete_gt   : [1,1,D,H,W] or [1,D,H,W] or [D,H,W]
    masked_coords : Tensor [N,3] or list of (z,y,x)

    - GT voxels: blue
    - Masked ON-object: solid red
    - Masked OFF-object: very transparent dark red (halo)
    """
    os.makedirs(out_dir, exist_ok=True)

    gt = to_numpy(complete_gt)
    if gt.ndim == 5:
        gt = gt[0, 0]
    elif gt.ndim == 4:
        gt = gt[0]

    D, H, W = gt.shape
    gt_filled = gt > 0.5

    # Mask map
    is_masked = np.zeros((D, H, W), dtype=bool)
    if isinstance(masked_coords, torch.Tensor):
        coords = masked_coords.detach().cpu().numpy()
    else:
        coords = np.array(masked_coords, dtype=np.int64)
    for zz, yy, xx in coords:
        if 0 <= zz < D and 0 <= yy < H and 0 <= xx < W:
            is_masked[int(zz), int(yy), int(xx)] = True

    masked_on_object  = is_masked & gt_filled
    masked_off_object = is_masked & ~gt_filled

    # We want to show all GT filled + all masked centers
    grid = gt_filled | is_masked

    colors = np.zeros((D, H, W, 4))

    # GT = blue
    colors[gt_filled] = [0.0, 0.7, 1.0, 1.0]

    # ON-object masked = solid red
    colors[masked_on_object] = [1.0, 0.0, 0.0, 1.0]

    # OFF-object masked = ultra transparent dark red (ONLY place with transparency)
    colors[masked_off_object] = [0.3, 0.0, 0.0, 0.003]

    out_path = os.path.join(out_dir, f"{name}_masked.png")
    _plot_voxels(grid, colors, out_path, title="Masked (GT=blue, masked red; halo faint)")


# ------------------------------------------------------------
# 3) PREDICTION QUALITY (correct=green, incorrect=red)
# ------------------------------------------------------------
def visualize_predictions(complete_gt, pred_grid, masked_coords, out_dir, name="pred_quality"):
    """
    complete_gt   : [1,1,D,H,W] or [1,D,H,W] or [D,H,W]
    pred_grid     : same shape, binary prediction
    masked_coords : Tensor [N,3] or list of (z,y,x)

    - Unmasked GT voxels: blue
    - Masked & predicted correctly (on or off object): solid green
    - Masked & predicted incorrectly (on or off object): solid red
    """
    os.makedirs(out_dir, exist_ok=True)

    gt = to_numpy(complete_gt)
    pr = to_numpy(pred_grid)

    if gt.ndim == 5:
        gt = gt[0, 0]
        pr = pr[0, 0]
    elif gt.ndim == 4:
        gt = gt[0]
        pr = pr[0]

    D, H, W = gt.shape

    gt_filled = gt > 0.5

    # Masked centers
    is_masked = np.zeros((D, H, W), dtype=bool)
    if isinstance(masked_coords, torch.Tensor):
        coords = masked_coords.detach().cpu().numpy()
    else:
        coords = np.array(masked_coords, dtype=np.int64)
    for zz, yy, xx in coords:
        if 0 <= zz < D and 0 <= yy < H and 0 <= xx < W:
            is_masked[int(zz), int(yy), int(xx)] = True

    correct   = (gt == pr)
    incorrect = ~correct

    masked_correct = is_masked & correct
    masked_wrong   = is_masked & incorrect

    # Show union of GT and prediction to see full shape
    grid = ((gt + pr) > 0).astype(bool)

    colors = np.zeros((D, H, W, 4))

    # Base GT structure = blue
    colors[gt_filled] = [0.0, 0.7, 1.0, 1.0]

    # Masked centers (on or off object) – solid, no transparency
    colors[masked_correct] = [0.0, 1.0, 0.0, 1.0]  # green
    colors[masked_wrong]   = [1.0, 0.0, 0.0, 1.0]  # red

    out_path = os.path.join(out_dir, f"{name}_pred_quality.png")
    _plot_voxels(
        grid,
        colors,
        out_path,
        title="Prediction Quality (blue=GT, green=correct, red=wrong masked)",
    )




In [61]:
import os
import zipfile
import random
import numpy as np
import matplotlib.pyplot as plt
import torch


# =========================
# 1) Utils: Tensor -> numpy
# =========================
def to_numpy(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()


# =========================
# 2) Visualization helpers
# =========================
def _plot_voxels(grid, colors, out_path, title=None):
    """
    grid   : [D,H,W] occupancy, NumPy or Tensor
    colors : [D,H,W,4] RGBA, NumPy or Tensor
    """
    grid = to_numpy(grid)
    colors = to_numpy(colors)

    D, H, W = grid.shape

    fig = plt.figure(figsize=(6, 6))
    ax = fig.add_subplot(111, projection="3d")
    # very light grey borders instead of black
    ax.voxels(grid.astype(bool), facecolors=colors, edgecolor=(0.85, 0.85, 0.85, 0.4))

    ax.set_xlim(0, D)
    ax.set_ylim(0, H)
    ax.set_zlim(0, W)

    try:
        ax.set_box_aspect([1, 1, 1])
    except Exception:
        pass

    if title:
        ax.set_title(title)

    plt.tight_layout()
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[VIS] Saved {out_path}")

def visualize_gt_vs_pred_side_by_side(complete_gt, pred_grid, out_dir, name="gt_vs_pred"):
    """
    Save a single PNG with:
      - Left:  Ground truth (only GT filled voxels in blue)
      - Right: Prediction
          * Predicted voxels = blue
          * Voxels where GT != pred = yellow

    No transparent voxels in either view.
    """
    os.makedirs(out_dir, exist_ok=True)

    gt = to_numpy(complete_gt)
    pr = to_numpy(pred_grid)

    if gt.ndim == 5:
        gt = gt[0, 0]
        pr = pr[0, 0]
    elif gt.ndim == 4:
        gt = gt[0]
        pr = pr[0]

    D, H, W = gt.shape

    gt_filled = gt > 0.5
    pr_filled = pr > 0.5
    diff = (gt_filled != pr_filled)

    # --- GT view: only GT voxels drawn ---
    grid_gt = gt_filled.copy()
    colors_gt = np.zeros((D, H, W, 4))
    colors_gt[gt_filled] = [0.0, 0.7, 1.0, 1.0]    # blue

    # --- Pred view: only predicted/diff voxels drawn ---
    grid_pr = pr_filled | diff
    colors_pr = np.zeros((D, H, W, 4))
    colors_pr[pr_filled] = [0.0, 0.7, 1.0, 1.0]    # blue
    colors_pr[diff]      = [1.0, 1.0, 0.0, 1.0]    # yellow

    fig = plt.figure(figsize=(12, 6))

    # Left: GT
    ax1 = fig.add_subplot(121, projection="3d")
    ax1.voxels(
        grid_gt.astype(bool),
        facecolors=colors_gt,
        edgecolor=(0.85, 0.85, 0.85, 0.4),
    )
    ax1.set_title("Ground Truth")
    ax1.set_xlim(0, D); ax1.set_ylim(0, H); ax1.set_zlim(0, W)
    try:
        ax1.set_box_aspect([1, 1, 1])
    except Exception:
        pass

    # Right: Prediction
    ax2 = fig.add_subplot(122, projection="3d")
    ax2.voxels(
        grid_pr.astype(bool),
        facecolors=colors_pr,
        edgecolor=(0.85, 0.85, 0.85, 0.4),
    )
    ax2.set_title("Prediction (yellow = difference)")
    ax2.set_xlim(0, D); ax2.set_ylim(0, H); ax2.set_zlim(0, W)
    try:
        ax2.set_box_aspect([1, 1, 1])
    except Exception:
        pass

    plt.tight_layout()
    out_path = os.path.join(out_dir, f"{name}_gt_vs_pred.png")
    plt.savefig(out_path, dpi=200)
    plt.close()
    print(f"[VIS] Saved {out_path}")


def visualize_full_object(voxel_grid, out_dir, name="gt"):
    """
    voxel_grid: [D,H,W] or [1,1,D,H,W] or [1,D,H,W]
    GT voxels always blue.
    """
    os.makedirs(out_dir, exist_ok=True)
    vg = to_numpy(voxel_grid)

    if vg.ndim == 5:
        vg = vg[0, 0]
    elif vg.ndim == 4:
        vg = vg[0]

    colors = np.zeros((*vg.shape, 4))
    # GT full voxels in blue
    colors[vg > 0.5] = [0.0, 0.7, 1.0, 1.0]

    out_path = os.path.join(out_dir, f"{name}_full.png")
    _plot_voxels(vg, colors, out_path, title="Full Object (GT)")


def visualize_masked_object(complete_gt, masked_coords, out_dir, name="masked"):
    os.makedirs(out_dir, exist_ok=True)

    gt = to_numpy(complete_gt)
    if gt.ndim == 5:
        gt = gt[0, 0]
    elif gt.ndim == 4:
        gt = gt[0]

    D, H, W = gt.shape

    # Mask array
    is_masked = np.zeros((D, H, W), dtype=bool)
    coords = masked_coords.detach().cpu().numpy() if isinstance(masked_coords, torch.Tensor) else np.array(masked_coords)
    for zz, yy, xx in coords:
        is_masked[int(zz), int(yy), int(xx)] = True

    gt_filled = gt > 0.5
    masked_on_object  = is_masked & gt_filled
    masked_off_object = is_masked & ~gt_filled

    grid = gt_filled | is_masked

    colors = np.zeros((D, H, W, 4))

    # GT = blue
    colors[gt_filled] = [0.0, 0.7, 1.0, 1.0]

    # ON-object masked = solid red
    colors[masked_on_object] = [1.0, 0.0, 0.0, 1.0]

    # OFF-object masked = ULTRA transparent dark red
    colors[masked_off_object] = [0.3, 0.0, 0.0, 0.003]   # <--- fix

    out_path = os.path.join(out_dir, f"{name}_masked.png")
    _plot_voxels(grid, colors, out_path, title="Masked Object")

def visualize_predictions(complete_gt, pred_grid, masked_coords, out_dir, name="pred_quality"):
    """
    complete_gt   : [1,1,D,H,W] or [1,D,H,W] or [D,H,W]
    pred_grid     : same shape, binary prediction
    masked_coords : Tensor [N,3] or list of (z,y,x)

    - Unmasked GT voxels: blue
    - Masked ON-object & predicted correctly (GT=1, pred=1): green
    - Masked ON-object & predicted incorrectly (GT=1, pred=0): red
    - Masked OFF-object & predicted incorrectly (GT=0, pred=1): red
    - Masked OFF-object & predicted correctly (GT=0, pred=0): not drawn
    """
    os.makedirs(out_dir, exist_ok=True)

    gt = to_numpy(complete_gt)
    pr = to_numpy(pred_grid)

    if gt.ndim == 5:
        gt = gt[0, 0]
        pr = pr[0, 0]
    elif gt.ndim == 4:
        gt = gt[0]
        pr = pr[0]

    D, H, W = gt.shape

    gt_filled = gt > 0.5
    pred_filled = pr > 0.5

    # Masked centers
    is_masked = np.zeros((D, H, W), dtype=bool)
    if isinstance(masked_coords, torch.Tensor):
        coords = masked_coords.detach().cpu().numpy()
    else:
        coords = np.array(masked_coords, dtype=np.int64)
    for zz, yy, xx in coords:
        if 0 <= zz < D and 0 <= yy < H and 0 <= xx < W:
            is_masked[int(zz), int(yy), int(xx)] = True

    correct   = (gt_filled == pred_filled)
    incorrect = ~correct

    # Split masked into on/off object
    masked_on_object  = is_masked & gt_filled
    masked_off_object = is_masked & ~gt_filled

    # Apply your rules
    masked_correct_on  = masked_on_object  & correct           # GT=1, pred=1
    masked_wrong_on    = masked_on_object  & incorrect         # GT=1, pred=0
    masked_wrong_off   = masked_off_object & incorrect         # GT=0, pred=1
    # masked_correct_off (GT=0, pred=0) → invisible, not drawn

    # Voxels we actually draw:
    # - any GT filled voxel
    # - any masked center that is on-object (correct or wrong)
    # - any masked OFF-object with a wrong prediction
    grid = gt_filled | masked_correct_on | masked_wrong_on | masked_wrong_off

    colors = np.zeros((D, H, W, 4))

    # Base: GT structure = blue
    colors[gt_filled] = [0.0, 0.7, 1.0, 1.0]

    # Masked ON-object:
    colors[masked_correct_on] = [0.0, 1.0, 0.0, 1.0]  # green
    colors[masked_wrong_on]   = [1.0, 0.0, 0.0, 1.0]  # red

    # Masked OFF-object, wrong (GT=0, pred=1): red
    colors[masked_wrong_off]  = [1.0, 0.0, 0.0, 1.0]

    out_path = os.path.join(out_dir, f"{name}_pred_quality.png")
    _plot_voxels(
        grid,
        colors,
        out_path,
        title="Prediction Quality (blue=GT, green=correct, red=wrong)",
    )


# =======================================
# 3) Rebuild full predicted voxel grid
# =======================================
def reconstruct_predicted_grid(complete_gt, masked_coords, logits_for_masked):
    """
    complete_gt      : [1,1,D,H,W] tensor
    masked_coords    : Tensor [N,3] or list of (z,y,x)
    logits_for_masked: [N,1] tensor (raw logits from model)
    Returns:
        pred_grid [1,1,D,H,W] float tensor in {0,1}
    """
    pred_grid = complete_gt.clone()

    if isinstance(masked_coords, torch.Tensor):
        coords = masked_coords.detach().cpu().numpy()
    else:
        coords = np.array(masked_coords, dtype=np.int64)

    probs = torch.sigmoid(logits_for_masked).view(-1).detach().cpu()
    preds = (probs >= 0.5).float().numpy()  # [N]

    for i, (zz, yy, xx) in enumerate(coords):
        pred_grid[0, 0, int(zz), int(yy), int(xx)] = float(preds[i])

    return pred_grid


# ======================================================
# 4) Inference: random objects, mask near-object voxels
# ======================================================
def run_inference_random_mask_from_zip_with_viz(
    model: torch.nn.Module,
    zip_path: str,
    train_list_path: str,
    window_size: int,
    mask_fraction: float = 0.2,
    max_files: int = 3,
    patch_batch_size: int = 512,
    device: torch.device = None,
    out_dir: str = "./vis_infer",
):
    """
    - Randomly samples NPZs from `zip_path` whose basenames are NOT in `train_list_path`.
    - For each selected volume:
        * Loads 'complete' [D,H,W].
        * Builds complete_gt [1,1,D,H,W].
        * Finds voxels "near the object": centers whose ws^3 window has at least 1 filled voxel.
        * Randomly masks `mask_fraction` of those centers:
              occ -> 0, known -> 0.
        * Predicts all masked voxels using model(neighbors_patch, known_mask).
        * Reconstructs full predicted grid `pred_grid`.
        * Visualizes:
              - full GT (blue),
              - masked object (GT blue, masked red),
              - prediction quality (blue, green, red as described).
    """
    model.eval()
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    os.makedirs(out_dir, exist_ok=True)

    # --- read train basenames ---
    train_basenames = set()
    with open(train_list_path, "r") as f:
        for line in f:
            p = line.strip()
            if not p:
                continue
            base = os.path.basename(p)
            train_basenames.add(base)
    print(f"[INF] Loaded {len(train_basenames)} train basenames from {train_list_path}")

    # --- scan zip ---
    with zipfile.ZipFile(zip_path, "r") as zf:
        all_members = [m for m in zf.infolist() if m.filename.lower().endswith(".npz")]

        test_members = []
        for m in all_members:
            base = os.path.basename(m.filename)
            if base not in train_basenames:
                test_members.append(m)

        print(f"[INF] Found {len(test_members)} NPZ files in zip not in train list")

        # randomize order so we don't always use the same files
        random.shuffle(test_members)

        if max_files is not None and len(test_members) > max_files:
            test_members = test_members[:max_files]
            print(f"[INF] Randomly sampling {max_files} test files")

        global_correct = 0
        global_count   = 0
        global_fill_correct = 0
        global_fill_total   = 0
        global_empty_correct = 0
        global_empty_total   = 0

        r = window_size // 2

        for idx, member in enumerate(test_members):
            print(f"\n[INF] File {idx+1}/{len(test_members)}: {member.filename}")
            base_name = os.path.splitext(os.path.basename(member.filename))[0]

            with zf.open(member, "r") as f:
                data = np.load(f)
                if "complete" not in data:
                    print("  [WARN] 'complete' key missing, skipping")
                    continue
                complete_np = data["complete"]

            complete = torch.from_numpy(complete_np).float()
            if complete.dim() == 3:
                complete = complete.unsqueeze(0)  # [1,D,H,W]
            if complete.dim() == 4:
                if complete.shape[0] != 1:
                    print(f"  [WARN] Unexpected channels {complete.shape}, skipping")
                    continue
                complete = complete.unsqueeze(0)  # [1,1,D,H,W]
            elif complete.dim() != 5:
                print(f"  [WARN] Unexpected shape {complete.shape}, skipping")
                continue

            complete = (complete > 0.5).float()
            complete_gt = complete.clone()  # [1,1,D,H,W]
            _, _, D, H, W = complete.shape

            # working occ + known
            occ_grid   = complete.clone()
            known_grid = torch.ones_like(occ_grid)

            # ---- find candidate centers "near object" ----
            gt_np = complete_gt[0, 0].cpu().numpy()
            candidate_coords = []
            for zz in range(D):
                for yy in range(H):
                    for xx in range(W):
                        d0 = max(0, zz - r); d1 = min(D, zz + r + 1)
                        h0 = max(0, yy - r); h1 = min(H, yy + r + 1)
                        w0 = max(0, xx - r); w1 = min(W, xx + r + 1)
                        local = gt_np[d0:d1, h0:h1, w0:w1]
                        if local.sum() > 0:
                            candidate_coords.append((zz, yy, xx))

            if len(candidate_coords) == 0:
                print("  [WARN] No near-object centers found, skipping file")
                continue

            num_candidates = len(candidate_coords)
            num_mask = max(1, int(mask_fraction * num_candidates))
            idxs = np.random.choice(num_candidates, size=num_mask, replace=False)
            masked_coords = np.array([candidate_coords[i] for i in idxs], dtype=np.int64)
            masked_coords_t = torch.from_numpy(masked_coords).long()

            # apply mask: occ->0, known->0 for masked centers
            for (zz, yy, xx) in masked_coords:
                occ_grid[0, 0, zz, yy, xx]   = 0.0
                known_grid[0, 0, zz, yy, xx] = 0.0

            all_logits = []
            all_labels = []

            with torch.no_grad():
                for start in range(0, masked_coords_t.shape[0], patch_batch_size):
                    end = min(start + patch_batch_size, masked_coords_t.shape[0])
                    batch_coords = masked_coords_t[start:end]  # [B',3]

                    patch_list = []
                    known_list = []
                    label_list = []

                    for (zz, yy, xx) in batch_coords.tolist():
                        patch_occ, patch_known = neighborhood_raw(
                            occ_grid, known_grid, b=0, d=zz, h=yy, w=xx, window_size=window_size
                        )  # patch_occ: [1,1,ws,ws,ws], patch_known: [1,ws,ws,ws]

                        patch_list.append(patch_occ)
                        known_list.append(patch_known)
                        label = complete_gt[0, 0, zz, yy, xx].view(1, 1)  # [1,1]
                        label_list.append(label)

                    patches = torch.cat(patch_list, dim=0).to(device)   # [B',1,ws,ws,ws]
                    knowns  = torch.cat(known_list, dim=0).to(device)   # [B',ws,ws,ws]
                    labels  = torch.cat(label_list, dim=0).to(device)   # [B',1]

                    logits = model(patches, knowns)                     # [B',1]

                    all_logits.append(logits.cpu())
                    all_labels.append(labels.cpu())

            if len(all_logits) == 0:
                print("  [WARN] No patches built, skipping stats/vis")
                continue

            if len(all_logits) == 1:
                logits_all = all_logits[0]
                labels_all = all_labels[0]
            else:
                logits_all = torch.cat(all_logits, dim=0)
                labels_all = torch.cat(all_labels, dim=0)

            probs = torch.sigmoid(logits_all)
            preds = (probs >= 0.5).float()

            labels_flat = labels_all.view(-1)
            preds_flat  = preds.view(-1)

            correct = (preds_flat == labels_flat).sum().item()
            count   = labels_flat.numel()

            mask_filled = (labels_flat == 1)
            mask_empty  = (labels_flat == 0)

            if mask_filled.any():
                file_fill_correct = (preds_flat[mask_filled] == labels_flat[mask_filled]).sum().item()
                file_fill_total   = mask_filled.sum().item()
            else:
                file_fill_correct = 0
                file_fill_total   = 0

            if mask_empty.any():
                file_empty_correct = (preds_flat[mask_empty] == labels_flat[mask_empty]).sum().item()
                file_empty_total   = mask_empty.sum().item()
            else:
                file_empty_correct = 0
                file_empty_total   = 0

            file_acc      = correct / count
            file_fill_acc = file_fill_correct / file_fill_total if file_fill_total > 0 else float("nan")
            file_empty_acc = file_empty_correct / file_empty_total if file_empty_total > 0 else float("nan")

            print(
                f"  [FILE STATS] acc={file_acc:.4f} | "
                f"fill_acc={file_fill_acc:.4f} | empty_acc={file_empty_acc:.4f}"
            )

            # --- accumulate global stats ---
            global_correct += correct
            global_count   += count
            global_fill_correct += file_fill_correct
            global_fill_total   += file_fill_total
            global_empty_correct += file_empty_correct
            global_empty_total   += file_empty_total

            # --- reconstruct full predicted grid & visualize ---
            pred_grid = reconstruct_predicted_grid(
                complete_gt=complete_gt,
                masked_coords=masked_coords_t,
                logits_for_masked=logits_all,
            )

            name = f"file{idx+1}_{base_name}"
            visualize_full_object(complete_gt, out_dir, name=name)
            visualize_masked_object(complete_gt, masked_coords_t, out_dir, name=name)
            visualize_predictions(complete_gt, pred_grid, masked_coords_t, out_dir, name=name)
            visualize_gt_vs_pred_side_by_side(complete_gt, pred_grid, out_dir, name=name)


        if global_count > 0:
            g_acc = global_correct / global_count
            g_fill_acc = global_fill_correct / global_fill_total if global_fill_total > 0 else float("nan")
            g_empty_acc = global_empty_correct / global_empty_total if global_empty_total > 0 else float("nan")

            print("\n[GLOBAL INFERENCE STATS]")
            print(f"  acc       = {g_acc:.4f}")
            print(f"  fill_acc  = {g_fill_acc:.4f}")
            print(f"  empty_acc = {g_empty_acc:.4f}")
        else:
            print("[INF] No samples processed in inference.")


In [76]:
ZIP_PATH = "/home/raedfidawi/Documents/thesis/3DLLM/chunk_data_16_flood_fill_rm_20.zip"
VIS_PATH = "./vis_infer_16"

WINDOW_SIZE = 11
D_MODEL = 64
NUM_HEADS = 4
NUM_LAYERS = 4
DROPOUT = 0.1
GRID_SIZE = 16

model = IterativeVoxelModel(
    d_model=D_MODEL,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    window_size=WINDOW_SIZE,
    dropout=DROPOUT,
    max_grid_size=GRID_SIZE,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

TRAIN_LIST = "./train_list.txt"
MODEL_CKPT = "./checkpoint_epoch_1.pth"

ckpt = torch.load(MODEL_CKPT, map_location=device)
model.load_state_dict(ckpt["model_state_dict"])
print(f"Loaded model from {MODEL_CKPT}, epoch {ckpt.get('epoch', '?')}")


for filename in os.listdir(VIS_PATH):
    if filename.endswith(".png"):
        file_path = os.path.join(VIS_PATH, filename)
        os.remove(file_path)

print(f"[!] Deleted Previous Files")

run_inference_random_mask_from_zip_with_viz(
    model,
    zip_path=ZIP_PATH,
    train_list_path=TRAIN_LIST,
    window_size=WINDOW_SIZE,
    mask_fraction=0.2,        # 20% of near-object voxels
    max_files=20,              # change if you want more objects
    patch_batch_size=512,
    device=device,
    out_dir=VIS_PATH,
)


Loaded model from ./checkpoint_epoch_1.pth, epoch 1
[!] Deleted Previous Files
[INF] Loaded 10262 train basenames from ./train_list.txt


KeyboardInterrupt: 