In [None]:
# 🔧 Setup: Run this cell first!
# Check GPU availability and install dependencies

import torch
import sys

# Check GPU
if torch.cuda.is_available():
    device = torch.device('cuda')
    print(f"✅ GPU available: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
else:
    device = torch.device('cpu')
    print("⚠️ No GPU detected. Some cells may run slowly.")
    print("   Go to Runtime → Change runtime type → GPU")

print(f"\n📦 Python {sys.version.split()[0]}")
print(f"🔥 PyTorch {torch.__version__}")

# Set random seeds for reproducibility
import random
import numpy as np

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

print(f"🎲 Random seed set to {SEED}")

%matplotlib inline

# Multimodal Defect Detection at Stratos Manufacturing -- Implementation Notebook

*Case Study Implementation: Multimodal Fusion for Semiconductor Quality Inspection*

## Setup

In [None]:
!pip install torch torchvision matplotlib numpy scikit-learn seaborn -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from sklearn.metrics import confusion_matrix, classification_report, f1_score

torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Data Generation (Synthetic)

Since we do not have real semiconductor inspection data, we create a realistic synthetic dataset that captures the key multimodal structure of the problem.

In [None]:
class SyntheticDefectDataset(Dataset):
    """
    Synthetic multimodal defect detection dataset.

    Generates correlated (image, process_log, text, label) tuples
    where defects manifest across modalities.
    """
    def __init__(self, n_samples=5000, img_size=64):
        self.n_samples = n_samples
        self.img_size = img_size

        # Class distribution: 85% good, 15% defective (split among 4 types)
        self.labels = torch.zeros(n_samples, dtype=torch.long)
        n_good = int(n_samples * 0.85)
        n_defect_each = (n_samples - n_good) // 4

        idx = n_good
        for cls in range(1, 5):
            self.labels[idx:idx+n_defect_each] = cls
            idx += n_defect_each

        # Shuffle
        perm = torch.randperm(n_samples)
        self.labels = self.labels[perm]

        # Generate images
        self.images = torch.zeros(n_samples, 3, img_size, img_size)

        # Generate process logs (48 parameters)
        self.process_logs = torch.randn(n_samples, 48) * 0.5

        # Generate operator notes (encoded as category)
        self.has_notes = torch.rand(n_samples) > 0.4  # 60% have notes

        for i in range(n_samples):
            label = self.labels[i].item()
            self._generate_sample(i, label)

    def _generate_sample(self, idx, label):
        img = torch.randn(3, self.img_size, self.img_size) * 0.1 + 0.5

        if label == 0:  # Good die
            pass  # Clean image, normal process
        elif label == 1:  # Particle contamination
            # Add bright spots
            n_particles = np.random.randint(1, 5)
            for _ in range(n_particles):
                cx, cy = np.random.randint(5, self.img_size-5, 2)
                r = np.random.randint(1, 4)
                y, x = np.ogrid[-r:r+1, -r:r+1]
                mask = x**2 + y**2 <= r**2
                img[0, max(0,cy-r):cy+r+1, max(0,cx-r):cx+r+1][mask[:img.shape[1]-max(0,cy-r), :img.shape[2]-max(0,cx-r)]] += 0.5
            self.process_logs[idx, 5] += 1.5  # Particle count sensor
        elif label == 2:  # Scratch
            # Add a line
            y0, x0 = np.random.randint(10, self.img_size-10, 2)
            angle = np.random.uniform(0, np.pi)
            length = np.random.randint(15, 40)
            for t in range(length):
                yy = int(y0 + t * np.sin(angle))
                xx = int(x0 + t * np.cos(angle))
                if 0 <= yy < self.img_size and 0 <= xx < self.img_size:
                    img[:, yy, max(0,xx-1):xx+2] -= 0.3
            self.process_logs[idx, 12] += 2.0  # Vibration sensor
        elif label == 3:  # Pattern defect
            # Add a shifted pattern
            x = torch.linspace(-3, 3, self.img_size)
            y = torch.linspace(-3, 3, self.img_size)
            XX, YY = torch.meshgrid(x, y, indexing='ij')
            pattern = torch.sin(XX * 5 + np.random.uniform(-1, 1)) * 0.3
            img[1] += pattern
            self.process_logs[idx, 20] += 1.8  # Exposure dose deviation
        elif label == 4:  # Process deviation
            # Discoloration + process anomaly
            img[0] += 0.15  # Red tint
            img[2] -= 0.1
            self.process_logs[idx, 30] += 2.5  # Temperature deviation
            self.process_logs[idx, 31] += 1.5  # Pressure deviation

        self.images[idx] = img.clamp(0, 1)

    def __len__(self):
        return self.n_samples

    def __getitem__(self, idx):
        return (self.images[idx], self.process_logs[idx],
                self.has_notes[idx].float(), self.labels[idx])

# Create datasets
train_data = SyntheticDefectDataset(5000)
test_data = SyntheticDefectDataset(1000)

train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# Class distribution
class_names = ['Good', 'Particle', 'Scratch', 'Pattern', 'Process']
counts = [(train_data.labels == c).sum().item() for c in range(5)]
print("Class distribution:")
for name, count in zip(class_names, counts):
    print(f"  {name}: {count} ({count/len(train_data)*100:.1f}%)")

## 2. Exploratory Data Analysis

In [None]:
fig, axes = plt.subplots(2, 5, figsize=(20, 8))

for c in range(5):
    idx = (train_data.labels == c).nonzero(as_tuple=True)[0][0].item()
    img = train_data.images[idx]

    axes[0, c].imshow(img.permute(1, 2, 0).clamp(0, 1))
    axes[0, c].set_title(f'{class_names[c]}', fontsize=12, fontweight='bold')
    axes[0, c].axis('off')

    # Process log deviations
    process = train_data.process_logs[idx].numpy()
    axes[1, c].bar(range(48), process, color='steelblue', alpha=0.7)
    axes[1, c].set_title(f'Process Params')
    axes[1, c].set_xlabel('Parameter Index')
    axes[1, c].set_ylim(-4, 4)
    axes[1, c].axhline(y=0, color='gray', linestyle='--', alpha=0.5)

plt.suptitle('Sample Images and Process Parameters by Defect Class', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

In [None]:
# Class imbalance visualization
fig, ax = plt.subplots(figsize=(8, 5))
bars = ax.bar(class_names, counts, color=['#4CAF50', '#FF5722', '#FF9800', '#2196F3', '#9C27B0'])
ax.set_ylabel('Count')
ax.set_title('Class Distribution (Imbalanced -- 85% Good Dies)')
for bar, count in zip(bars, counts):
    ax.text(bar.get_x() + bar.get_width()/2., bar.get_height() + 20,
           str(count), ha='center', va='bottom', fontweight='bold')
plt.tight_layout()
plt.show()

## 3. Baseline: Vision-Only Model

In [None]:
class VisionOnlyBaseline(nn.Module):
    def __init__(self, img_size=64, num_classes=5):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(4),
            nn.Flatten(),
            nn.Linear(128 * 4 * 4, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes)
        )

    def forward(self, images, process_logs=None, has_notes=None):
        return self.encoder(images)

baseline = VisionOnlyBaseline().to(device)
print(f"Baseline parameters: {sum(p.numel() for p in baseline.parameters()):,}")

## 4. Multimodal Fusion Model

In [None]:
class GatedCrossAttention(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.cross_attn = nn.MultiheadAttention(dim, num_heads, batch_first=True)
        self.norm = nn.LayerNorm(dim)
        self.gate = nn.Parameter(torch.zeros(1))

    def forward(self, query, kv):
        attn_out, attn_weights = self.cross_attn(self.norm(query), kv, kv)
        gate_val = torch.tanh(self.gate)
        return query + gate_val * attn_out, attn_weights


class MultimodalDefectDetector(nn.Module):
    """
    Full multimodal model with gated cross-attention fusion.
    """
    def __init__(self, img_size=64, patch_size=8, embed_dim=128,
                 process_dim=48, num_classes=5, num_heads=4):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2
        patch_dim = patch_size * patch_size * 3

        # Image encoder (patch-based)
        self.patch_proj = nn.Linear(patch_dim, embed_dim)
        self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim) * 0.02)
        self.vis_transformer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=num_heads, dim_feedforward=embed_dim*4,
            batch_first=True, dropout=0.1
        )
        self.patch_size = patch_size

        # Process log encoder
        self.process_proj = nn.Sequential(
            nn.Linear(process_dim, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

        # Note encoder (simplified: just a learnable embedding for has/no notes)
        self.note_embed = nn.Embedding(2, embed_dim)

        # Gated cross-attention layers
        self.vision_process_xattn = GatedCrossAttention(embed_dim, num_heads)
        self.vision_note_xattn = GatedCrossAttention(embed_dim, num_heads)

        # Classification head
        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, embed_dim),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, images, process_logs, has_notes):
        B = images.shape[0]
        p = self.patch_size

        # Encode image patches
        patches = images.unfold(2, p, p).unfold(3, p, p)
        patches = patches.contiguous().view(B, 3, -1, p, p)
        patches = patches.permute(0, 2, 1, 3, 4).reshape(B, self.num_patches, -1)
        vis_tokens = self.patch_proj(patches) + self.pos_embed
        vis_tokens = self.vis_transformer(vis_tokens)

        # Encode process logs
        process_token = self.process_proj(process_logs).unsqueeze(1)

        # Encode notes
        note_ids = has_notes.long()
        note_token = self.note_embed(note_ids).unsqueeze(1)

        # Fuse: vision <- process
        vis_tokens, attn_vp = self.vision_process_xattn(vis_tokens, process_token)
        # Fuse: vision <- notes
        vis_tokens, attn_vn = self.vision_note_xattn(vis_tokens, note_token)

        # Pool and classify
        pooled = vis_tokens.mean(dim=1)
        logits = self.classifier(pooled)

        return logits

