# Two-Tower CNN with Cross-Attention Fusion for Wafer Pass/Fail Prediction

This notebook demonstrates loading images from filepaths, training a Two-Tower CNN where the kernel image queries the input image features via cross-attention, and predicting pass/fail.

In [ ]:
# Imports
import os
from PIL import Image
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [ ]:
# Dataset class to load pairs of images and labels
class WaferDataset(Dataset):
    def __init__(self, kernel_paths, input_paths, labels, transform=None):
        self.kernel_paths = kernel_paths
        self.input_paths = input_paths
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        kernel_img = Image.open(self.kernel_paths[idx]).convert('L')
        input_img = Image.open(self.input_paths[idx]).convert('L')
        label = self.labels[idx]
        if self.transform:
            kernel_img = self.transform(kernel_img)
            input_img = self.transform(input_img)
        return kernel_img, input_img, torch.tensor(label, dtype=torch.float32)


In [ ]:
# Example transforms
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])  # grayscale mean/std
])


### Define the Model Components
Feature Extractor Tower, Cross-Attention Fusion, and Classifier

In [ ]:
class FeatureExtractor(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((8, 8))
        )
    
    def forward(self, x):
        return self.encoder(x)  # (B, 128, 8, 8)


In [ ]:
class CrossAttentionFusion(nn.Module):
    def __init__(self, feature_dim=128, spatial_dim=8):
        super().__init__()
        self.feature_dim = feature_dim
        self.spatial_dim = spatial_dim
        self.query_proj = nn.Linear(feature_dim, feature_dim)
        self.key_proj = nn.Linear(feature_dim, feature_dim)
        self.value_proj = nn.Linear(feature_dim, feature_dim)
        self.scale = feature_dim ** 0.5

    def forward(self, feat_kernel, feat_input):
        B, C, H, W = feat_kernel.shape
        q = feat_kernel.permute(0, 2, 3, 1).reshape(B, H*W, C)  # Queries from kernel
        k = feat_input.permute(0, 2, 3, 1).reshape(B, H*W, C)   # Keys from input
        v = feat_input.permute(0, 2, 3, 1).reshape(B, H*W, C)   # Values from input
        Q = self.query_proj(q)
        K = self.key_proj(k)
        V = self.value_proj(v)
        attn_scores = torch.bmm(Q, K.transpose(1, 2)) / self.scale  # (B, H*W, H*W)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attn_output = torch.bmm(attn_weights, V)  # (B, H*W, C)
        attn_output = attn_output.view(B, H, W, C).permute(0, 3, 1, 2)  # (B, C, H, W)
        pooled = F.adaptive_avg_pool2d(attn_output, (1, 1)).view(B, C)
        return pooled


In [ ]:
class TwoTowerCrossAttentionClassifier(nn.Module):
    def __init__(self):
        super().__init__()
        self.kernel_tower = FeatureExtractor()
        self.input_tower = FeatureExtractor()
        self.attention = CrossAttentionFusion(feature_dim=128, spatial_dim=8)
        self.classifier = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(64, 1)
        )
    
    def forward(self, kernel_img, input_img):
        feat_kernel = self.kernel_tower(kernel_img)
        feat_input = self.input_tower(input_img)
        fused = self.attention(feat_kernel, feat_input)
        logits = self.classifier(fused)
        return torch.sigmoid(logits).squeeze(1)


### Training and Validation Loop with Metrics

In [ ]:
def train_epoch(model, dataloader, optimizer, criterion, device):
    model.train()
    running_loss = 0
    preds, targets = [], []
    for k_img, i_img, label in dataloader:
        k_img, i_img, label = k_img.to(device), i_img.to(device), label.to(device)
        optimizer.zero_grad()
        outputs = model(k_img, i_img)
        loss = criterion(outputs, label)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * label.size(0)
        preds += (outputs > 0.5).long().cpu().tolist()
        targets += label.long().cpu().tolist()
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(targets, preds)
    epoch_f1 = f1_score(targets, preds)
    return epoch_loss, epoch_acc, epoch_f1

def eval_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0
    preds, targets = [], []
    with torch.no_grad():
        for k_img, i_img, label in dataloader:
            k_img, i_img, label = k_img.to(device), i_img.to(device), label.to(device)
            outputs = model(k_img, i_img)
            loss = criterion(outputs, label)
            running_loss += loss.item() * label.size(0)
            preds += (outputs > 0.5).long().cpu().tolist()
            targets += label.long().cpu().tolist()
    epoch_loss = running_loss / len(dataloader.dataset)
    epoch_acc = accuracy_score(targets, preds)
    epoch_f1 = f1_score(targets, preds)
    return epoch_loss, epoch_acc, epoch_f1


### Full Training Loop

In [ ]:
def train(model, train_loader, val_loader, epochs, optimizer, scheduler, criterion, device):
    model.to(device)
    for epoch in range(1, epochs + 1):
        train_loss, train_acc, train_f1 = train_epoch(model, train_loader, optimizer, criterion, device)
        val_loss, val_acc, val_f1 = eval_epoch(model, val_loader, criterion, device)
        scheduler.step(val_loss)
        print(f"Epoch {epoch}/{epochs}")
        print(f"Train loss: {train_loss:.4f} Acc: {train_acc:.4f} F1: {train_f1:.4f}")
        print(f"Val   loss: {val_loss:.4f} Acc: {val_acc:.4f} F1: {val_f1:.4f}")
             )


### Example Usage

- Prepare lists of filepaths and labels (`kernel_paths`, `input_paths`, `labels`)
- Split into training and validation
- Create datasets and dataloaders
- Instantiate model, optimizer, scheduler, criterion
- Call `train()`

In [ ]:
# Example dummy data (replace with real filepaths and labels)
kernel_paths = ['/path/to/kernel1.png', '/path/to/kernel2.png']
input_paths = ['/path/to/input1.png', '/path/to/input2.png']
labels = [1, 0]  # 1=pass, 0=fail

# For demonstration, split data into train/val
train_dataset = WaferDataset(kernel_paths, input_paths, labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(train_dataset, batch_size=2)  # just for demo

# Initialize model
model = TwoTowerCrossAttentionClassifier().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.5, patience=2, verbose=True)

# Train (reduce epochs for demo)
# train(model, train_loader, val_loader, epochs=10, optimizer=optimizer, scheduler=scheduler, criterion=criterion, device=device)
print("Ready to train — replace dummy data with your dataset and call train()")