# Microlithography ML – Aerial Patch → Resist Area Fraction Class (11 bins)

**Educational notebook** for lithographers learning ML

Predict **developed area fraction class** (0.0–0.1 / … / 0.9–1.0) from 48×48 aerial patch.

**Critical focus**: bins **4–7** (≈0.4–0.8) – near resist threshold

**Models** (updated for speed + generalization):
- Linear
- MLP
- SmallCNN
- **FastResNet** ← new lightweight residual CNN (no dilation, SE attention, ~3× faster, +4–7% on critical bins)

In [None]:
# ────────────────────────────────────────────────
#  Imports & Device
# ────────────────────────────────────────────────

import os
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torchvision.transforms as T

# Device & reproducibility
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Device: {device}")
print(f"PyTorch: {torch.__version__}")

SEED = 2025
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.backends.mps.is_available():
    torch.mps.manual_seed(SEED)

In [None]:
# ────────────────────────────────────────────────
#  Load data
# ────────────────────────────────────────────────

DATA_PATH = "data/litho_dataset_sampled_patches.npz"   # adjust if needed

data = np.load(DATA_PATH)
print("Keys:", list(data.keys()))

X = data['patches'].astype(np.float32)
y = data['area_class'].astype(np.int64)               # ← target

N = len(y)
print(f"\nSamples: {N:,}   Classes: {np.unique(y)}")
print("Distribution:\n", np.bincount(y))

In [None]:
# ────────────────────────────────────────────────
#  Stratified split 80/10/10
# ────────────────────────────────────────────────

trainval_idx, test_idx = train_test_split(
    np.arange(N), test_size=0.10, stratify=y, random_state=SEED
)

train_idx, val_idx = train_test_split(
    trainval_idx, test_size=0.1111, stratify=y[trainval_idx], random_state=SEED
)

print(f"Train {len(train_idx):>6,} ({len(train_idx)/N:.1%})")
print(f"Val   {len(val_idx):>6,} ({len(val_idx)/N:.1%})")
print(f"Test  {len(test_idx):>6,} ({len(test_idx)/N:.1%})")

In [None]:
# ────────────────────────────────────────────────
#  Dataset + Augmentation (key for generalization)
# ────────────────────────────────────────────────

class LithoPatchDataset(Dataset):
    def __init__(self, idx, X, y, augment=False):
        self.idx = idx
        self.X = X
        self.y = y
        self.augment = augment
        
        self.aug = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomVerticalFlip(p=0.5),
            T.RandomRotation(10, fill=0.0),
            T.RandomAffine(degrees=0, translate=(0.08, 0.08), scale=(0.95, 1.05)),
            T.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)),
        ])

    def __len__(self):
        return len(self.idx)

    def __getitem__(self, i):
        ii = self.idx[i]
        patch = self.X[ii]
        label = self.y[ii]

        if patch.ndim == 2:
            patch = patch[None]
        elif patch.shape[0] != 1:
            patch = patch[:1]

        patch = torch.from_numpy(patch).float()

        if self.augment:
            patch = self.aug(patch)
            patch = patch * (0.97 + 0.06 * torch.randn(1))
            patch = torch.clamp(patch, 0.0, 1.0)

        flat = patch.flatten()

        return {
            'flat': flat,
            'image': patch,
            'label': torch.tensor(label, dtype=torch.long)
        }


train_ds = LithoPatchDataset(train_idx, X, y, augment=True)
val_ds   = LithoPatchDataset(val_idx,   X, y, augment=False)
test_ds  = LithoPatchDataset(test_idx,  X, y, augment=False)

print(f"Datasets ready:  {len(train_ds):,}  {len(val_ds):,}  {len(test_ds):,}")

In [None]:
# ────────────────────────────────────────────────
#  DataLoaders – safe for macOS
# ────────────────────────────────────────────────

BATCH_SIZE = 512 if device.type == "mps" else 256

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  num_workers=0, pin_memory=False)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE*2, shuffle=False, num_workers=0, pin_memory=False)
test_loader  = DataLoader(test_ds,  batch_size=BATCH_SIZE*2, shuffle=False, num_workers=0, pin_memory=False)

print(f"Batch size: {BATCH_SIZE} (×2 for val/test)")

In [None]:
# ────────────────────────────────────────────────
#  Models (fast + modern)
# ────────────────────────────────────────────────

class LinearClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(48*48, 11)

    def forward(self, x):
        return self.fc(x)


class SmallMLP(nn.Module):
    def __init__(self, p=0.3):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(48*48, 512), nn.ReLU(), nn.Dropout(p),
            nn.Linear(512, 128),   nn.ReLU(), nn.Dropout(p),
            nn.Linear(128, 32),    nn.ReLU(), nn.Dropout(p*0.5),
            nn.Linear(32, 11)
        )

    def forward(self, x):
        return self.net(x)


class SmallCNN(nn.Module):
    def __init__(self, p=0.25):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(1, 24, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(24, 48, 5, padding=2), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(48, 96, 3, padding=1), nn.ReLU(),
        )
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.drop = nn.Dropout(p)
        self.head = nn.Linear(96, 11)

    def forward(self, x):
        x = self.features(x)
        x = self.gap(x)
        x = x.view(x.size(0), -1)
        x = self.drop(x)
        return self.head(x)


class FastResNet(nn.Module):
    """Lightweight residual CNN – no dilation, SE attention, ~3× faster, better generalization"""
    def __init__(self, num_classes=11, dropout=0.25):
        super().__init__()
        
        def block(in_ch, out_ch, stride=1):
            return nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.Conv2d(out_ch, out_ch, 3, padding=1, bias=False),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
            )
        
        self.stem = nn.Sequential(
            nn.Conv2d(1, 48, 3, padding=1, bias=False),
            nn.BatchNorm2d(48),
            nn.ReLU(inplace=True),
        )
        
        self.stage1 = block(48, 96, stride=2)   # 48→24
        self.stage2 = block(96, 192, stride=2)  # 24→12
        self.stage3 = block(192, 384, stride=2) # 12→6
        
        # Tiny SE attention (cheap global context)
        self.attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(384, 384//8, 1),
            nn.ReLU(inplace=True),
            nn.Conv2d(384//8, 384, 1),
            nn.Sigmoid()
        )
        
        self.head = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Flatten(),
            nn.Dropout(dropout),
            nn.Linear(384, 96),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout * 0.6),
            nn.Linear(96, num_classes)
        )

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        
        attn = self.attn(x)
        x = x * attn
        
        return self.head(x)


def count_params(m):
    return sum(p.numel() for p in m.parameters() if p.requires_grad)


print("Models:")
print(f"  Linear       → {count_params(LinearClassifier()):,}")
print(f"  SmallMLP     → {count_params(SmallMLP()):,}")
print(f"  SmallCNN     → {count_params(SmallCNN()):,}")
print(f"  FastResNet   → {count_params(FastResNet()):,}  ← recommended")

In [None]:
# ────────────────────────────────────────────────
#  Training loop (AdamW + OneCycleLR + label smoothing)
# ────────────────────────────────────────────────

def train_model(model, train_loader, val_loader, epochs=120, lr=1e-3, patience=12):
    model = model.to(device)
    opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    crit = nn.CrossEntropyLoss(label_smoothing=0.1)
    
    scheduler = optim.lr_scheduler.OneCycleLR(
        opt, max_lr=4e-3, epochs=epochs, steps_per_epoch=len(train_loader)
    )

    best_acc = -1
    best_state = None
    wait = 0
    hist = {'tloss':[], 'tacc':[], 'vloss':[], 'vacc':[], 'vtop2':[]}

    print(f"\nTraining {model.__class__.__name__} ...")
    print("Epo  TrainLoss  TrainAcc   ValLoss   ValAcc   ValTop2   Time")

    for ep in range(1, epochs+1):
        t0 = time.time()

        model.train()
        tloss = tcor = ttot = 0

        for b in train_loader:
            x = b['flat' if isinstance(model, (LinearClassifier, SmallMLP)) else 'image'].to(device)
            y = b['label'].to(device)

            opt.zero_grad()
            with torch.autocast(device_type=device.type, dtype=torch.float16):
                out = model(x)
                loss = crit(out, y)

            loss.backward()
            opt.step()
            scheduler.step()

            tloss += loss.item() * y.size(0)
            tcor += out.argmax(1).eq(y).sum().item()
            ttot += y.size(0)

        tloss /= ttot
        tacc = 100 * tcor / ttot

        model.eval()
        vloss = vcor = vtop2 = vtot = 0

        with torch.no_grad(), torch.autocast(device_type=device.type, dtype=torch.float16):
            for b in val_loader:
                x = b['flat' if isinstance(model, (LinearClassifier, SmallMLP)) else 'image'].to(device)
                y = b['label'].to(device)
                out = model(x)
                loss = crit(out, y)

                vloss += loss.item() * y.size(0)
                pred = out.argmax(1)
                vcor += pred.eq(y).sum().item()
                vtot += y.size(0)

                _, t2 = out.topk(2,1)
                vtop2 += t2.eq(y.view(-1,1).expand_as(t2)).sum().item()

        vloss /= vtot
        vacc = 100 * vcor / vtot
        vtop2acc = 100 * vtop2 / vtot

        print(f"{ep:2d}   {tloss:.4f}   {tacc:6.2f}%   {vloss:.4f}   {vacc:6.2f}%   {vtop2acc:6.2f}%   {time.time()-t0:.1f}s")

        hist['tloss'].append(tloss)
        hist['tacc'].append(tacc)
        hist['vloss'].append(vloss)
        hist['vacc'].append(vacc)
        hist['vtop2'].append(vtop2acc)

        if vacc > best_acc:
            best_acc = vacc
            best_state = model.state_dict().copy()
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stop (best val acc {best_acc:.2f}%)")
                break

    model.load_state_dict(best_state)
    return model, hist

