In [None]:
# AI-Driven Defect Detection in Chemically Decapsulated NAND Flash Memory

## Abstract
This notebook presents a comprehensive implementation of deep learning techniques for automated defect detection in semiconductor wafer maps. The approach utilizes convolutional neural networks (CNNs) to classify various defect patterns in chemically decapsulated NAND flash memory devices, enabling efficient quality control in semiconductor manufacturing.

## Motivation
Traditional manual inspection of semiconductor wafers is time-consuming and prone to human error. This implementation demonstrates how artificial intelligence can enhance defect detection accuracy while reducing inspection time (LeCun et al., 2015).

## Dataset
We utilize the WM-811K wafer map dataset, which contains 811,457 wafer maps with various defect patterns commonly found in semiconductor manufacturing (Nakazawa & Kulkarni, 2018).

ImportError: Traceback (most recent call last):
  File "c:\Users\HP\AppData\Local\Programs\Python\Python310\lib\site-packages\tensorflow\python\pywrap_tensorflow.py", line 73, in <module>
    from tensorflow.python._pywrap_tensorflow_internal import *
ImportError: DLL load failed while importing _pywrap_tensorflow_internal: A dynamic link library (DLL) initialization routine failed.


Failed to load the native TensorFlow runtime.
See https://www.tensorflow.org/install/errors for some common causes and solutions.
If you need help, create an issue at https://github.com/tensorflow/tensorflow/issues and include the entire stack trace above this error message.

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, f1_score, precision_score, recall_score
import tensorflow as tf
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Dense, Flatten, Dropout, BatchNormalization, Input
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras.utils import to_categorical
import pickle
import warnings
import os
from pathlib import Path

# Configuration
warnings.filterwarnings('ignore')
tf.random.set_seed(42)
np.random.seed(42)

print("TensorFlow version:", tf.__version__)
print("GPU Available:", tf.test.is_gpu_available())
print("Libraries imported successfully!")

## Requirements and Dependencies

### Core Libraries:

- **TensorFlow 2.x**: Deep learning framework for CNN implementation (Abadi et al., 2016)
- **NumPy**: Numerical computing for array operations (Harris et al., 2020)
- **Pandas**: Data manipulation and analysis (McKinney, 2010)
- **Scikit-learn**: Machine learning utilities and metrics (Pedregosa et al., 2011)
- **Matplotlib/Seaborn**: Data visualization (Hunter, 2007)

### Installation:

```bash
pip install tensorflow numpy pandas scikit-learn matplotlib seaborn
```

### Hardware Recommendations:

- GPU with CUDA support for accelerated training
- Minimum 8GB RAM for dataset handling
- 10GB+ storage for dataset and model checkpoints


In [None]:
class WaferDataLoader:
    """
    Custom data loader for WM-811K wafer map dataset
    Handles loading, preprocessing, and visualization of wafer maps
    """
    
    def __init__(self, data_path):
        self.data_path = data_path
        self.data = None
        self.label_encoder = LabelEncoder()
        
    def load_data(self):
        """Load wafer map data from pickle file"""
        try:
            with open(self.data_path, 'rb') as f:
                self.data = pickle.load(f)
            print(f"Dataset loaded successfully!")
            print(f"Total samples: {len(self.data)}")
            return True
        except FileNotFoundError:
            print(f"Error: File not found at {self.data_path}")
            print("Please ensure the WM-811K dataset is downloaded and path is correct")
            return False
        except Exception as e:
            print(f"Error loading data: {str(e)}")
            return False
    
    def explore_data(self):
        """Explore dataset structure and statistics"""
        if self.data is None:
            print("Please load data first using load_data()")
            return
        
        # Sample data structure
        sample = self.data[0]
        print("Data structure:")
        for key, value in sample.items():
            if isinstance(value, np.ndarray):
                print(f"- {key}: shape {value.shape}, dtype {value.dtype}")
            else:
                print(f"- {key}: {type(value)} - {value}")
        
        # Defect pattern distribution
        defect_patterns = [item['failureType'] for item in self.data if item['failureType'] is not None]
        pattern_counts = pd.Series(defect_patterns).value_counts()
        
        print(f"\nDefect Pattern Distribution:")
        print(pattern_counts)
        
        return pattern_counts
    
    def visualize_samples(self, n_samples=9):
        """Visualize sample wafer maps for each defect type"""
        if self.data is None:
            print("Please load data first")
            return
        
        # Get unique defect types
        defect_types = list(set([item['failureType'] for item in self.data if item['failureType'] is not None]))
        
        fig, axes = plt.subplots(3, 3, figsize=(15, 15))
        axes = axes.ravel()
        
        for i, defect_type in enumerate(defect_types[:n_samples]):
            # Find first sample with this defect type
            sample = next(item for item in self.data if item['failureType'] == defect_type)
            wafer_map = sample['waferMap']
            
            axes[i].imshow(wafer_map, cmap='viridis')
            axes[i].set_title(f'Defect Type: {defect_type}')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()

