In [None]:
# =================================================================
# 1.1. SETUP AND DEPENDENCIES
# =================================================================
print("Starting environment setup...")
!pip install -q kagglehub imbalanced-learn scikit-image
!pip install -q opencv-python  # Required for image processing and visualization

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from skimage.transform import resize
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score, classification_report
from imblearn.over_sampling import RandomOverSampler
import cv2 # Used for image processing in visualization

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import kagglehub

# Set Device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# =================================================================
# 2.1. DATA ACQUISITION AND CLEANING
# =================================================================
print("\nDownloading and processing WM-811K dataset...")
# Download data using kagglehub
dataset_path = kagglehub.dataset_download("qingyi/wm811k-wafer-map")
pkl_file_path = os.path.join(dataset_path, "LSWMD.pkl")
df = pd.read_pickle(pkl_file_path)

# --- CRITICAL FIXES for Data Integrity ---
# 1. Function to clean the nested array structure in 'failureType'
def clean_failure_type(x):
    if isinstance(x, np.ndarray):
        return str(x[0]) if x.size > 0 else None
    return str(x)
df['failureType'] = df['failureType'].apply(clean_failure_type)

# 2. Filter out 'none' class and NaN values to focus only on the 8 defect patterns
df_defects = df[(df['failureType'] != "['none']") & (df['failureType'] != 'none')].copy()
df_defects.dropna(subset=['failureType'], inplace=True)

# Map labels and define global constants
defect_types = df_defects['failureType'].unique()
df_defects['label'] = df_defects['failureType'].map({label: i for i, label in enumerate(defect_types)})
NUM_CLASSES = len(defect_types)

print(f"Final Defect Samples: {len(df_defects)}, Total Classes: {NUM_CLASSES}")
print("Initial Defect Class Distribution (Imbalanced):\n", df_defects['failureType'].value_counts())


# =================================================================
# 2.2. DATA SPLIT AND IMBALANCE MITIGATION
# =================================================================

# 1. Split indices (Memory-efficient splitting)
X_indices = df_defects.index.values
X_train_indices, X_test_indices, y_train, y_test = train_test_split(
    X_indices, df_defects['label'].values, test_size=0.2, stratify=df_defects['label'].values, random_state=42
)

# 2. Random Oversampling on indices (to counter severe imbalance)
ros = RandomOverSampler(random_state=42)
X_train_resampled_indices, y_train_resampled = ros.fit_resample(
    X_train_indices.reshape(-1, 1), y_train
)
X_train_resampled_indices = X_train_resampled_indices.flatten()

# 3. Calculate Class Weights (to maximize F1-score on minority classes)
class_counts = pd.Series(y_train).value_counts().sort_index()
total_samples = class_counts.sum()
class_weights_inverse = total_samples / (NUM_CLASSES * class_counts.values)
class_weights_tensor = torch.tensor(class_weights_inverse, dtype=torch.float).to(device)

print("\nWeighted Loss Tensor Defined (Used in Training).")

# =================================================================
# 2.3. CUSTOM PYTORCH DATASET (LAZY LOADING)
# =================================================================

def preprocess_wafer(wafer_map, target_size=(64, 64)):
    """Pre-processing function for on-the-fly loading."""
    wafer_map = wafer_map / 2.0
    max_dim = max(wafer_map.shape)
    squared_map = np.zeros((max_dim, max_dim))
    h, w = wafer_map.shape
    h_start = (max_dim - h) // 2
    w_start = (max_dim - w) // 2
    squared_map[h_start:h_start+h, w_start:w_start+w] = wafer_map
    resized_map = resize(squared_map, target_size, anti_aliasing=True)
    input_tensor = np.stack([resized_map] * 3, axis=0)
    return torch.tensor(input_tensor.astype(np.float32))

class WaferDataset(Dataset):
    """Loads and preprocesses data only when requested (Saves 10+ GB of RAM)."""
    def __init__(self, df, indices, labels):
        self.df = df
        self.indices = indices
        self.labels = labels
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, idx):
        df_idx = self.indices[idx]
        wafer_map = self.df.loc[df_idx, 'waferMap']
        image = preprocess_wafer(wafer_map)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return image, label

