# Face Recognition model with Cross Validation

In [3]:
import os
import numpy as np
from sklearn.model_selection import StratifiedKFold
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from sklearn.model_selection import KFold
from sklearn.preprocessing import LabelEncoder
from PIL import Image
from rnn import FaceRecognitionRNN

### Load Dataset

In [4]:
# Custom dataset class
class LFWDataset(Dataset):
    def __init__(self, X, y, transform=None):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.y = torch.tensor(y, dtype=torch.long)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]

lfw_path = r'C:\Users\Anvesha\Documents\Assignments\AWS\Implementation\lfw_data\aug'
input_size = 4096  # Adjusted for 64x64 grayscale image flattened

# Loading function
def load_lfw_data():
    X, y = [], []
    for person_dir in os.listdir(lfw_path):
        person_path = os.path.join(lfw_path, person_dir)
        for image_file in os.listdir(person_path):
            if image_file.endswith('.jpg'):
                image_path = os.path.join(person_path, image_file)
                image = Image.open(image_path).convert('L').resize((64, 64))
                X.append(np.array(image).flatten())
                y.append(person_dir)
    return np.array(X), np.array(y)

# Load the data
X, y = load_lfw_data()


### Label Encoding

In [7]:
# Encode labels
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y)

In [8]:
def stratified_k_fold_cross_validation(X, y, model_class, num_epochs=10, batch_size=32, lr=0.001, k=5):
    skf = StratifiedKFold(n_splits=k, shuffle=True, random_state=42)
    best_model_path = 'cross_validation.pth'  # Path to save the best model

    for fold, (train_index, val_index) in enumerate(skf.split(X, y)):
        print(f'Fold {fold + 1}/{k}')
        
        # Create train and validation datasets
        X_train, X_val = X[train_index], X[val_index]
        y_train, y_val = y[train_index], y[val_index]
        
        # Convert to PyTorch tensors and create DataLoader
        train_dataset = TensorDataset(torch.tensor(X_train, dtype=torch.float32), torch.tensor(y_train, dtype=torch.long))
        val_dataset = TensorDataset(torch.tensor(X_val, dtype=torch.float32), torch.tensor(y_val, dtype=torch.long))
        
        train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
        
        # Initialize model, criterion, and optimizer
        model = model_class(input_size=128, hidden_size=64, num_layers=2, num_classes=len(np.unique(y_train)))  # Update as needed
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        # Learning rate scheduler
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, factor=0.1, verbose=True)

        # Variable to track the best validation loss
        best_val_loss = float('inf')

        # Training loop
        for epoch in range(num_epochs):
            model.train()
            for inputs, labels in train_loader:
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()

            # Validation
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for inputs, labels in val_loader:
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    _, predicted = torch.max(outputs.data, 1)
                    total += labels.size(0)
                    correct += (predicted == labels).sum().item()
            
            val_accuracy = correct / total
            avg_val_loss = val_loss / len(val_loader)
            print(f'Epoch {epoch + 1}/{num_epochs} - Validation Loss: {avg_val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}')

            # Save the model if the current validation loss is the best
            if avg_val_loss < best_val_loss:
                best_val_loss = avg_val_loss
                torch.save(model.state_dict(), best_model_path)  # Save model state dict
                print(f'Saved the best model for fold {fold + 1} at epoch {epoch + 1}')

In [9]:
stratified_k_fold_cross_validation(X, y_encoded, FaceRecognitionRNN, num_epochs=1, batch_size=32, lr=0.001, k=5)

Fold 1/5
Epoch 1/1 - Validation Loss: 8.3160, Validation Accuracy: 0.0005
Saved the best model for fold 1 at epoch 1
Fold 2/5
Epoch 1/1 - Validation Loss: 8.3069, Validation Accuracy: 0.0002
Saved the best model for fold 2 at epoch 1
Fold 3/5
Epoch 1/1 - Validation Loss: 8.3078, Validation Accuracy: 0.0007
Saved the best model for fold 3 at epoch 1
Fold 4/5
Epoch 1/1 - Validation Loss: 8.3227, Validation Accuracy: 0.0002
Saved the best model for fold 4 at epoch 1
Fold 5/5
Epoch 1/1 - Validation Loss: 8.2863, Validation Accuracy: 0.0000
Saved the best model for fold 5 at epoch 1


## Evaluation with Explainable AI

Here, we will have to implement the xai part in each fold

### SHAP (SHapley Additive exPlanations) and LIME (Local Interpretable Model-agnostic Explanations)