# Initialize data loader (adjust path as needed)
# Note: Replace with actual path to downloaded WM-811K dataset
data_loader = WaferDataLoader('WAFERM811K_25013.pkl')

# Attempt to load data (will show instructions if file not found)
if data_loader.load_data():
    pattern_distribution = data_loader.explore_data()
    data_loader.visualize_samples()

## Data Preprocessing Pipeline

The preprocessing pipeline addresses several key challenges in semiconductor defect detection:

### 1. **Data Cleaning** (Ng, 2016)

- Remove samples with missing labels
- Handle irregular wafer map sizes
- Filter out extremely rare defect types

### 2. **Normalization and Standardization**

- Normalize pixel values to [0, 1] range
- Apply z-score normalization for consistent feature scaling

### 3. **Data Augmentation** (Shorten & Khoshgoftaar, 2019)

- Rotation: Account for wafer orientation variations
- Flipping: Increase dataset diversity
- Noise injection: Improve model robustness

### 4. **Class Balancing**

- Address class imbalance through stratified sampling
- Implement weighted loss functions for minority classes


In [None]:
class WaferPreprocessor:
    """
    Comprehensive preprocessor for wafer map data
    Handles cleaning, normalization, and augmentation
    """
    
    def __init__(self, target_size=(64, 64), min_samples_per_class=100):
        self.target_size = target_size
        self.min_samples_per_class = min_samples_per_class
        self.label_encoder = LabelEncoder()
        self.class_weights = None
        
    def clean_data(self, data):
        """Clean and filter dataset"""
        print("Cleaning dataset...")
        
        # Remove samples without failure type
        cleaned_data = [item for item in data if item['failureType'] is not None]
        print(f"Removed {len(data) - len(cleaned_data)} samples without labels")
        
        # Count samples per class
        failure_types = [item['failureType'] for item in cleaned_data]
        type_counts = pd.Series(failure_types).value_counts()
        
        # Keep only classes with sufficient samples
        valid_types = type_counts[type_counts >= self.min_samples_per_class].index.tolist()
        cleaned_data = [item for item in cleaned_data if item['failureType'] in valid_types]
        
        print(f"Keeping {len(valid_types)} classes with >= {self.min_samples_per_class} samples")
        print(f"Final dataset size: {len(cleaned_data)} samples")
        
        return cleaned_data
    
    def normalize_wafer_maps(self, wafer_maps):
        """Normalize and resize wafer maps"""
        normalized_maps = []
        
        for wafer_map in wafer_maps:
            # Resize to target size
            resized_map = tf.image.resize(
                tf.expand_dims(wafer_map, axis=-1), 
                self.target_size
            ).numpy().squeeze()
            
            # Normalize to [0, 1]
            if resized_map.max() > resized_map.min():
                normalized_map = (resized_map - resized_map.min()) / (resized_map.max() - resized_map.min())
            else:
                normalized_map = resized_map
            
            normalized_maps.append(normalized_map)
        
        return np.array(normalized_maps)
    
    def prepare_dataset(self, data):
        """Prepare complete dataset for training"""
        # Clean data
        cleaned_data = self.clean_data(data)
        
        # Extract features and labels
        wafer_maps = [item['waferMap'] for item in cleaned_data]
        failure_types = [item['failureType'] for item in cleaned_data]
        
        # Normalize wafer maps
        X = self.normalize_wafer_maps(wafer_maps)
        X = np.expand_dims(X, axis=-1)  # Add channel dimension
        
        # Encode labels
        y_encoded = self.label_encoder.fit_transform(failure_types)
        y_categorical = to_categorical(y_encoded)
        
        # Calculate class weights for imbalanced dataset
        from sklearn.utils.class_weight import compute_class_weight
        class_weights_array = compute_class_weight(
            'balanced',
            classes=np.unique(y_encoded),
            y=y_encoded
        )
        self.class_weights = dict(enumerate(class_weights_array))
        
        print(f"Dataset shape: {X.shape}")
        print(f"Number of classes: {len(self.label_encoder.classes_)}")
        print(f"Classes: {list(self.label_encoder.classes_)}")
        
        return X, y_categorical, failure_types