# Create the final memory-efficient Dataset instances
train_dataset = WaferDataset(df_defects, X_train_resampled_indices, y_train_resampled)
test_dataset = WaferDataset(df_defects, X_test_indices, y_test)

print("\nCustom PyTorch Datasets created for memory-efficient training.")

In [None]:
# =================================================================
# 3.1. DEFINE CUSTOM CNN MODEL
# =================================================================

BATCH_SIZE = 64
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

class WaferCNN(nn.Module):
    def __init__(self, num_classes=NUM_CLASSES):
        super(WaferCNN, self).__init__()
        # Architecture: 6 Convolutional Layers for rich feature extraction
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1); self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1); self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1); self.bn3 = nn.BatchNorm2d(128)
        self.conv4 = nn.Conv2d(128, 128, kernel_size=3, padding=1); self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.conv5 = nn.Conv2d(128, 256, kernel_size=3, padding=1); self.bn5 = nn.BatchNorm2d(256)
        self.conv6 = nn.Conv2d(256, 256, kernel_size=3, padding=1); self.bn6 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(256, num_classes)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = F.relu(self.bn2(self.conv2(F.relu(self.bn1(self.conv1(x))))))
        x = self.pool1(x)
        x = F.relu(self.bn4(self.conv4(F.relu(self.bn3(self.conv3(x))))))
        x = self.pool2(x)
        x = F.relu(self.bn6(self.conv6(F.relu(self.bn5(self.conv5(x))))))
        x = self.pool3(x)
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        x = self.dropout(x)
        x = self.fc(x)
        return x

custom_model = WaferCNN(NUM_CLASSES).to(device)

# Loss function uses the calculated inverse class weights
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
optimizer = optim.Adam(custom_model.parameters(), lr=0.0005, weight_decay=1e-5)
NUM_EPOCHS = 15

print("\nStarting Custom CNN Training with Weighted Loss...")
# =================================================================
# 3.2. TRAINING LOOP
# =================================================================
custom_model.train()
for epoch in range(NUM_EPOCHS):
    running_loss = 0.0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = custom_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()

    print(f'Epoch {epoch + 1}/{NUM_EPOCHS}, Training Loss: {running_loss / len(train_loader):.4f}')

print('\nFinished Training! Model weights saved.')
torch.save(custom_model.state_dict(), 'custom_wafer_cnn_weighted.pth')

In [None]:
# =================================================================
# 4.1. MODEL EVALUATION
# =================================================================
print("\n\nEvaluating Final Model on Test Data...")
custom_model.eval()
all_preds_custom = []
all_labels_custom = []

with torch.no_grad():
    for data in test_loader:
        images, labels = data
        images, labels = images.to(device), labels.to(device)
        outputs = custom_model(images)
        _, predicted = torch.max(outputs.data, 1)
        all_preds_custom.extend(predicted.cpu().numpy())
        all_labels_custom.extend(labels.cpu().numpy())

accuracy_custom = accuracy_score(all_labels_custom, all_preds_custom)
f1_macro_custom = f1_score(all_labels_custom, all_preds_custom, average='macro')

print(f"\n--- FINAL PERFORMANCE (Macro F1: {f1_macro_custom:.4f}) ---")
print(f"Overall Accuracy: {accuracy_custom*100:.2f}%")
print("\nFinal Classification Report:\n", classification_report(all_labels_custom, all_preds_custom, target_names=defect_types))


# =================================================================
# 4.2. MANUAL GRAD-CAM VISUALIZATION
# =================================================================

print("\n=======================================================")
print("✅ MANUAL GRAD-CAM: Generating Visual Proof...")
print("=======================================================")

# Hooks to store the gradient and activation map
gradients = None
activations = None

# 1. Hook Functions (Correct Signatures)
def save_gradient(module, grad_input, grad_output):
    global gradients
    gradients = grad_output[0]

def save_activations(module, input, output):
    global activations
    activations = output