In [None]:
import shap
import lime
import lime.lime_tabular

class ModelWrapper:
    def __init__(self, model):
        self.model = model
        
    def predict_proba(self, X):
        X = torch.FloatTensor(X)
        self.model.eval()
        with torch.no_grad():
            outputs = self.model(X)
            probas = torch.softmax(outputs, dim=1)
        return probas.numpy()

def analyze_interpretability(model, X, feature_names=None, 
                           use_sample=True, n_background=100, 
                           n_test_samples=5,
                           lime_samples=1):
    
    print("Starting interpretability analysis...")
    
    # Initialize feature names if not provided
    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(X.shape[1])]
    
    # Prepare data
    scaler = StandardScaler()
    X_scaled = scaler.fit_transform(X)
    
    # Wrap model
    model_wrapper = ModelWrapper(model)
    
    # SHAP Analysis
    print("\nPerforming SHAP analysis...")
    if use_sample:
        print(f"Using {n_background} background samples...")
        background_data = shap.kmeans(X_scaled, n_background)
    else:
        print("Using full dataset as background...")
        background_data = X_scaled
    
    explainer = shap.KernelExplainer(model_wrapper.predict_proba, background_data)
    
    # Calculate SHAP values
    test_samples = X_scaled[:n_test_samples]
    shap_values = explainer.shap_values(test_samples)
    
    # LIME Analysis
    print("\nPerforming LIME analysis...")
    lime_explainer = lime.lime_tabular.LimeTabularExplainer(
        X_scaled,
        feature_names=feature_names,
        class_names=[f"Class_{i}" for i in range(model.fc.out_features)],
        mode="classification"
    )
    
    lime_explanations = []
    for i in range(lime_samples):
        print(f"Generating LIME explanation for sample {i+1}/{lime_samples}")
        exp = lime_explainer.explain_instance(
            X_scaled[i],
            model_wrapper.predict_proba,
            num_features=10
        )
        lime_explanations.append(exp)
    
    results = {
        'shap_values': shap_values,
        'shap_test_samples': test_samples,
        'lime_explanations': lime_explanations,
        'feature_names': feature_names,
        'background_size': background_data.shape[0]
    }
    
    print("\nAnalysis completed!")
    print(f"SHAP values shape: {[sv.shape for sv in shap_values]}")
    print(f"Number of LIME explanations: {len(lime_explanations)}")
    
    return results

def analyze_both_datasets(model, X_original, X_cleaned, feature_names=None, 
                         use_sample=True, n_background=100, 
                         n_test_samples=5, lime_samples=1):
    print("Starting comparative analysis...")
    
    # Analyze original data
    print("\n=== Analyzing Original Data ===")
    original_results = analyze_interpretability(
        model=model,
        X=X_original,
        feature_names=feature_names,
        use_sample=use_sample,
        n_background=n_background,
        n_test_samples=n_test_samples,
        lime_samples=lime_samples
    )
    
    # Analyze cleaned data
    print("\n=== Analyzing Cleaned Data ===")
    cleaned_results = analyze_interpretability(
        model=model,
        X=X_cleaned,
        feature_names=feature_names,
        use_sample=use_sample,
        n_background=n_background,
        n_test_samples=n_test_samples,
        lime_samples=lime_samples
    )
    
    return {
        'original': original_results,
        'cleaned': cleaned_results
    }

def plot_interpretability_results(results):
    """
    Plot SHAP and LIME results
    """
    # Plot SHAP summary
    plt.figure(figsize=(12, 8))
    shap.summary_plot(
        results['shap_values'], 
        results['shap_test_samples'],
        feature_names=results['feature_names'],
        show=False
    )
    plt.title("SHAP Feature Importance")
    plt.tight_layout()
    plt.show()
    
    # Plot LIME explanations
    for i, exp in enumerate(results['lime_explanations']):
        plt.figure(figsize=(10, 6))
        exp.as_pyplot_figure()
        plt.title(f"LIME Explanation for Instance {i+1}")
        plt.tight_layout()
        plt.show()

In [None]:
# Then run the analysis
results = analyze_both_datasets(
    model = FaceRecognitionRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes),
    X_original=X_train,
    X_cleaned=X_train_cleaned,
    use_sample=True,
    n_background=10,
    n_test_samples=5,
    lime_samples=1
)

# # For full dataset analysis:
# results = analyze_both_datasets(
#     model = FaceRecognitionRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes),
#     X_original=X_train,
#     X_cleaned=X_train_cleaned,
#     use_sample=False,
#     n_test_samples=20,
#     lime_samples=5
# )