# Apply preprocessing if data is loaded
if 'data_loader' in globals() and data_loader.data is not None:
    preprocessor = WaferPreprocessor(target_size=(64, 64))
    X, y, original_labels = preprocessor.prepare_dataset(data_loader.data)
    
    # Split dataset
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.3, random_state=42, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=42, stratify=y_temp
    )
    
    print(f"Training set: {X_train.shape}")
    print(f"Validation set: {X_val.shape}")
    print(f"Test set: {X_test.shape}")
else:
    print("Creating dummy data for demonstration purposes...")
    # Create dummy data if actual dataset not available
    X_train = np.random.random((1000, 64, 64, 1))
    X_val = np.random.random((300, 64, 64, 1))
    X_test = np.random.random((200, 64, 64, 1))
    y_train = to_categorical(np.random.randint(0, 9, 1000), num_classes=9)
    y_val = to_categorical(np.random.randint(0, 9, 300), num_classes=9)
    y_test = to_categorical(np.random.randint(0, 9, 200), num_classes=9)

## Convolutional Neural Network Architecture

Our CNN architecture is specifically designed for semiconductor defect detection, incorporating modern deep learning principles:

### **Architecture Components:**

1. **Convolutional Layers** (Krizhevsky et al., 2012)

   - Multiple conv2D layers with increasing filter counts
   - ReLU activation for non-linearity
   - Batch normalization for training stability

2. **Pooling Layers**

   - MaxPooling for spatial dimension reduction
   - Preserves most important features while reducing computation

3. **Regularization Techniques** (Srivastava et al., 2014)

   - Dropout layers to prevent overfitting
   - Batch normalization for internal covariate shift reduction

4. **Dense Classification Head**
   - Fully connected layers for final classification
   - Softmax activation for multi-class probability distribution

### **Design Rationale:**

The architecture balances model complexity with computational efficiency, suitable for industrial deployment while maintaining high accuracy.


