# ILT Patch Extraction – 48×48 design → 24×24 ILT center

Generates two HDF5 files for training VAE / GAN / Transformer:
- `ilt_patches_design_only.h5` (1 channel)
- `ilt_patches_3ch.h5` (design + aerial + resist)

Stride=8, outer 24 px rim avoided, empty centers skipped, flips only.

In [None]:
# Cell 1 – Imports & paths
import torch
import numpy as np
import h5py
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display, clear_output
from tqdm.notebook import tqdm
import os
from scipy.ndimage import gaussian_filter
import glob

device = torch.device('mps' if torch.backends.mps.is_available() else 'cpu')
print(f"Device: {device}")

# ── Edit this path to your raw ILT dataset ─────────────────────
RAW_NPZ_PATH = "data/ilt_dataset_v1_m4/*/ilt_dataset_500.npz"   # or exact path
OUTPUT_DIR = "data/ilt_patches"
os.makedirs(OUTPUT_DIR, exist_ok=True)

PATCH_SIZE = 48
TARGET_SIZE = 24
STRIDE = 8
INNER_MARGIN = 24   # avoid outer rim
EMPTY_THRESHOLD = 0.01

In [None]:
# Cell 2 – Load all raw ILT samples
files = glob.glob(RAW_NPZ_PATH)
if not files:
    raise FileNotFoundError("No raw ILT .npz files found. Update RAW_NPZ_PATH.")

all_samples = []
for f in files:
    data = np.load(f, allow_pickle=True)
    samples = data['samples'].tolist() if 'samples' in data else list(data.values())[0]
    all_samples.extend(samples)

print(f"Loaded {len(all_samples)} raw ILT samples.")

In [None]:
# Cell 3 – Fast simulation functions (reuse from previous notebook)
# (paste your fast gaussian_blur_torch, simple_resist, pad/unpad, simulate_aerial_socs, etc. here)
# For brevity I assume you copy them from your ILT notebook. If you need me to include them, say so.

In [None]:
# Cell 4 – Patch extraction
def extract_patches():
    design_only = []
    three_ch = []
    metadata = []

    for sample_idx, s in enumerate(tqdm(all_samples, desc="Extracting patches")):
        design = torch.from_numpy(s['target_raster']).float().to(device)   # [256,256]
        ilt    = torch.from_numpy(s['ilt_mask']).float().to(device)
        aerial = torch.from_numpy(s['aerial']).float().to(device)
        resist = torch.from_numpy(s['resist']).float().to(device)

        H = design.shape[0]
        for y in range(INNER_MARGIN, H - PATCH_SIZE - INNER_MARGIN + 1, STRIDE):
            for x in range(INNER_MARGIN, H - PATCH_SIZE - INNER_MARGIN + 1, STRIDE):
                # 48x48 design patch
                d48 = design[y:y+PATCH_SIZE, x:x+PATCH_SIZE]
                # center 24x24 design
                c24 = d48[12:36, 12:36]
                if c24.mean().item() < EMPTY_THRESHOLD:
                    continue

                # 24x24 ILT target
                i24 = ilt[y+12:y+36, x+12:x+36]

                # augmentations (flips only)
                for flip_h in [False, True]:
                    for flip_v in [False, True]:
                        d48_aug = d48.clone()
                        i24_aug = i24.clone()
                        if flip_h:
                            d48_aug = d48_aug.flip(1)
                            i24_aug = i24_aug.flip(1)
                        if flip_v:
                            d48_aug = d48_aug.flip(0)
                            i24_aug = i24_aug.flip(0)

                        design_only.append(d48_aug.cpu().numpy())
                        three_ch.append(torch.stack([d48_aug, aerial[y:y+PATCH_SIZE, x:x+PATCH_SIZE],
                                                     resist[y:y+PATCH_SIZE, x:x+PATCH_SIZE]], dim=0).cpu().numpy())

                        metadata.append({
                            'sample_idx': sample_idx,
                            'patch_y': y,
                            'patch_x': x,
                            'structure': s['structure'],
                            'flip_h': flip_h,
                            'flip_v': flip_v
                        })

    return np.array(design_only), np.array(three_ch), metadata

print("Starting patch extraction...")
design_only, three_ch, meta = extract_patches()
print(f"Extracted {len(design_only)} patches (after flips and filtering).")

In [None]:
# Cell 5 – Save to two HDF5 files
def save_hdf5(data, filename, meta):
    with h5py.File(os.path.join(OUTPUT_DIR, filename), 'w') as f:
        f.create_dataset('design_48', data=data[:, 0] if data.ndim == 4 else data, compression='gzip')
        if data.ndim == 4:  # 3ch
            f.create_dataset('aerial_48', data=data[:, 1], compression='gzip')
            f.create_dataset('resist_48', data=data[:, 2], compression='gzip')
        f.create_dataset('design_24', data=[s['design_24'] for s in meta], compression='gzip')  # you can add this if you want
        # metadata as attributes or separate dataset
        print(f"Saved {filename} ({data.shape[0]} patches)")

save_hdf5(design_only, "ilt_patches_design_only.h5", meta)
save_hdf5(three_ch, "ilt_patches_3ch.h5", meta)
print("Both HDF5 files created in", OUTPUT_DIR)

In [None]:
# Cell 6 – Interactive Browser (with on-the-fly resist simulation)
# (full widget code here – shows 48×48 design, 24×24 ILT, simulated resist from both, etc.)
# I can provide the full widget in the next message if you want it now.