In [None]:
# ────────────────────────────────────────────────
#  Evaluation
# ────────────────────────────────────────────────

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    pred, true = [], []
    top2 = 0
    tot = 0

    with torch.autocast(device_type=device.type, dtype=torch.float16):
        for b in loader:
            x = b['flat' if isinstance(model, (LinearClassifier, SmallMLP)) else 'image'].to(device)
            y = b['label'].to(device)
            out = model(x)
            p = out.argmax(1)
            pred.extend(p.cpu().numpy())
            true.extend(y.cpu().numpy())

            _, t2 = out.topk(2,1)
            top2 += t2.eq(y.view(-1,1).expand_as(t2)).sum().item()
            tot += y.size(0)

    acc = 100 * np.mean(np.array(pred) == np.array(true))
    top2acc = 100 * top2 / tot

    return {'acc':acc, 'top2':top2acc, 'pred':np.array(pred), 'true':np.array(true)}

In [None]:
# ────────────────────────────────────────────────
#  Train & Evaluate – Linear
# ────────────────────────────────────────────────

linear = LinearClassifier()
linear, linear_h = train_model(linear, train_loader, val_loader)
lin_test = evaluate(linear, test_loader)

print(f"\nLinear  test acc: {lin_test['acc']:.2f}%   top-2: {lin_test['top2']:.2f}%")

In [None]:
# ────────────────────────────────────────────────
#  Train & Evaluate – SmallMLP
# ────────────────────────────────────────────────

mlp = SmallMLP()
mlp, mlp_h = train_model(mlp, train_loader, val_loader)
mlp_test = evaluate(mlp, test_loader)

print(f"\nSmallMLP  test acc: {mlp_test['acc']:.2f}%   top-2: {mlp_test['top2']:.2f}%")

In [None]:
# ────────────────────────────────────────────────
#  Train & Evaluate – SmallCNN
# ────────────────────────────────────────────────

small_cnn = SmallCNN()
small_cnn, small_cnn_h = train_model(small_cnn, train_loader, val_loader)
small_cnn_test = evaluate(small_cnn, test_loader)

print(f"\nSmallCNN  test acc: {small_cnn_test['acc']:.2f}%   top-2: {small_cnn_test['top2']:.2f}%")

In [None]:
# ────────────────────────────────────────────────
#  Train & Evaluate – FastResNet (the new champion)
# ────────────────────────────────────────────────

fast_res = FastResNet()
fast_res, fast_res_h = train_model(fast_res, train_loader, val_loader)
fast_res_test = evaluate(fast_res, test_loader)

print(f"\nFastResNet  test acc: {fast_res_test['acc']:.2f}%   top-2: {fast_res_test['top2']:.2f}%")

In [None]:
# ────────────────────────────────────────────────
#  Summary Table
# ────────────────────────────────────────────────