In [None]:
class DefectDetectionCNN:
    """
    Advanced CNN architecture for semiconductor defect detection
    Incorporates modern deep learning best practices
    """
    
    def __init__(self, input_shape=(64, 64, 1), num_classes=9):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model = None
        self.history = None
        
    def build_model(self):
        """Build CNN architecture optimized for defect detection"""
        
        model = Sequential([
            # Input layer
            Input(shape=self.input_shape),
            
            # First convolutional block
            Conv2D(32, (3, 3), activation='relu', padding='same'),
            BatchNormalization(),
            Conv2D(32, (3, 3), activation='relu', padding='same'),
            MaxPooling2D((2, 2)),
            Dropout(0.25),
            
            # Second convolutional block
            Conv2D(64, (3, 3), activation='relu', padding='same'),
            BatchNormalization(),
            Conv2D(64, (3, 3), activation='relu', padding='same'),
            MaxPooling2D((2, 2)),
            Dropout(0.25),
            
            # Third convolutional block
            Conv2D(128, (3, 3), activation='relu', padding='same'),
            BatchNormalization(),
            Conv2D(128, (3, 3), activation='relu', padding='same'),
            MaxPooling2D((2, 2)),
            Dropout(0.25),
            
            # Fourth convolutional block
            Conv2D(256, (3, 3), activation='relu', padding='same'),
            BatchNormalization(),
            MaxPooling2D((2, 2)),
            Dropout(0.5),
            
            # Classification head
            Flatten(),
            Dense(512, activation='relu'),
            BatchNormalization(),
            Dropout(0.5),
            Dense(256, activation='relu'),
            Dropout(0.5),
            Dense(self.num_classes, activation='softmax')
        ])
        
        # Compile model with appropriate optimizer and loss function
        model.compile(
            optimizer=Adam(learning_rate=0.001),
            loss='categorical_crossentropy',
            metrics=['accuracy', 'precision', 'recall']
        )
        
        self.model = model
        return model
    
    def get_model_summary(self):
        """Display model architecture summary"""
        if self.model is None:
            self.build_model()
        
        self.model.summary()
        
        # Calculate total parameters
        total_params = self.model.count_params()
        print(f"\nTotal Parameters: {total_params:,}")
        
        return self.model.summary()
    
    def visualize_architecture(self):
        """Visualize model architecture"""
        if self.model is None:
            self.build_model()
        
        tf.keras.utils.plot_model(
            self.model,
            to_file='model_architecture.png',
            show_shapes=True,
            show_layer_names=True
        )
        print("Model architecture saved as 'model_architecture.png'")

# Initialize and build model
cnn_model = DefectDetectionCNN(
    input_shape=(64, 64, 1),
    num_classes=y_train.shape[1] if 'y_train' in globals() else 9
)

# Build and display model
model = cnn_model.build_model()
cnn_model.get_model_summary()

In [None]:
class TrainingManager:
    """
    Manages training process with callbacks and monitoring
    """
    
    def __init__(self, model):
        self.model = model
        self.callbacks = []
        self.history = None
        
    def setup_callbacks(self, model_save_path='best_model.h5'):
        """Configure training callbacks for optimal training"""
        
        # Early stopping to prevent overfitting
        early_stopping = EarlyStopping(
            monitor='val_loss',
            patience=10,
            restore_best_weights=True,
            verbose=1
        )
        
        # Reduce learning rate on plateau
        reduce_lr = ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.2,
            patience=5,
            min_lr=1e-7,
            verbose=1
        )
        
        # Save best model
        model_checkpoint = ModelCheckpoint(
            model_save_path,
            monitor='val_accuracy',
            save_best_only=True,
            verbose=1
        )
        
        self.callbacks = [early_stopping, reduce_lr, model_checkpoint]
        return self.callbacks
    
    def train_model(self, X_train, y_train, X_val, y_val, 
                   epochs=50, batch_size=32, class_weights=None):
        """Train the model with specified parameters"""
        
        # Setup callbacks
        self.setup_callbacks()
        
        print("Starting training...")
        print(f"Training samples: {len(X_train)}")
        print(f"Validation samples: {len(X_val)}")
        print(f"Batch size: {batch_size}")
        print(f"Maximum epochs: {epochs}")
        
        # Train model
        self.history = self.model.fit(
            X_train, y_train,
            batch_size=batch_size,
            epochs=epochs,
            validation_data=(X_val, y_val),
            callbacks=self.callbacks,
            class_weight=class_weights,
            verbose=1
        )
        
        print("Training completed!")
        return self.history
    
    def plot_training_history(self):
        """Visualize training progress"""
        if self.history is None:
            print("No training history available. Train model first.")
            return
        
        # Create subplots for metrics
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Accuracy
        axes[0, 0].plot(self.history.history['accuracy'], label='Training Accuracy')
        axes[0, 0].plot(self.history.history['val_accuracy'], label='Validation Accuracy')
        axes[0, 0].set_title('Model Accuracy')
        axes[0, 0].set_xlabel('Epoch')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].legend()
        axes[0, 0].grid(True)
        
        # Loss
        axes[0, 1].plot(self.history.history['loss'], label='Training Loss')
        axes[0, 1].plot(self.history.history['val_loss'], label='Validation Loss')
        axes[0, 1].set_title('Model Loss')
        axes[0, 1].set_xlabel('Epoch')
        axes[0, 1].set_ylabel('Loss')
        axes[0, 1].legend()
        axes[0, 1].grid(True)
        
        # Precision
        if 'precision' in self.history.history:
            axes[1, 0].plot(self.history.history['precision'], label='Training Precision')
            axes[1, 0].plot(self.history.history['val_precision'], label='Validation Precision')
            axes[1, 0].set_title('Model Precision')
            axes[1, 0].set_xlabel('Epoch')
            axes[1, 0].set_ylabel('Precision')
            axes[1, 0].legend()
            axes[1, 0].grid(True)
        
        # Recall
        if 'recall' in self.history.history:
            axes[1, 1].plot(self.history.history['recall'], label='Training Recall')
            axes[1, 1].plot(self.history.history['val_recall'], label='Validation Recall')
            axes[1, 1].set_title('Model Recall')
            axes[1, 1].set_xlabel('Epoch')
            axes[1, 1].set_ylabel('Recall')
            axes[1, 1].legend()
            axes[1, 1].grid(True)
        
        plt.tight_layout()
        plt.show()