multimodal = MultimodalDefectDetector().to(device)
print(f"Multimodal parameters: {sum(p.numel() for p in multimodal.parameters()):,}")

## 5. Training

In [None]:
def focal_loss(logits, targets, gamma=2.0, alpha=None):
    """Focal loss for class-imbalanced classification."""
    ce_loss = F.cross_entropy(logits, targets, reduction='none')
    pt = torch.exp(-ce_loss)
    focal = ((1 - pt) ** gamma) * ce_loss

    if alpha is not None:
        alpha_t = alpha[targets]
        focal = alpha_t * focal

    return focal.mean()

# Class weights (inverse frequency)
class_counts = torch.tensor([c for c in counts], dtype=torch.float32)
class_weights = (1.0 / class_counts) * class_counts.sum() / len(class_counts)
class_weights = class_weights.to(device)
print(f"Class weights: {class_weights.cpu().numpy()}")

def train_model(model, train_loader, test_loader, epochs=30, lr=1e-3, use_focal=True):
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)

    train_losses, test_f1s, escape_rates = [], [], []

    for epoch in range(epochs):
        model.train()
        epoch_loss = 0
        for imgs, plogs, notes, labels in train_loader:
            imgs = imgs.to(device)
            plogs = plogs.to(device)
            notes = notes.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            logits = model(imgs, plogs, notes)

            if use_focal:
                loss = focal_loss(logits, labels, gamma=2.0, alpha=class_weights)
            else:
                loss = F.cross_entropy(logits, labels)

            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()

        scheduler.step()
        train_losses.append(epoch_loss / len(train_loader))

        # Evaluate
        model.eval()
        all_preds, all_labels = [], []
        with torch.no_grad():
            for imgs, plogs, notes, labels in test_loader:
                imgs = imgs.to(device)
                plogs = plogs.to(device)
                notes = notes.to(device)
                logits = model(imgs, plogs, notes)
                preds = logits.argmax(1).cpu()
                all_preds.append(preds)
                all_labels.append(labels)

        all_preds = torch.cat(all_preds)
        all_labels = torch.cat(all_labels)

        f1 = f1_score(all_labels, all_preds, average='macro')
        test_f1s.append(f1)

        # Escape rate: defective dies classified as good
        defective = all_labels > 0
        escaped = (all_preds[defective] == 0).float().mean().item() if defective.sum() > 0 else 0
        escape_rates.append(escaped)

        if (epoch + 1) % 10 == 0:
            print(f"  Epoch {epoch+1:3d}: Loss={train_losses[-1]:.4f}, "
                  f"F1={f1:.4f}, Escape={escaped:.4f}")

    return train_losses, test_f1s, escape_rates

