# 05. CNN Wafer Map Classification
## Smart Wafer Yield Optimization Project

This notebook implements deep learning techniques using Convolutional Neural Networks (CNN) for wafer map classification and defect pattern recognition.

### Objectives:
- Generate synthetic wafer map images for demonstration
- Build CNN architecture for spatial feature extraction
- Implement data augmentation techniques
- Train and evaluate CNN model
- Visualize activation maps and learned features
- Compare CNN performance with traditional ML methods

### CNN Architecture:
- **Convolutional Layers**: Spatial feature extraction
- **Batch Normalization**: Training stability
- **Dropout**: Regularization
- **Pooling**: Dimensionality reduction
- **Fully Connected**: Classification layers

### Data Augmentation:
- Rotation, flipping, scaling
- Noise injection
- Brightness/contrast adjustment
- Elastic deformation


In [None]:
# Import required libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

# Import our utility functions
import sys
import os
notebook_path = os.path.abspath("")
if notebook_path.endswith("notebooks"):
    project_root = os.path.dirname(notebook_path)
    os.chdir(project_root)
from app.utils import load_data

print("Libraries imported successfully!")
print("Ready to begin CNN wafer map classification...")


## 1. Generate Synthetic Wafer Map Data


In [None]:
# Generate synthetic wafer map images for demonstration
def generate_wafer_map(size=64, defect_type='normal'):
    """
    Generate synthetic wafer map images
    
    Args:
        size (int): Size of the wafer map (size x size)
        defect_type (str): Type of defect pattern ('normal', 'center', 'edge', 'random', 'cluster')
    
    Returns:
        np.array: Wafer map image
    """
    # Create base wafer (circular)
    center = size // 2
    y, x = np.ogrid[:size, :size]
    distance = np.sqrt((x - center)**2 + (y - center)**2)
    wafer_mask = distance <= center
    
    # Initialize wafer map
    wafer_map = np.zeros((size, size))
    
    if defect_type == 'normal':
        # Normal wafer with few random defects
        defects = np.random.random((size, size)) < 0.02
        wafer_map[defects & wafer_mask] = 1
        
    elif defect_type == 'center':
        # Center defect pattern
        center_defect = distance <= 5
        wafer_map[center_defect] = 1
        
    elif defect_type == 'edge':
        # Edge defect pattern
        edge_defect = distance >= center - 3
        wafer_map[edge_defect & wafer_mask] = 1
        
    elif defect_type == 'random':
        # Random scattered defects
        defects = np.random.random((size, size)) < 0.1
        wafer_map[defects & wafer_mask] = 1
        
    elif defect_type == 'cluster':
        # Cluster defect pattern
        cluster_center = (center + np.random.randint(-10, 10), center + np.random.randint(-10, 10))
        cluster_distance = np.sqrt((x - cluster_center[0])**2 + (y - cluster_center[1])**2)
        cluster_defect = cluster_distance <= 8
        wafer_map[cluster_defect & wafer_mask] = 1
    
    return wafer_map

# Generate dataset
print("Generating synthetic wafer map dataset...")

n_samples = 1000
image_size = 64
defect_types = ['normal', 'center', 'edge', 'random', 'cluster']
n_per_type = n_samples // len(defect_types)

X_images = []
y_labels = []

for defect_type in defect_types:
    for _ in range(n_per_type):
        # Generate wafer map
        wafer_map = generate_wafer_map(image_size, defect_type)
        X_images.append(wafer_map)
        y_labels.append(defect_types.index(defect_type))

X_images = np.array(X_images)
y_labels = np.array(y_labels)

print(f"✅ Generated {len(X_images)} wafer maps")
print(f"Image shape: {X_images.shape}")
print(f"Label distribution: {np.bincount(y_labels)}")

# Visualize sample wafer maps
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
axes = axes.flatten()

for i, defect_type in enumerate(defect_types):
    # Find first sample of this type
    sample_idx = np.where(y_labels == i)[0][0]
    axes[i].imshow(X_images[sample_idx], cmap='viridis')
    axes[i].set_title(f'{defect_type.capitalize()} Defect')
    axes[i].axis('off')

# Hide unused subplot
axes[5].set_visible(False)

plt.suptitle('Sample Wafer Map Defect Patterns', fontsize=16)
plt.tight_layout()
plt.show()


## 2. Build CNN Architecture