# Initialize training manager
trainer = TrainingManager(model)

In [None]:
# Train the model
print("="*50)
print("TRAINING PHASE")
print("="*50)

# Get class weights if available
class_weights = None
if 'preprocessor' in globals() and hasattr(preprocessor, 'class_weights'):
    class_weights = preprocessor.class_weights

# Execute training
history = trainer.train_model(
    X_train, y_train, X_val, y_val,
    epochs=30,  # Reduced for demonstration
    batch_size=32,
    class_weights=class_weights
)

# Visualize training progress
trainer.plot_training_history()

## Model Evaluation and Performance Analysis

### **Evaluation Metrics** (Sokolova & Lapalme, 2009)

1. **Accuracy**: Overall correctness of predictions

   - Formula: (TP + TN) / (TP + TN + FP + FN)

2. **Precision**: Ability to avoid false positives

   - Formula: TP / (TP + FP)

3. **Recall (Sensitivity)**: Ability to find all positive instances

   - Formula: TP / (TP + FN)

4. **F1-Score**: Harmonic mean of precision and recall

   - Formula: 2 × (Precision × Recall) / (Precision + Recall)

5. **Confusion Matrix**: Detailed breakdown of classification performance

### **Industrial Relevance:**

In semiconductor manufacturing, minimizing false negatives (missed defects) is critical for product quality, while controlling false positives (unnecessary rejections) is important for cost efficiency.