# Plot results for both datasets
print("\nPlotting Original Data Results:")
plot_interpretability_results(results['original'])

print("\nPlotting Cleaned Data Results:")
plot_interpretability_results(results['cleaned'])

### S-RISE (Saliency-guided Random Input Sampling for Explanation)

In [None]:
class SRISE:
    def __init__(self, model, input_size, n_samples=1000, s=8, p1=0.5, sigma=2.0):
        """
        Initialize S-RISE.
        
        Args:
            model: The model to explain
            input_size: Size of input features
            n_samples: Number of mask samples
            s: Stride size
            p1: Base probability
            sigma: Gaussian smoothing parameter
        """
        self.model = model
        # Convert input_size to 2D if it's 1D
        if isinstance(input_size, int):
            self.height = int(np.sqrt(input_size))
            self.width = self.height
            if self.height * self.width != input_size:
                raise ValueError(f"Input size {input_size} must be a perfect square for 2D reshaping")
        else:
            self.height, self.width = input_size
            
        self.n_samples = n_samples
        self.s = s
        self.p1 = p1
        self.sigma = sigma
        self.device = next(model.parameters()).device
        self.masks = self.generate_masks()
        
    def gaussian_kernel(self, sigma):
        """Generate Gaussian kernel."""
        kernel_size = int(4 * sigma + 1)
        if kernel_size % 2 == 0:
            kernel_size += 1
            
        center = kernel_size // 2
        x, y = np.meshgrid(np.arange(kernel_size), np.arange(kernel_size))
        kernel = np.exp(-((x - center) ** 2 + (y - center) ** 2) / (2 * sigma ** 2))
        kernel = kernel / kernel.sum()
        
        return torch.FloatTensor(kernel).unsqueeze(0).unsqueeze(0)
    
    def gaussian_blur(self, x, sigma):
        """Apply Gaussian blur to tensor."""
        kernel = self.gaussian_kernel(sigma)
        kernel = kernel.to(x.device)
        channels = x.shape[1]
        
        # Ensure kernel size is smaller than input dimensions
        if kernel.shape[-1] >= min(x.shape[-2:]):
            kernel = self.gaussian_kernel(sigma/2)  # Use smaller kernel
            
        padding = kernel.shape[-1]//2
        x_padded = torch.nn.functional.pad(x, (padding, padding, padding, padding), mode='reflect')
        return torch.nn.functional.conv2d(x_padded, kernel, groups=channels)
    
    def generate_masks(self):
        """Generate random masks."""
        masks = []
        for _ in range(self.n_samples):
            # Generate 2D mask
            h = int(np.ceil(self.height/self.s))
            w = int(np.ceil(self.width/self.s))
            mask = np.random.choice([0, 1], size=(h, w), p=[1-self.p1, self.p1])
            
            # Upsample to full size
            mask = torch.FloatTensor(mask)
            mask = nn.functional.interpolate(
                mask.unsqueeze(0).unsqueeze(0),
                size=(self.height, self.width),
                mode='nearest'
            )
            
            # Apply Gaussian smoothing
            mask = self.gaussian_blur(mask, self.sigma).squeeze()
            masks.append(mask)
            
        return torch.stack(masks).to(self.device)
    
    def explain(self, x, batch_size=32):
        """Generate saliency map for input x."""
        if isinstance(x, np.ndarray):
            x = torch.from_numpy(x)

        # Convert the tensor to float32
        x = x.float() 

        if len(x.shape) == 1:
            # Reshape 1D input to 2D
            x = x.unsqueeze(0).unsqueeze(0).view(1, 1, self.height, self.width)
        elif len(x.shape) == 2:
            # If it's already 2D, add the batch and channel dimensions
            x = x.unsqueeze(0).unsqueeze(0)  # Shape becomes (1, 1, height, width)
        elif len(x.shape) == 3:
            # If it's 3D, assume the input shape is (channels, height, width)
            x = x.unsqueeze(0)  # Add batch dimension

        self.model.eval()
        predictions = []
        
        with torch.no_grad():
            # Get original prediction
            original_pred = self.model(x)
            pred_class = original_pred.argmax().item()
            
            # Process masks in batches
            for i in range(0, self.n_samples, batch_size):
                batch_masks = self.masks[i:i+batch_size]
                batch_size = len(batch_masks)
                
                # Apply masks
                masked_inputs = x.repeat(batch_size, 1, 1, 1) * batch_masks.unsqueeze(1)
                batch_preds = self.model(masked_inputs.view(batch_size, -1))
                predictions.append(batch_preds[:, pred_class])
                
        # Calculate saliency map
        predictions = torch.cat(predictions)
        saliency_map = torch.zeros(self.height * self.width, dtype=torch.float32)
        
        for i, pred in enumerate(predictions):
            saliency_map += pred * self.masks[i].view(-1)
            
        return saliency_map.view(self.height, self.width)
    
    def visualize(self, x, saliency_map, save_path):
        """Visualize original input and saliency map."""
        plt.figure(figsize=(10, 5))

        plt.subplot(1, 2, 1)
        plt.title('Original Input')
        
        # Check if x is a tensor and print its shape
        if isinstance(x, torch.Tensor):
            # Ensure x is reshaped and moved to CPU before converting to NumPy
            if len(x.shape) == 1:
                x = x.view(1, -1)  # Ensure it's at least 2D
            original_input = x.view(self.height, self.width).cpu().detach().numpy()
        else:
            raise ValueError("Input x must be a torch.Tensor")
        
        plt.imshow(original_input, cmap='gray')
        plt.axis('off')

        # Plot saliency map
        plt.subplot(1, 2, 2)
        plt.title('Saliency Map')
        
        # Ensure saliency_map is moved to CPU before converting to NumPy
        if isinstance(saliency_map, torch.Tensor):
            saliency_map_cpu = saliency_map.cpu().detach().numpy()
        else:
            raise ValueError("Saliency map must be a torch.Tensor")

        plt.imshow(saliency_map_cpu, cmap='hot')
        plt.axis('off')

        plt.savefig(save_path)
        plt.show()