In [None]:
print("Training Vision-Only Baseline...")
baseline = VisionOnlyBaseline().to(device)
bl_losses, bl_f1s, bl_escapes = train_model(baseline, train_loader, test_loader, use_focal=False)

print("\nTraining Multimodal Model...")
multimodal = MultimodalDefectDetector().to(device)
mm_losses, mm_f1s, mm_escapes = train_model(multimodal, train_loader, test_loader, use_focal=True)

## 6. Evaluation

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(bl_losses, label='Vision-Only', color='#FF5722', linewidth=2)
axes[0].plot(mm_losses, label='Multimodal', color='#2196F3', linewidth=2)
axes[0].set_title('Training Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(bl_f1s, label='Vision-Only', color='#FF5722', linewidth=2)
axes[1].plot(mm_f1s, label='Multimodal', color='#2196F3', linewidth=2)
axes[1].set_title('Macro F1 Score')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(bl_escapes, label='Vision-Only', color='#FF5722', linewidth=2)
axes[2].plot(mm_escapes, label='Multimodal', color='#2196F3', linewidth=2)
axes[2].axhline(y=0.01, color='green', linestyle='--', label='Target (1%)')
axes[2].set_title('Defect Escape Rate')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Vision-Only vs Multimodal Fusion Model', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"\nFinal Results:")
print(f"  Vision-Only: F1={bl_f1s[-1]:.4f}, Escape Rate={bl_escapes[-1]:.4f}")
print(f"  Multimodal:  F1={mm_f1s[-1]:.4f}, Escape Rate={mm_escapes[-1]:.4f}")

In [None]:
# Confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

for idx, (model, name) in enumerate([(baseline, 'Vision-Only'), (multimodal, 'Multimodal')]):
    model.eval()
    all_preds, all_labels = [], []
    with torch.no_grad():
        for imgs, plogs, notes, labels in test_loader:
            logits = model(imgs.to(device), plogs.to(device), notes.to(device))
            all_preds.append(logits.argmax(1).cpu())
            all_labels.append(labels)

    all_preds = torch.cat(all_preds)
    all_labels = torch.cat(all_labels)

    cm = confusion_matrix(all_labels, all_preds)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[idx],
               xticklabels=class_names, yticklabels=class_names)
    axes[idx].set_title(f'{name}\n{classification_report(all_labels, all_preds, target_names=class_names, output_dict=False)[:50]}...')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('True')

