In [1]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset
from tensorflow.keras.models import load_model
from skimage.segmentation import mark_boundaries
from lime import lime_image
import shap
import os
from PIL import Image
import pandas as pd
import io
from skimage.transform import resize
from PIL import Image

#pip install tensorflow numpy matplotlib datasets scikit-image lime shap

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

# Create directory for saving explanation images
os.makedirs('explanations', exist_ok=True)

# Constants
img_size = 224
batch_size = 32
num_classes = 4

# Function to preprocess images
def preprocess_image(image):
    if isinstance(image, bytes):  # If it's bytes from parquet
        image = Image.open(io.BytesIO(image))
    if hasattr(image, 'convert'):  # If it's a PIL image
        image = image.convert('RGB').resize((img_size, img_size))
        image = np.array(image)
    return image / 255.0

# Load the trained model
try:
    model = load_model('best_model.h5')  # Try loading the best model first
    print("Best model loaded successfully!")
except:
    model = load_model('alzheimer_classifier.h5')  # Fallback to the final model
    print("Fallback model loaded successfully!")

# Class names for reference
class_names = ['Mild_Demented', 'Moderate_Demented', 'Non_Demented', 'Very_Mild_Demented']

# Load test data from local parquet file
def load_local_test_data(parquet_path, num_samples=20):
    df = pd.read_parquet(parquet_path)
    
    if num_samples:
        df = df.sample(n=num_samples, random_state=40)  # << Randomly sample
    
    test_images_raw = []
    test_labels = []
    
    for _, row in df.iterrows():
        # Handle different possible column names
        img_bytes = row.get('image', row.get('bytes', None))
        if isinstance(img_bytes, dict):  # If stored as dictionary with 'bytes' key
            img_bytes = img_bytes['bytes']
        
        test_images_raw.append(img_bytes)
        test_labels.append(row['label'])
    
    test_images = np.array([preprocess_image(img) for img in test_images_raw])
    test_labels = np.array(test_labels)
    
    return test_images, test_labels

# Load local test data (adjust path as needed)
test_images, test_labels = load_local_test_data('Dataset/Data/test.parquet', num_samples=20)

# Get predictions for these samples
predictions = model.predict(test_images)
predicted_classes = np.argmax(predictions, axis=1)

print(f"Loaded {len(test_images)} test images for explanation")

FileNotFoundError: [Errno 2] Unable to synchronously open file (unable to open file: name = 'alzheimer_classifier.h5', errno = 2, error message = 'No such file or directory', flags = 0, o_flags = 0)

In [None]:
print("\n===== Generating LIME explanations =====")

# Enhanced LIME explainer configuration
explainer = lime_image.LimeImageExplainer(
    kernel_width=0.25,  # Controls the size of the neighborhood
    verbose=False,      # Set to True for debugging
    random_state=42     # For reproducibility
)

def lime_predict(images):
    """Enhanced prediction function for LIME"""
    # Convert to float32 if needed (some models require this)
    if images.dtype != np.float32:
        images = images.astype(np.float32)
    # Handle single image case
    if len(images.shape) == 3:
        images = np.expand_dims(images, axis=0)
    # Ensure proper normalization (already done in preprocessing, but double-check)
    if images.max() > 1.0:
        images = np.clip(images / 255.0, 0, 1)
    return model.predict(images)

for i in range(5):    
    image = test_images[i]
    true_label = test_labels[i]
    pred_label = predicted_classes[i]
    pred_prob = predictions[i][pred_label]
    
    print(f"True: {class_names[true_label]} | Predicted: {class_names[pred_label]} ({pred_prob:.2f})")
    
    try:
        # Enhanced explanation parameters
        explanation = explainer.explain_instance(
            image.astype('double'),  # LIME works better with double precision
            lime_predict,
            top_labels=3,            # Show top 3 classes
            hide_color=0, 
            num_samples=2000,        # Increased for better quality
            batch_size=32,           # Process in batches for efficiency
            distance_metric='cosine', # Better for image data
            segmentation_fn=None     # Use default quickshift segmentation
        )
        
        # Get explanation for both positive and negative features
        temp, mask = explanation.get_image_and_mask(
            pred_label,
            positive_only=False,     # Show both positive and negative features
            num_features=4,         # Optimal number of superpixels to show
            hide_rest=False,
            min_weight=0.05          # Filter out insignificant features
        )
        
        # Create enhanced visualization
        lime_explanation = mark_boundaries(
            temp,  # Adjust brightness for better visibility
            mask,
            color=(1, 1, 1),  # White boundaries
        )
        
        # Create more informative title
        title = (f"LIME Explanation\n"
                f"True: {class_names[true_label]}\n"
                f"Pred: {class_names[pred_label]} ({pred_prob:.2f})\n"
                f"Top Features Highlighted")
        
        # Enhanced display function
        plt.figure(figsize=(12, 6))
        
        # Original Image
        plt.subplot(1, 2, 1)
        plt.imshow(image)
        plt.title(f"Original\n{class_names[true_label]}")
        plt.axis('off')
        
        # LIME Explanation
        plt.subplot(1, 2, 2)
        plt.imshow(lime_explanation)
        plt.title(f"Explanation\n{class_names[pred_label]}")
        plt.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"explanations/lime_explanation_{i}.png", bbox_inches='tight', dpi=150)
        plt.close()
        

    except Exception as e:
        print(f"✗ Failed for image {i}: {str(e)}")
        if 'image' in locals():
            print(f"Image shape: {image.shape}, dtype: {image.dtype}, range: [{image.min()}, {image.max()}]")
        continue

print("\nLIME explanation generation completed!")