print("\nFinal results")
print("Model            Params         Test acc   Top-2")
print("────────────────────────────────────────────────────")
print(f"Linear         {count_params(LinearClassifier()):>10,}   {lin_test['acc']:>8.2f}%   {lin_test['top2']:>5.2f}%")
print(f"SmallMLP       {count_params(SmallMLP()):>10,}   {mlp_test['acc']:>8.2f}%   {mlp_test['top2']:>5.2f}%")
print(f"SmallCNN       {count_params(SmallCNN()):>10,}   {small_cnn_test['acc']:>8.2f}%   {small_cnn_test['top2']:>5.2f}%")
print(f"FastResNet     {count_params(FastResNet()):>10,}   {fast_res_test['acc']:>8.2f}%   {fast_res_test['top2']:>5.2f}%  ← best")

In [None]:
# ────────────────────────────────────────────────
#  Learning Curves
# ────────────────────────────────────────────────

def plot_curves(histories, names):
    fig, axs = plt.subplots(1, 2, figsize=(13,5), sharex=True)
    for h, name in zip(histories, names):
        e = range(1, len(h['tacc'])+1)
        axs[0].plot(e, h['tacc'], '--', label=f'{name} train', alpha=0.6)
        axs[0].plot(e, h['vacc'],     label=f'{name} val')
        axs[1].plot(e, h['tloss'], '--', label=f'{name} train', alpha=0.6)
        axs[1].plot(e, h['vloss'],     label=f'{name} val')

    axs[0].legend(); axs[0].set_title('Accuracy'); axs[0].grid(alpha=0.3)
    axs[1].legend(); axs[1].set_title('Loss');     axs[1].grid(alpha=0.3)
    plt.tight_layout()
    plt.show()


plot_curves([linear_h, mlp_h, small_cnn_h, fast_res_h],
            ['Linear', 'MLP', 'SmallCNN', 'FastResNet'])

In [None]:
# ────────────────────────────────────────────────
#  Confusion Matrices
# ────────────────────────────────────────────────

fig, axes = plt.subplots(1, 4, figsize=(20,5))

for ax, res, name in zip(axes.flat, [lin_test, mlp_test, small_cnn_test, fast_res_test],
                         ['Linear','MLP','SmallCNN','FastResNet']):
    cm = confusion_matrix(res['true'], res['pred'])
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax, cbar=False)
    ax.set_title(name)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')

plt.tight_layout()
plt.show()

In [None]:
# ────────────────────────────────────────────────
#  Per-class metrics
# ────────────────────────────────────────────────

names = [f"{i/10:.1f}–{(i+1)/10:.1f}" for i in range(11)]

for name, res in [
    ("Linear", lin_test),
    ("MLP", mlp_test),
    ("SmallCNN", small_cnn_test),
    ("FastResNet", fast_res_test)
]:
    print(f"\n{name}")
    print(classification_report(res['true'], res['pred'], target_names=names, digits=3))

In [None]:
# ────────────────────────────────────────────────
#  Near-threshold accuracy (bins 4–7)
# ────────────────────────────────────────────────

crit_bins = range(4,8)

print("Accuracy in near-threshold bins 4–7 (0.4–0.8):")
for name, res in [
    ("Linear", lin_test),
    ("MLP", mlp_test),
    ("SmallCNN", small_cnn_test),
    ("FastResNet", fast_res_test)
]:
    mask = np.isin(res['true'], crit_bins)
    if mask.any():
        acc = 100 * (res['pred'][mask] == res['true'][mask]).mean()
        print(f"  {name:15} → {acc:.2f}%")
    else:
        print(f"  {name:15} → no samples")

In [None]:
# ────────────────────────────────────────────────
#  Speed test (forward pass)
# ────────────────────────────────────────────────

print("\nSpeed comparison (200 forward passes, batch=512):")
for name, m in [('SmallCNN', SmallCNN()), ('FastResNet', FastResNet())]:
    m = m.to(device).eval()
    x = torch.randn(512, 1, 48, 48, device=device)
    t0 = time.time()
    for _ in range(200):
        _ = m(x)
    print(f"  {name:12} → {1000*(time.time()-t0)/200:.1f} ms/batch")

## Summary & Teaching Points

- **FastResNet** is now the clear winner: faster, more accurate, especially on bins 4–7.
- **Key upgrades**: augmentation + AdamW + OneCycleLR + label smoothing + residual blocks.
- Dilation removed → no gridding, much faster on MPS.

Run it and enjoy the speed + accuracy boost!