In [None]:
class ModelEvaluator:
    """
    Comprehensive evaluation suite for defect detection model
    """
    
    def __init__(self, model, label_encoder=None):
        self.model = model
        self.label_encoder = label_encoder
        self.predictions = None
        self.true_labels = None
        
    def evaluate_model(self, X_test, y_test, class_names=None):
        """Comprehensive model evaluation"""
        print("="*50)
        print("MODEL EVALUATION")
        print("="*50)
        
        # Make predictions
        self.predictions = self.model.predict(X_test)
        predicted_classes = np.argmax(self.predictions, axis=1)
        true_classes = np.argmax(y_test, axis=1)
        
        self.true_labels = true_classes
        self.predicted_labels = predicted_classes
        
        # Calculate metrics
        accuracy = np.mean(predicted_classes == true_classes)
        precision = precision_score(true_classes, predicted_classes, average='weighted')
        recall = recall_score(true_classes, predicted_classes, average='weighted')
        f1 = f1_score(true_classes, predicted_classes, average='weighted')
        
        print(f"Test Accuracy: {accuracy:.4f}")
        print(f"Weighted Precision: {precision:.4f}")
        print(f"Weighted Recall: {recall:.4f}")
        print(f"Weighted F1-Score: {f1:.4f}")
        
        # Per-class metrics
        if class_names is None:
            if self.label_encoder is not None:
                class_names = self.label_encoder.classes_
            else:
                class_names = [f"Class_{i}" for i in range(len(np.unique(true_classes)))]
        
        print("\nDetailed Classification Report:")
        print(classification_report(true_classes, predicted_classes, 
                                  target_names=class_names))
        
        return {
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1_score': f1,
            'predictions': self.predictions,
            'predicted_classes': predicted_classes,
            'true_classes': true_classes
        }
    
    def plot_confusion_matrix(self, class_names=None, normalize=False):
        """Plot confusion matrix heatmap"""
        if self.true_labels is None or self.predicted_labels is None:
            print("Please run evaluate_model first")
            return
        
        # Compute confusion matrix
        cm = confusion_matrix(self.true_labels, self.predicted_labels)
        
        if normalize:
            cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
            title = 'Normalized Confusion Matrix'
            fmt = '.2f'
        else:
            title = 'Confusion Matrix'
            fmt = 'd'
        
        # Plot
        plt.figure(figsize=(12, 10))
        sns.heatmap(cm, annot=True, fmt=fmt, cmap='Blues',
                   xticklabels=class_names, yticklabels=class_names)
        plt.title(title)
        plt.xlabel('Predicted Label')
        plt.ylabel('True Label')
        plt.tight_layout()
        plt.show()
        
        return cm
    
    def plot_prediction_confidence(self):
        """Analyze prediction confidence distribution"""
        if self.predictions is None:
            print("Please run evaluate_model first")
            return
        
        # Get confidence scores (max probability for each prediction)
        confidence_scores = np.max(self.predictions, axis=1)
        
        # Plot distribution
        plt.figure(figsize=(12, 5))
        
        plt.subplot(1, 2, 1)
        plt.hist(confidence_scores, bins=50, alpha=0.7, edgecolor='black')
        plt.title('Prediction Confidence Distribution')
        plt.xlabel('Confidence Score')
        plt.ylabel('Frequency')
        plt.grid(True, alpha=0.3)
        
        plt.subplot(1, 2, 2)
        # Confidence vs correctness
        correct_mask = self.predicted_labels == self.true_labels
        
        plt.hist(confidence_scores[correct_mask], bins=30, alpha=0.7, 
                label='Correct Predictions', color='green')
        plt.hist(confidence_scores[~correct_mask], bins=30, alpha=0.7, 
                label='Incorrect Predictions', color='red')
        plt.title('Confidence by Prediction Correctness')
        plt.xlabel('Confidence Score')
        plt.ylabel('Frequency')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        print(f"Average confidence for correct predictions: {np.mean(confidence_scores[correct_mask]):.4f}")
        print(f"Average confidence for incorrect predictions: {np.mean(confidence_scores[~correct_mask]):.4f}")

# Evaluate model
evaluator = ModelEvaluator(
    model, 
    label_encoder=preprocessor.label_encoder if 'preprocessor' in globals() else None
)

# Run comprehensive evaluation
evaluation_results = evaluator.evaluate_model(X_test, y_test)

# Plot confusion matrix
class_names = preprocessor.label_encoder.classes_ if 'preprocessor' in globals() else None
evaluator.plot_confusion_matrix(class_names=class_names, normalize=True)

# Analyze prediction confidence
evaluator.plot_prediction_confidence()

## Practical Implementation for Industrial Deployment

### **Model Optimization Techniques:**

1. **Model Quantization** (Jacob et al., 2018)

   - Reduce model size for edge deployment
   - Maintain accuracy while improving inference speed

2. **TensorFlow Lite Conversion**

   - Mobile and embedded device compatibility
   - Optimized for resource-constrained environments

3. **Batch Processing Pipeline**
   - Efficient handling of multiple wafer maps
   - Queue management for production lines

### **Integration Considerations:**

- **Real-time Processing**: Sub-second inference for production lines
- **Quality Assurance**: Confidence thresholds for automated decisions
- **Human-in-the-Loop**: Expert review for borderline cases
- **Continuous Learning**: Model updates with new defect patterns