In [None]:
# Define CNN architecture for wafer map classification
class WaferMapCNN(nn.Module):
    def __init__(self, num_classes=5):
        super(WaferMapCNN, self).__init__()
        
        # Convolutional layers
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)
        
        # Pooling layer
        self.pool = nn.MaxPool2d(2, 2)
        
        # Dropout for regularization
        self.dropout = nn.Dropout(0.5)
        
        # Calculate the size after convolutions and pooling
        # 64x64 -> 32x32 -> 16x16 -> 8x8
        self.fc1 = nn.Linear(128 * 8 * 8, 512)
        self.fc2 = nn.Linear(512, 128)
        self.fc3 = nn.Linear(128, num_classes)
        
        # Activation function
        self.relu = nn.ReLU()
        
    def forward(self, x):
        # First conv block
        x = self.pool(self.relu(self.bn1(self.conv1(x))))
        
        # Second conv block
        x = self.pool(self.relu(self.bn2(self.conv2(x))))
        
        # Third conv block
        x = self.pool(self.relu(self.bn3(self.conv3(x))))
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.dropout(self.relu(self.fc1(x)))
        x = self.dropout(self.relu(self.fc2(x)))
        x = self.fc3(x)
        
        return x

# Create model
model = WaferMapCNN(num_classes=5)
print("✅ CNN model created")
print(f"Model architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")


## 3. Train CNN Model


In [None]:
# Prepare data for training
print("Preparing data for CNN training...")

# Split data
X_train, X_test, y_train, y_test = train_test_split(
    X_images, y_labels, test_size=0.2, random_state=42, stratify=y_labels
)

# Convert to PyTorch tensors
X_train = torch.FloatTensor(X_train).unsqueeze(1)  # Add channel dimension
X_test = torch.FloatTensor(X_test).unsqueeze(1)
y_train = torch.LongTensor(y_train)
y_test = torch.LongTensor(y_test)

print(f"Training set: {X_train.shape}")
print(f"Test set: {X_test.shape}")

# Create data loaders
batch_size = 32
train_dataset = torch.utils.data.TensorDataset(X_train, y_train)
test_dataset = torch.utils.data.TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("✅ Data loaders created")

# Set up training
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

# Training loop
num_epochs = 20
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []

print(f"\nStarting training for {num_epochs} epochs...")

for epoch in range(num_epochs):
    # Training phase
    model.train()
    train_loss = 0.0
    train_correct = 0
    train_total = 0
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(output.data, 1)
        train_total += target.size(0)
        train_correct += (predicted == target).sum().item()
    
    # Test phase
    model.eval()
    test_loss = 0.0
    test_correct = 0
    test_total = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            
            test_loss += loss.item()
            _, predicted = torch.max(output.data, 1)
            test_total += target.size(0)
            test_correct += (predicted == target).sum().item()
    
    # Calculate metrics
    train_loss /= len(train_loader)
    train_acc = 100. * train_correct / train_total
    test_loss /= len(test_loader)
    test_acc = 100. * test_correct / test_total
    
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    test_losses.append(test_loss)
    test_accuracies.append(test_acc)
    
    if epoch % 5 == 0:
        print(f'Epoch {epoch:2d}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    
    scheduler.step()

print("✅ Training completed!")

# Plot training results
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses, label='Train Loss')
ax1.plot(test_losses, label='Test Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Test Loss')
ax1.legend()
ax1.grid(True)

ax2.plot(train_accuracies, label='Train Accuracy')
ax2.plot(test_accuracies, label='Test Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Test Accuracy')
ax2.legend()
ax2.grid(True)

plt.tight_layout()
plt.show()

print(f"Final Test Accuracy: {test_acc:.2f}%")


## 4. Model Evaluation and Save


In [None]:
# Evaluate model on test set
print("Evaluating CNN model...")

model.eval()
all_predictions = []
all_targets = []

with torch.no_grad():
    for data, target in test_loader:
        data, target = data.to(device), target.to(device)
        output = model(data)
        _, predicted = torch.max(output, 1)
        
        all_predictions.extend(predicted.cpu().numpy())
        all_targets.extend(target.cpu().numpy())

# Convert to numpy arrays
y_pred = np.array(all_predictions)
y_true = np.array(all_targets)

# Calculate metrics
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

accuracy = accuracy_score(y_true, y_pred)
print(f"Test Accuracy: {accuracy:.3f}")

# Classification report
defect_type_names = ['normal', 'center', 'edge', 'random', 'cluster']
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=defect_type_names))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=defect_type_names, yticklabels=defect_type_names)
plt.title('CNN Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.tight_layout()
plt.show()

# Save the trained model
import os
os.makedirs('../models', exist_ok=True)

model_path = '../models/cnn_wafer_classifier.pth'
torch.save(model.state_dict(), model_path)

# Save model metadata
metadata = {
    'model_type': 'CNN',
    'num_classes': 5,
    'class_names': defect_type_names,
    'test_accuracy': accuracy,
    'image_size': 64,
    'num_parameters': total_params,
    'training_epochs': num_epochs
}

import json
with open('../models/cnn_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"✅ CNN model saved to {model_path}")
print(f"✅ Model metadata saved")
print(f"📊 Final CNN Performance:")
print(f"   Test Accuracy: {accuracy:.3f}")
print(f"   Parameters: {total_params:,}")
print(f"   Classes: {len(defect_type_names)}")

print("\n🎯 CNN wafer map classification completed successfully!")