# Target the last convolutional layer (bn6)
target_layer = custom_model.bn6

# Register the hooks
hook_handle_act = target_layer.register_forward_hook(save_activations)
hook_handle_grad = target_layer.register_backward_hook(save_gradient)
# ------------------------------------

# Helper function to generate Grad-CAM
def generate_grad_cam(input_tensor, predicted_class_idx):
    input_tensor.requires_grad_(True)
    output = custom_model(input_tensor)

    custom_model.zero_grad()
    one_hot = torch.zeros_like(output)
    one_hot[0][predicted_class_idx] = 1
    output.backward(gradient=one_hot, retain_graph=True)

    pooled_gradients = torch.mean(gradients, dim=[0, 2, 3])

    # CRITICAL FIX: Explicitly move Tensors to CPU for multiplication
    pooled_gradients_cpu = pooled_gradients.cpu()
    activations_cpu = activations.detach().cpu()

    for i in range(pooled_gradients_cpu.size(0)):
        activations_cpu[:, i, :, :] *= pooled_gradients_cpu[i]

    heatmap = torch.sum(activations_cpu, dim=1).squeeze()
    heatmap = torch.relu(heatmap)

    if torch.max(heatmap) > 0:
        heatmap /= torch.max(heatmap)

    np_heatmap = heatmap.cpu().numpy()
    np_heatmap = cv2.resize(np_heatmap, (64, 64))

    return np_heatmap

# Helper function for visualization
def visualize_cam(heatmap, rgb_img):
    # Fix 1: Detach input tensor before converting to numpy
    img_float = rgb_img.transpose(1, 2, 0)
    img = (img_float * 255).astype(np.uint8)

    # Fix 2: Ensure CV_8UC1 type for OpenCV colormapping
    heatmap_uint8 = np.uint8(255 * heatmap)
    heatmap_colored = cv2.applyColorMap(heatmap_uint8, cv2.COLORMAP_JET)

    superimposed_img = cv2.addWeighted(img, 0.6, heatmap_colored, 0.4, 0)
    superimposed_img = cv2.cvtColor(superimposed_img, cv2.COLOR_BGR2RGB)

    return img, superimposed_img

# --- Generate and Display Loop ---
VISUALS_DIR = 'XAI_VISUALS'
os.makedirs(VISUALS_DIR, exist_ok=True)
sample_indices_to_view = [0, 50, 100]

for i, test_idx in enumerate(sample_indices_to_view):
    input_tensor, true_label_idx = test_dataset[test_idx]
    input_tensor = input_tensor.unsqueeze(0).to(device)

    with torch.no_grad():
        output_no_grad = custom_model(input_tensor)
        predicted_class_idx = output_no_grad.argmax(dim=1).item()

    # Generate Heatmap
    np_heatmap = generate_grad_cam(input_tensor, predicted_class_idx)

    # Prepare image for display
    true_label = defect_types[true_label_idx.item()].strip("[]'")
    predicted_label = defect_types[predicted_class_idx].strip("[]'")
    rgb_img = input_tensor.squeeze(0).detach().cpu().numpy() # FINAL FIX: detach() is here

    original_img, superimposed_img = visualize_cam(np_heatmap, rgb_img)

    # Display and Save the results
    plt.figure(figsize=(12, 6))
    plt.suptitle(f'Sample {i+1} | True: {true_label} | Predicted: {predicted_label}', fontsize=14)

    plt.subplot(1, 2, 1)
    plt.imshow(original_img[:,:,0], cmap='gray')
    plt.title(f'Wafer Map: {true_label}')
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(superimposed_img)
    plt.title(f'Grad-CAM: Model Focus on Defect')
    plt.axis('off')

    plt.tight_layout()

    filename = f"grad_cam_sample_{i+1}_{predicted_label}.png"
    save_path = os.path.join(VISUALS_DIR, filename)
    plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    print(f"Saved: {save_path}")
    print("-" * 60)

# Remove hooks after use (MANDATORY)
hook_handle_act.remove()
hook_handle_grad.remove()