In [None]:
class ProductionDeployment:
    """
    Production-ready deployment pipeline for defect detection
    """
    
    def __init__(self, model, preprocessor, confidence_threshold=0.8):
        self.model = model
        self.preprocessor = preprocessor
        self.confidence_threshold = confidence_threshold
        self.inference_times = []
        
    def preprocess_single_wafer(self, wafer_map):
        """Preprocess single wafer map for inference"""
        import time
        
        start_time = time.time()
        
        # Resize and normalize
        if len(wafer_map.shape) == 2:
            resized_map = tf.image.resize(
                tf.expand_dims(tf.expand_dims(wafer_map, axis=-1), axis=0), 
                (64, 64)
            ).numpy()
        else:
            resized_map = tf.image.resize(
                tf.expand_dims(wafer_map, axis=0), 
                (64, 64)
            ).numpy()
        
        # Normalize
        if resized_map.max() > resized_map.min():
            normalized_map = (resized_map - resized_map.min()) / (resized_map.max() - resized_map.min())
        else:
            normalized_map = resized_map
        
        preprocessing_time = time.time() - start_time
        return normalized_map, preprocessing_time
    
    def predict_defect(self, wafer_map, return_confidence=True):
        """Single wafer defect prediction with confidence scoring"""
        import time
        
        start_time = time.time()
        
        # Preprocess wafer map
        processed_wafer, prep_time = self.preprocess_single_wafer(wafer_map)
        
        # Make prediction
        prediction = self.model.predict(processed_wafer, verbose=0)
        confidence = np.max(prediction)
        predicted_class = np.argmax(prediction)
        
        # Get class name if label encoder available
        if hasattr(self.preprocessor, 'label_encoder'):
            predicted_defect = self.preprocessor.label_encoder.inverse_transform([predicted_class])[0]
        else:
            predicted_defect = f"Class_{predicted_class}"
        
        inference_time = time.time() - start_time
        self.inference_times.append(inference_time)
        
        # Decision based on confidence
        reliable_prediction = confidence >= self.confidence_threshold
        
        result = {
            'predicted_defect': predicted_defect,
            'confidence': float(confidence),
            'reliable': reliable_prediction,
            'inference_time': inference_time,
            'preprocessing_time': prep_time,
            'raw_probabilities': prediction[0].tolist()
        }
        
        if return_confidence:
            return result
        else:
            return predicted_defect
    
    def generate_performance_report(self):
        """Generate performance statistics for deployment monitoring"""
        if not self.inference_times:
            print("No inference times recorded. Run predictions first.")
            return
        
        report = {
            'total_predictions': len(self.inference_times),
            'avg_inference_time': np.mean(self.inference_times),
            'min_inference_time': np.min(self.inference_times),
            'max_inference_time': np.max(self.inference_times),
            'std_inference_time': np.std(self.inference_times),
            'throughput_per_second': 1.0 / np.mean(self.inference_times)
        }
        
        print("="*50)
        print("DEPLOYMENT PERFORMANCE REPORT")
        print("="*50)
        print(f"Total Predictions: {report['total_predictions']}")
        print(f"Average Inference Time: {report['avg_inference_time']:.4f} seconds")
        print(f"Min/Max Inference Time: {report['min_inference_time']:.4f}/{report['max_inference_time']:.4f} seconds")
        print(f"Standard Deviation: {report['std_inference_time']:.4f} seconds")
        print(f"Estimated Throughput: {report['throughput_per_second']:.2f} wafers/second")
        
        return report

# Initialize deployment pipeline
deployment = ProductionDeployment(
    model, 
    preprocessor if 'preprocessor' in globals() else None,
    confidence_threshold=0.8
)

# Demo single prediction
print("="*50)
print("SINGLE WAFER PREDICTION DEMO")
print("="*50)

# Use a test sample for demonstration
demo_wafer = X_test[0].squeeze()  # Remove channel dimension for demo
result = deployment.predict_defect(demo_wafer)

print(f"Predicted Defect: {result['predicted_defect']}")
print(f"Confidence: {result['confidence']:.4f}")
print(f"Reliable Prediction: {result['reliable']}")
print(f"Inference Time: {result['inference_time']:.4f} seconds")

# Generate performance report
deployment.generate_performance_report()