# Siamese Network for Wafer Pass/Fail Prediction
This notebook builds and trains a Siamese neural network to predict wafer alignment success based on kernel and input images.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import torchvision.transforms as transforms
from sklearn.metrics import accuracy_score, f1_score
from tqdm import tqdm
import os


In [None]:
# --- Feature Extractor ---
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((1, 1))
        )

    def forward(self, x):
        x = self.encoder(x)
        return x.view(x.size(0), -1)


In [None]:
# --- Siamese Classifier ---
class SiameseClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.feature_extractor = FeatureExtractor()
        self.classifier = nn.Sequential(
            nn.Linear(256, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, 1),
        )

    def forward(self, kernel_img, input_img):
        f1 = self.feature_extractor(kernel_img)
        f2 = self.feature_extractor(input_img)
        combined = torch.cat([f1, f2], dim=1)
        out = self.classifier(combined)
        return torch.sigmoid(out).squeeze(1)


In [None]:
# --- Dataset ---
class WaferPairDataset(Dataset):
    def __init__(self, kernel_paths, input_paths, labels, image_size=256):
        self.kernel_paths = kernel_paths
        self.input_paths = input_paths
        self.labels = labels
        self.transform = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.5], std=[0.5])
        ])

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        kernel_img = Image.open(self.kernel_paths[idx])
        input_img = Image.open(self.input_paths[idx])
        kernel_img = self.transform(kernel_img)
        input_img = self.transform(input_img)
        label = torch.tensor(self.labels[idx], dtype=torch.float32)
        return kernel_img, input_img, label


In [None]:
# --- Evaluation ---
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss = 0.0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for kernel, input_img, labels in loader:
            kernel = kernel.to(device)
            input_img = input_img.to(device)
            labels = labels.to(device)
            outputs = model(kernel, input_img)
            loss = criterion(outputs, labels)
            total_loss += loss.item() * labels.size(0)
            all_preds.extend((outputs > 0.5).int().cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    avg_loss = total_loss / len(loader.dataset)
    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    return avg_loss, acc, f1


In [None]:
# --- Training Loop ---
def train(model, train_loader, val_loader, num_epochs, optimizer, scheduler, criterion, device, save_path='best_model.pt'):
    model = model.to(device)
    best_f1 = 0.0
    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        all_preds, all_labels = [], []
        for kernel, input_img, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
            kernel = kernel.to(device)
            input_img = input_img.to(device)
            labels = labels.to(device)
            optimizer.zero_grad()
            outputs = model(kernel, input_img)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * labels.size(0)
            all_preds.extend((outputs > 0.5).int().cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
        train_loss = running_loss / len(train_loader.dataset)
        train_acc = accuracy_score(all_labels, all_preds)
        train_f1 = f1_score(all_labels, all_preds)
        val_loss, val_acc, val_f1 = evaluate(model, val_loader, criterion, device)
        print(f"\nEpoch {epoch+1}:")
        print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.4f} | F1: {train_f1:.4f}")
        print(f"Val   Loss: {val_loss:.4f} | Acc: {val_acc:.4f} | F1: {val_f1:.4f}")
        scheduler.step(val_loss)
        if val_f1 > best_f1:
            best_f1 = val_f1
            torch.save(model.state_dict(), save_path)
            print(f"✅ Best model saved with F1: {best_f1:.4f}")


In [None]:
# --- Setup Example ---
# Replace with actual file paths and labels
kernel_paths = ['/path/to/kernel1.png', '/path/to/kernel2.png']
input_paths = ['/path/to/input1.png', '/path/to/input2.png']
labels = [1, 0]  # 1 = pass, 0 = fail

# Create dataset and data loaders
dataset = WaferPairDataset(kernel_paths, input_paths, labels)
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

# Initialize model, loss, optimizer
model = SiameseClassifier()
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2, verbose=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Train
train(model, train_loader, val_loader, num_epochs=10, optimizer=optimizer, scheduler=scheduler, criterion=criterion, device=device)