plt.tight_layout()
plt.show()

## 7. Error Analysis

In [None]:
# Analyze gate values -- how much visual information comes from each modality
multimodal.eval()
with torch.no_grad():
    vp_gate = torch.tanh(multimodal.vision_process_xattn.gate).item()
    vn_gate = torch.tanh(multimodal.vision_note_xattn.gate).item()

print(f"Learned Gate Values:")
print(f"  Vision-Process gate: {vp_gate:.4f}")
print(f"  Vision-Note gate:    {vn_gate:.4f}")
print(f"\nThe model learned to weight process data {abs(vp_gate)/abs(vn_gate):.1f}x more than notes.")

## 8. Deployment Optimization

In [None]:
# Latency benchmark
import time

multimodal.eval()
dummy_imgs = torch.randn(1, 3, 64, 64).to(device)
dummy_plogs = torch.randn(1, 48).to(device)
dummy_notes = torch.tensor([1.0]).to(device)

# Warmup
for _ in range(10):
    with torch.no_grad():
        _ = multimodal(dummy_imgs, dummy_plogs, dummy_notes)

# Benchmark
times = []
for _ in range(100):
    start = time.time()
    with torch.no_grad():
        _ = multimodal(dummy_imgs, dummy_plogs, dummy_notes)
    if device.type == 'cuda':
        torch.cuda.synchronize()
    times.append((time.time() - start) * 1000)

print(f"Inference Latency (single die):")
print(f"  Mean:  {np.mean(times):.2f} ms")
print(f"  P50:   {np.percentile(times, 50):.2f} ms")
print(f"  P95:   {np.percentile(times, 95):.2f} ms")
print(f"  P99:   {np.percentile(times, 99):.2f} ms")
print(f"  Target: <200 ms {'PASS' if np.percentile(times, 99) < 200 else 'FAIL'}")

## 9. Ethics and Fairness Considerations

In [None]:
# Simulate per-fab performance analysis
print("Per-Fab Performance Analysis (Simulated)")
print("=" * 50)

# Randomly assign test samples to fabs
fab_ids = np.random.choice(['Fab A', 'Fab B', 'Fab C'], size=len(test_data))

multimodal.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for imgs, plogs, notes, labels in test_loader:
        logits = multimodal(imgs.to(device), plogs.to(device), notes.to(device))
        all_preds.append(logits.argmax(1).cpu())
        all_labels.append(labels)

all_preds = torch.cat(all_preds).numpy()
all_labels = torch.cat(all_labels).numpy()

for fab in ['Fab A', 'Fab B', 'Fab C']:
    mask = fab_ids == fab
    fab_labels = all_labels[mask]
    fab_preds = all_preds[mask]

    defective = fab_labels > 0
    if defective.sum() > 0:
        escape = (fab_preds[defective] == 0).mean()
    else:
        escape = 0

    good = fab_labels == 0
    if good.sum() > 0:
        fpr = (fab_preds[good] > 0).mean()
    else:
        fpr = 0

    f1 = f1_score(fab_labels, fab_preds, average='macro', zero_division=0)
    print(f"  {fab}: F1={f1:.4f}, Escape={escape:.4f}, FPR={fpr:.4f}")

In [None]:
print("\nCongratulations! You have built a multimodal defect detection system.")
print("Key achievement: combining vision, process data, and operator notes")
print("reduced the defect escape rate compared to vision-only inspection.")