def compare_srise_explanations(model, X_original, X_cleaned, n_samples=3):
    """Compare SRISE explanations for original and cleaned data."""
    input_size = X_original.shape[1]
    srise = SRISE(model, input_size)
    
    plt.figure(figsize=(12, 4*n_samples))
    for i in range(n_samples):
        # Original data
        smap_original = srise.explain(X_original[i])
        plt.subplot(n_samples, 2, 2*i + 1)
        plt.title(f'Original Data - Sample {i+1}')
        plt.imshow(smap_original.cpu().numpy(), cmap='hot')
        plt.colorbar()
        plt.axis('off')
        
        # Cleaned data
        smap_cleaned = srise.explain(X_cleaned[i])
        plt.subplot(n_samples, 2, 2*i + 2)
        plt.title(f'Cleaned Data - Sample {i+1}')
        plt.imshow(smap_cleaned.cpu().numpy(), cmap='hot')
        plt.colorbar()
        plt.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Initialize S-RISE with your model
srise = SRISE(
    model = FaceRecognitionRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes),
    input_size=4096,  # Will be automatically reshaped to (64, 64)
    n_samples=1000,
    s=8,
    p1=0.5,
    sigma=2.0
)

# Generate saliency map for a sample
sample = X_train[0]

# Convert ndarray to PyTorch tensor
if isinstance(sample, np.ndarray):
    sample = torch.tensor(sample)

# Ensure the tensor is of type float32
sample = sample.float()
    
saliency_map = srise.explain(sample, batch_size=32)

# Visualize results
srise.visualize(sample, saliency_map, save_path='srise_visualization.png')

In [None]:
# Compare original and cleaned data
def compare_srise_explanations(model, X_original, X_cleaned, n_samples=3):
    srise = SRISE(model, X_original.shape[1])
    
    for i in range(n_samples):
        # Original data
        original = X_original[i]
        smap_original = srise.explain(original)
        print(f"\nOriginal Data - Sample {i+1}:")
        if isinstance(original, np.ndarray):
            original = torch.tensor(original)
        original = original.float()
        srise.visualize(original, smap_original, save_path='srise_visualization_original.png')
        
        # Cleaned data
        cleaned =  X_cleaned[i]
        smap_cleaned = srise.explain(cleaned)
        print(f"\nCleaned Data - Sample {i+1}:")
        if isinstance(cleaned, np.ndarray):
            cleaned = torch.tensor(cleaned)
        cleaned = cleaned.float()
        srise.visualize(cleaned, smap_cleaned, save_path='srise_visualization_cleaned.png')
        
# Run comparison
compare_srise_explanations(
    model = FaceRecognitionRNN(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, num_classes=num_classes),
    X_original=X_train,
    X_cleaned=X_train_cleaned,
    n_samples=3
)