# Wafer Defect Detection

This notebook walks through the process of building and training a model to detect defects on semiconductor wafers.

## 1. Imports

Import all the necessary libraries.

In [None]:
import tensorflow as tf
from tensorflow import keras
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import skimage
from skimage import measure
from skimage.transform import radon
from skimage.transform import probabilistic_hough_line
from scipy import interpolate
from scipy import stats

## 2. Model Definition

Here we define the Convolutional Neural Network (CNN) architecture.

In [None]:
def build_model(input_shape, num_classes):
    """
    Builds a convolutional neural network (CNN) for image classification.

    Args:
        input_shape (tuple): The shape of the input images (height, width, channels).
        num_classes (int): The number of classes for classification.

    Returns:
        keras.Model: The compiled CNN model.
    """
    model = keras.Sequential([
        keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=input_shape),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.MaxPooling2D((2, 2)),
        keras.layers.Conv2D(64, (3, 3), activation='relu'),
        keras.layers.Flatten(),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dense(num_classes, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model

## 3. Build and Summarize Model

Now, let's instantiate the model with our specific parameters and print its summary.

In [None]:
# Define dataset parameters
INPUT_SHAPE = (128, 128, 1)  # Example: 128x128 grayscale images
NUM_CLASSES = 9  # Example: 8 defect types + 1 normal

# Build the model
model = build_model(INPUT_SHAPE, NUM_CLASSES)

# Print the model summary
model.summary()

## 4. Load Data

Load the wafer map data from the pickle file.

In [None]:
# The notebook is in the 'src' directory, so we go up one level to find the 'data' directory.
file_path = os.path.join('..', 'data', 'LSWMD.pkl')
df = pd.read_pickle(file_path)

# Display the first few rows to verify it loaded correctly
df.head()

## 5. Preprocessing and Cleaning

Prepare the data for the model. This includes dropping unnecessary columns and converting labels into a numerical format.

In [None]:
# The 'waferIndex' is not needed for classification
df = df.drop(['waferIndex'], axis = 1)

# Create new columns with numerical representations of the labels
df['failureNum'] = df.failureType
df['trainTestNum'] = df.trianTestLabel
mapping_type = {'Center':0, 'Donut':1, 'Edge-Loc':2, 'Edge-Ring':3, 'Loc':4, 'Random':5, 'Scratch':6, 'Near-full':7, 'none':8}
mapping_traintest = {'Training':0, 'Test':1}
df = df.replace({'failureNum': mapping_type, 'trainTestNum': mapping_traintest})

# Check the data types and non-null counts
df.info()

## 6. Data Exploration and Visualization

Let's visualize some of the wafer maps to get a feel for the data.

In [None]:
# Filter for wafers that have a defect pattern (i.e., are not 'none')
df_withpattern = df[(df['failureNum'] >= 0) & (df['failureNum'] <= 7)]
df_withpattern = df_withpattern.reset_index()

# Display the first 20 defect patterns
num_to_show = 20
fig, ax = plt.subplots(nrows=2, ncols=10, figsize=(20, 5))
ax = ax.ravel(order='C')
for i in range(num_to_show):
    img = df_withpattern.waferMap[i]
    ax[i].imshow(img)
    ax[i].set_title(df_withpattern.failureType[i][0][0], fontsize=10)
    ax[i].set_xlabel(df_withpattern.index[i], fontsize=8)
    ax[i].set_xticks([])
    ax[i].set_yticks([])
plt.tight_layout()
plt.show()

## 7. Next Steps (TODO)

Now that the data is loaded and preprocessed, the next steps are:

- **Preprocess the images**: Resize, normalize, etc.
- **Split the data**: Divide the data into training and testing sets using the `trainTestNum` column.
- **Train the model**: Use `model.fit()` with the training data.
- **Evaluate the model**: Use `model.evaluate()` with the testing data.
- **Save the trained model**: Persist the trained model for future use.