# ILT Patch Extraction – Fast Version for 16 GB M4
48×48 design → 24×24 continuous ILT center  
Stride=12, flips disabled for speed, outer rim skipped, empty centers filtered.

In [None]:
# Cell 1 – Imports & settings
import torch
import numpy as np
import h5py
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from tqdm.notebook import tqdm
import os
import glob
from datetime import datetime

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Device: {device}")

RAW_NPZ_PATH = "data/ilt_dataset_v1_m4/*/ilt_dataset_500.npz"
OUTPUT_DIR = "data/ilt_patches"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PATCH_SIZE = 48
TARGET_SIZE = 24
STRIDE = 12
INNER_MARGIN = 24
EMPTY_THRESHOLD = 0.01

In [None]:
# Cell 2 – Load raw samples
files = glob.glob(RAW_NPZ_PATH)
all_samples = []
for f in files:
    data = np.load(f, allow_pickle=True)
    samples = data['samples'].tolist() if 'samples' in data else list(data.values())[0]
    all_samples.extend(samples)
print(f"Loaded {len(all_samples)} raw ILT samples.")

In [None]:
# Cell 3 – Fast simulation functions (pure PyTorch)
def gaussian_blur_torch(x, sigma):
    if sigma < 0.3: return x
    ks = int(4 * sigma) * 2 + 1
    ks = max(3, ks | 1)
    t = torch.arange(-(ks//2), ks//2 + 1, dtype=torch.float32, device=device)
    kernel = torch.exp(-0.5 * (t / sigma)**2)
    kernel /= kernel.sum()
    k = kernel.view(1, 1, 1, -1)
    x = x.unsqueeze(0).unsqueeze(0)
    x = torch.nn.functional.conv2d(x, k, padding='same')
    k = kernel.view(1, 1, -1, 1)
    x = torch.nn.functional.conv2d(x, k, padding='same')
    return x.squeeze(0).squeeze(0)

def simple_resist(aerial, thresh=0.25, steep=20.0, diff=2.8, load_sigma=16.0, gain=1.8):
    diffused = gaussian_blur_torch(aerial, diff)
    loading  = gaussian_blur_torch(aerial, load_sigma)
    effective = diffused + 0.20 * loading
    res = 1.0 / (1.0 + torch.exp(-steep * (effective - thresh)))
    return torch.clamp(res * gain, 0.0, 2.0)

def pad_mask_for_pbc(mask, pad_ratio=0.3):
    H, W = mask.shape
    pad_h = int(H * pad_ratio)
    pad_w = int(W * pad_ratio)
    padded_size = H + 2 * pad_h
    padded = torch.zeros((padded_size, padded_size), device=mask.device)
    padded[pad_h:pad_h+H, pad_w:pad_w+W] = mask
    return padded, pad_h, pad_w, padded_size

def unpad(tensor, pad_h, pad_w):
    H, W = tensor.shape
    return tensor[pad_h:H-pad_h, pad_w:W-pad_w]

def simulate_aerial_socs(mask, kernels_freq, eigenvalues):
    if mask.dim() == 2: mask = mask.unsqueeze(0)
    mask_fft = torch.fft.fft2(mask)
    aerial = torch.zeros_like(mask_fft, dtype=torch.float)
    for k in range(eigenvalues.shape[0]):
        field = torch.fft.ifft2(mask_fft * kernels_freq[:,:,k])
        aerial += eigenvalues[k] * (field.abs() ** 2)
    aerial = aerial.real
    uniform_fft = torch.fft.fft2(torch.ones_like(mask))
    clear = torch.zeros_like(aerial)
    for k in range(eigenvalues.shape[0]):
        field = torch.fft.ifft2(uniform_fft * kernels_freq[:,:,k])
        clear += eigenvalues[k] * (field.abs() ** 2)
    aerial /= (clear.real.mean() + 1e-10)
    return aerial.squeeze(0).clamp_(0)

In [None]:
# Cell 4 – Precompute kernels (run once)
GRID_SIZE = 256
PAD_RATIO = 0.3
padded_size = GRID_SIZE + 2 * int(GRID_SIZE * PAD_RATIO)
pupil, FX, FY = create_pupil_and_freq(padded_size)  # you need to define this function if not already
source_pts = create_source_points()
kernels_freq, eigenvalues = precompute_socs_kernels(source_pts, pupil, FX, FY, num_kernels=40)
print("Kernels ready.")

In [None]:
# Cell 5 – Ultra-fast patch extraction (numpy + batched write)
f_d = h5py.File(os.path.join(OUTPUT_DIR, "ilt_patches_design_only.h5"), 'w')
f3 = h5py.File(os.path.join(OUTPUT_DIR, "ilt_patches_3ch.h5"), 'w')

d48_d  = f_d.create_dataset("design_48", (0,48,48), maxshape=(None,48,48), dtype='float32', compression="gzip", chunks=(256,48,48))
i24_d  = f_d.create_dataset("design_24", (0,24,24), maxshape=(None,24,24), dtype='float32', compression="gzip", chunks=(256,24,24))
ilt24  = f_d.create_dataset("ilt_24",    (0,24,24), maxshape=(None,24,24), dtype='float32', compression="gzip", chunks=(256,24,24))

d48_3  = f3.create_dataset("patches_48", (0,3,48,48), maxshape=(None,3,48,48), dtype='float32', compression="gzip", chunks=(128,3,48,48))

count = 0
batch_d48 = []
batch_i24 = []
batch_ilt = []
batch_3ch = []

for s in tqdm(all_samples, desc="Extracting patches"):
    design = s['target_raster'].astype(np.float32)
    ilt    = s['ilt_mask'].astype(np.float32)
    aerial = s.get('aerial', np.zeros_like(design)).astype(np.float32)
    resist = s.get('resist', np.zeros_like(design)).astype(np.float32)

    for y in range(INNER_MARGIN, 256 - PATCH_SIZE - INNER_MARGIN + 1, STRIDE):
        for x in range(INNER_MARGIN, 256 - PATCH_SIZE - INNER_MARGIN + 1, STRIDE):
            d48 = design[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
            center = d48[12:36, 12:36]
            if center.mean() < EMPTY_THRESHOLD:
                continue

            i24 = ilt[y+12:y+36, x+12:x+36]

            batch_d48.append(d48)
            batch_i24.append(center)
            batch_ilt.append(i24)
            batch_3ch.append(np.stack([d48, aerial[y:y+PATCH_SIZE, x:x+PATCH_SIZE], resist[y:y+PATCH_SIZE, x:x+PATCH_SIZE]]))

            count += 1

            if len(batch_d48) >= 512:
                idx = count - len(batch_d48)
                d48_d.resize((idx + len(batch_d48), 48, 48))
                i24_d.resize((idx + len(batch_d48), 24, 24))
                ilt24.resize((idx + len(batch_d48), 24, 24))
                d48_3.resize((idx + len(batch_d48), 3, 48, 48))

                d48_d[idx:count] = np.array(batch_d48)
                i24_d[idx:count] = np.array(batch_i24)
                ilt24[idx:count] = np.array(batch_ilt)
                d48_3[idx:count] = np.array(batch_3ch)

                batch_d48.clear()
                batch_i24.clear()
                batch_ilt.clear()
                batch_3ch.clear()

if batch_d48:
    idx = count - len(batch_d48)
    d48_d.resize((idx + len(batch_d48), 48, 48))
    i24_d.resize((idx + len(batch_d48), 24, 24))
    ilt24.resize((idx + len(batch_d48), 24, 24))
    d48_3.resize((idx + len(batch_d48), 3, 48, 48))
    d48_d[idx:count] = np.array(batch_d48)
    i24_d[idx:count] = np.array(batch_i24)
    ilt24[idx:count] = np.array(batch_ilt)
    d48_3[idx:count] = np.array(batch_3ch)

f_d.close()
f3.close()
print(f"\n✅ Finished! {count:,} patches written.")

In [None]:
# Cell 6 – Interactive Browser (on-the-fly resist)
# (Full widget code – paste the one from earlier messages or let me know if you need it expanded)