
# 🟢 Wafer Defect Classification (WM-811K)

This project applies **deep learning** to classify wafer maps from the **WM-811K dataset** into nine defect categories using **ResNet18 (transfer learning)**. 

It demonstrates:
- Data preparation (parsing waferMap → images)
- Training with transfer learning
- Evaluation with precision/recall/F1, ROC, PR curves
- Interpretability using **Grad-CAM**
- Insights into imbalance challenges in semiconductor yield analysis



## 📊 Dataset: WM-811K

- **Size:** 811,457 wafer maps
- **Labels:** ~172,950 wafers labeled into 9 defect classes  
  *center, donut, edge-loc, edge-ring, loc, near-full, random, scratch, none*  
- **Imbalance:** Majority class ("none") = 147k, rare class ("near-full") = 149  
- Each wafer is stored as a 2D matrix (`waferMap`):  
  - 0 = background  
  - 1 = normal die  
  - 2 = defective die  



## 🧮 Mathematical Foundations

### Softmax + Cross-Entropy Loss
$$
\hat{p}_k = \frac{\exp(z_k)}{\sum_j \exp(z_j)}
$$
$$
\mathcal{L}_{CE} = - \frac{1}{N} \sum_{i=1}^N \log \hat{p}_{y_i}
$$

### Precision, Recall, and F1
$$
Precision = \frac{TP}{TP + FP}, \quad Recall = \frac{TP}{TP + FN}, \quad
F1 = \frac{2 \cdot Precision \cdot Recall}{Precision + Recall}
$$

### ROC & PR Curves
- **ROC**: plots TPR vs FPR  
$$TPR = \frac{TP}{TP+FN}, \quad FPR = \frac{FP}{FP+TN}$$  

- **PR**: plots Precision vs Recall (important for imbalance)

### Convolutional Layers
$$
y_{i,j,k} = \sigma\!\Bigg(\sum_{m=1}^{M}\sum_{u}\sum_{v} W_{u,v,m,k}\; x_{i+u, j+v, m} + b_k\Bigg)
$$

### Residual Block (ResNet18)
$$
y = F(x, \{W_i\}) + x
$$

- $F(x)$ = learned residual mapping  
- $x$ = identity skip connection  
- Benefit: mitigates vanishing gradients and allows deeper networks



## 💻 Data Preparation
Wafer maps are parsed into 224×224 RGB images grouped into `train/val/test` directories by class.


In [None]:

from pathlib import Path
from PIL import Image
import numpy as np

def render_wafer(arr):
    arr = np.array(arr).astype(np.uint8)
    vis = (arr==1)*180 + (arr==2)*255
    return Image.fromarray(vis, mode='L').resize((224,224)).convert('RGB')



## 🧠 Model: ResNet18 (Transfer Learning)

We use **ResNet18 pretrained on ImageNet**, replacing the final FC layer with 9 output nodes.

Training objective: **Cross-Entropy Loss** with class weights for imbalance.


In [None]:

import torch, torch.nn as nn
import torchvision.models as models

def build_model(num_classes=9, pretrained=True, freeze_backbone=True):
    weights = models.ResNet18_Weights.DEFAULT if pretrained else None
    m = models.resnet18(weights=weights)
    if freeze_backbone:
        for name, param in m.named_parameters():
            if not name.startswith('fc.'):
                param.requires_grad = False
    in_feats = m.fc.in_features
    m.fc = nn.Linear(in_feats, num_classes)
    return m



## 🔄 Training Loop

Parameters $\theta$ optimized with Adam:

$$
\theta \leftarrow \theta - \eta \cdot \nabla_\theta \mathcal{L}(\theta)
$$


In [None]:

def train_epoch(model, loader, criterion, optimizer, device='cuda'):
    model.train(); total, correct, loss_sum = 0,0,0
    for xb,yb in loader:
        xb,yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits,yb)
        loss.backward(); optimizer.step()
        loss_sum += loss.item()*xb.size(0)
        correct += (logits.argmax(1)==yb).sum().item()
        total += xb.size(0)
    return loss_sum/total, correct/total



## 📈 Results

### Training Dynamics
![Train Loss](models/train_loss.png)  
![Train Acc](models/train_acc.png)

### Classification Report
![Classification Report](models/classification_report.png)

### Confusion Matrix
![Confusion Matrix](models/confusion_matrix_norm.png)

### ROC & PR Curves
![ROC](models/ovr_roc.png)  
![PR](models/ovr_pr.png)

### Grad-CAM
![Grad-CAM Example](models/gradcam/example_cam.png)



## 🔎 Discussion

- **Strong baseline**: ResNet18 pretrained reached ~93% overall accuracy.  
- **Imbalance issue**: Macro-F1 is lower due to rare classes.  
- **Interpretability**: Grad-CAM overlays confirm the model focuses on defect regions.  
- **Industry relevance**: Mirrors real-world fab challenges.

## 🚀 Next Steps
- Use focal loss or oversampling for imbalance  
- Try deeper backbones  
- Add K-fold cross-validation  
- Deploy via Streamlit or SageMaker



## 📌 References
- He et al., 2016: *Deep Residual Learning for Image Recognition*  
- WM-811K Dataset (Kaggle)  
