# Explainability Analysis for Multimodal Rare Disease Diagnosis

This notebook provides visualization tools for understanding model predictions:
1. **Grad-CAM**: Visualize facial regions influencing predictions
2. **Attention Weights**: Understand important clinical terms
3. **Cross-modal Analysis**: See how modalities interact

In [None]:
import sys
sys.path.append('..')

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
from pathlib import Path

from src.config import get_config
from src.multimodal_classifier import MultimodalClassifier
from src.predict import MultimodalPredictor

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Grad-CAM Visualization

Gradient-weighted Class Activation Mapping (Grad-CAM) highlights which regions of the facial image contribute most to the prediction.

In [None]:
class GradCAM:
    """
    Grad-CAM implementation for CNN encoder visualization.
    """
    
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self._save_activation)
        self.target_layer.register_full_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate_cam(self, input_image, input_ids, attention_mask, target_class=None):
        """
        Generate Grad-CAM heatmap.
        
        Args:
            input_image: Image tensor [1, 3, H, W]
            input_ids: Token IDs
            attention_mask: Attention mask
            target_class: Target class index (uses predicted if None)
        
        Returns:
            CAM heatmap
        """
        self.model.eval()
        
        # Forward pass
        output = self.model(input_image, input_ids, attention_mask)
        
        if target_class is None:
            target_class = output['logits'].argmax(dim=1).item()
        
        # Backward pass
        self.model.zero_grad()
        output['logits'][0, target_class].backward()
        
        # Generate CAM
        gradients = self.gradients[0]  # [C, H, W]
        activations = self.activations[0]  # [C, H, W]
        
        # Global average pooling of gradients
        weights = gradients.mean(dim=(1, 2), keepdim=True)
        
        # Weighted sum of activations
        cam = (weights * activations).sum(dim=0)
        cam = F.relu(cam)  # Apply ReLU
        
        # Normalize
        cam = cam - cam.min()
        cam = cam / (cam.max() + 1e-8)
        
        return cam.cpu().numpy(), target_class, output['probs'][0, target_class].item()
    
    def visualize(self, image_path, text, config=None, figsize=(15, 5)):
        """
        Visualize Grad-CAM overlay on image.
        """
        if config is None:
            config = get_config()
        
        from torchvision import transforms
        from transformers import AutoTokenizer
        
        # Load and preprocess image
        original_image = Image.open(image_path).convert('RGB')
        
        transform = transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
        
        image_tensor = transform(original_image).unsqueeze(0)
        
        # Tokenize text
        tokenizer = AutoTokenizer.from_pretrained(config.text_encoder.model_name)
        encoding = tokenizer(
            text,
            max_length=128,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Generate CAM
        device = next(self.model.parameters()).device
        image_tensor = image_tensor.to(device)
        input_ids = encoding['input_ids'].to(device)
        attention_mask = encoding['attention_mask'].to(device)
        
        cam, target_class, confidence = self.generate_cam(
            image_tensor, input_ids, attention_mask
        )
        
        # Resize CAM to image size
        cam_resized = cv2.resize(cam, (224, 224))
        
        # Create heatmap
        heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
        heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        
        # Overlay
        original_resized = original_image.resize((224, 224))
        original_array = np.array(original_resized)
        overlay = 0.6 * original_array + 0.4 * heatmap
        overlay = overlay.astype(np.uint8)
        
        # Plot
        fig, axes = plt.subplots(1, 3, figsize=figsize)
        
        axes[0].imshow(original_resized)
        axes[0].set_title('Original Image')
        axes[0].axis('off')
        
        axes[1].imshow(cam_resized, cmap='jet')
        axes[1].set_title('Grad-CAM Heatmap')
        axes[1].axis('off')
        
        axes[2].imshow(overlay)
        axes[2].set_title(f'Overlay\nPredicted: Class {target_class} ({confidence:.1%})')
        axes[2].axis('off')
        
        plt.tight_layout()
        plt.show()
        
        return cam, target_class, confidence

## 2. Text Attention Visualization

Visualize which tokens in the clinical narrative are most important for the prediction.

In [None]:
def visualize_text_attention(tokens, attention_weights, figsize=(15, 3)):
    """
    Visualize attention weights over tokens.
    
    Args:
        tokens: List of token strings
        attention_weights: Attention weights array
        figsize: Figure size
    """
    # Filter out special tokens
    valid_mask = [t not in ['[PAD]', '[CLS]', '[SEP]'] for t in tokens]
    filtered_tokens = [t for t, m in zip(tokens, valid_mask) if m]
    filtered_weights = attention_weights[valid_mask]
    
    # Normalize weights
    filtered_weights = filtered_weights / filtered_weights.max()
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create color-coded tokens
    colors = plt.cm.Reds(filtered_weights)
    
    x_pos = 0
    for token, weight, color in zip(filtered_tokens, filtered_weights, colors):
        text = ax.text(
            x_pos, 0.5, token + ' ',
            fontsize=12,
            ha='left', va='center',
            bbox=dict(boxstyle='round', facecolor=color, alpha=0.8)
        )
        
        # Get text width for positioning
        renderer = fig.canvas.get_renderer()
        bbox = text.get_window_extent(renderer=renderer)
        x_pos += bbox.width / fig.dpi / figsize[0] + 0.01
    
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.axis('off')
    ax.set_title('Token Importance (Attention Weights)', fontsize=14)
    
    plt.tight_layout()
    plt.show()


def extract_text_attention(model, text, tokenizer, device='cpu'):
    """
    Extract attention weights from text encoder.
    
    Args:
        model: Multimodal classifier
        text: Clinical narrative
        tokenizer: Tokenizer
        device: Device
    
    Returns:
        tokens, attention_weights
    """
    encoding = tokenizer(
        text,
        max_length=128,
        padding='max_length',
        truncation=True,
        return_tensors='pt'
    )
    
    input_ids = encoding['input_ids'].to(device)
    attention_mask = encoding['attention_mask'].to(device)
    
    # Get attention weights
    with torch.no_grad():
        embedding, attentions = model.text_encoder.get_attention_weights(
            input_ids, attention_mask
        )
    
    # Average attention across layers and heads
    # attentions is tuple of (batch, heads, seq, seq)
    last_layer_attention = attentions[-1][0]  # [heads, seq, seq]
    avg_attention = last_layer_attention.mean(dim=0)  # [seq, seq]
    
    # Get attention from CLS token to other tokens
    cls_attention = avg_attention[0].cpu().numpy()  # [seq]
    
    # Decode tokens
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    
    return tokens, cls_attention

## 3. Cross-Modal Attention Analysis

Visualize how the fusion module combines information from image and text modalities.

In [None]:
def visualize_cross_modal_attention(attention_info, figsize=(10, 4)):
    """
    Visualize cross-modal attention weights.
    
    Args:
        attention_info: Dictionary with attention weights from fusion
        figsize: Figure size
    """
    if attention_info is None:
        print("No attention info available (using concatenation fusion?)")
        return
    
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Image-to-text attention
    if 'image_to_text_attention' in attention_info:
        img2txt = attention_info['image_to_text_attention']
        if torch.is_tensor(img2txt):
            img2txt = img2txt.mean(dim=(0, 1)).cpu().numpy()  # Average across heads
        axes[0].bar(range(len(img2txt)), img2txt)
        axes[0].set_title('Image → Text Attention')
        axes[0].set_xlabel('Position')
        axes[0].set_ylabel('Attention Weight')
    
    # Text-to-image attention
    if 'text_to_image_attention' in attention_info:
        txt2img = attention_info['text_to_image_attention']
        if torch.is_tensor(txt2img):
            txt2img = txt2img.mean(dim=(0, 1)).cpu().numpy()
        axes[1].bar(range(len(txt2img)), txt2img)
        axes[1].set_title('Text → Image Attention')
        axes[1].set_xlabel('Position')
        axes[1].set_ylabel('Attention Weight')
    
    plt.tight_layout()
    plt.show()

## 4. Complete Explainability Pipeline

Run the complete explainability analysis for a sample.

In [None]:
def full_explainability_analysis(
    image_path,
    clinical_text,
    checkpoint_path=None,
    device='cpu'
):
    """
    Run complete explainability analysis.
    
    Args:
        image_path: Path to facial image
        clinical_text: Clinical narrative
        checkpoint_path: Path to model checkpoint
        device: Device to use
    """
    from transformers import AutoTokenizer
    
    config = get_config()
    
    # Load model
    model = MultimodalClassifier(config)
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location=device)
        model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    print("=" * 60)
    print("MULTIMODAL EXPLAINABILITY ANALYSIS")
    print("=" * 60)
    
    # 1. Make prediction
    print("\n1. PREDICTION")
    print("-" * 40)
    
    predictor = MultimodalPredictor(checkpoint_path=checkpoint_path, config=config)
    result = predictor.predict(image_path, clinical_text, top_k=3, return_embeddings=True)
    
    for pred in result['predictions']:
        print(f"   {pred['syndrome']}: {pred['probability_percent']:.1f}%")
    
    # 2. Grad-CAM
    print("\n2. GRAD-CAM VISUALIZATION")
    print("-" * 40)
    
    target_layer = model.cnn_encoder.get_attention_layer()
    gradcam = GradCAM(model, target_layer)
    gradcam.visualize(image_path, clinical_text, config)
    
    # 3. Text Attention
    print("\n3. TEXT ATTENTION")
    print("-" * 40)
    
    tokenizer = AutoTokenizer.from_pretrained(config.text_encoder.model_name)
    try:
        tokens, attention = extract_text_attention(model, clinical_text, tokenizer, device)
        visualize_text_attention(tokens, attention)
    except Exception as e:
        print(f"Could not visualize text attention: {e}")
    
    # 4. Embedding Space
    print("\n4. EMBEDDING ANALYSIS")
    print("-" * 40)
    
    if 'embeddings' in result:
        img_emb = np.array(result['embeddings']['image'])
        txt_emb = np.array(result['embeddings']['text'])
        fused_emb = np.array(result['embeddings']['fused'])
        
        print(f"   Image embedding norm: {np.linalg.norm(img_emb):.4f}")
        print(f"   Text embedding norm: {np.linalg.norm(txt_emb):.4f}")
        print(f"   Fused embedding norm: {np.linalg.norm(fused_emb):.4f}")
    
    print("\n" + "=" * 60)

## 5. Example Usage

Run the analysis on a sample case.

In [None]:
# Example usage (uncomment and modify paths as needed)

# sample_image = "../data/sample_image.jpg"
# sample_text = "Patient presents with hypertelorism, seizures, delayed speech, and characteristic facial features."
# checkpoint = "../checkpoints/multimodal_best.pt"

# full_explainability_analysis(
#     image_path=sample_image,
#     clinical_text=sample_text,
#     checkpoint_path=checkpoint,
#     device='cuda'
# )

## 6. Batch Analysis for Multiple Samples

Generate explanations for multiple samples and save results.

In [None]:
def batch_explainability_analysis(samples, output_dir, checkpoint_path=None):
    """
    Run explainability analysis on multiple samples.
    
    Args:
        samples: List of (image_path, clinical_text) tuples
        output_dir: Directory to save visualizations
        checkpoint_path: Model checkpoint path
    """
    import os
    
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    config = get_config()
    model = MultimodalClassifier(config)
    
    if checkpoint_path:
        checkpoint = torch.load(checkpoint_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
    
    model.eval()
    target_layer = model.cnn_encoder.get_attention_layer()
    gradcam = GradCAM(model, target_layer)
    
    results = []
    
    for idx, (image_path, text) in enumerate(samples):
        print(f"\nProcessing sample {idx + 1}/{len(samples)}...")
        
        try:
            # Generate Grad-CAM
            fig, axes = plt.subplots(1, 3, figsize=(15, 5))
            
            # This would be a simplified version
            # Full implementation would save the visualizations
            
            results.append({
                'sample_idx': idx,
                'status': 'success'
            })
            
        except Exception as e:
            results.append({
                'sample_idx': idx,
                'status': 'error',
                'error': str(e)
            })
    
    print(f"\nCompleted analysis for {len(samples)} samples")
    return results

---

## Summary

This notebook provides tools for:

1. **Grad-CAM**: Understand which facial regions influence the model's decision
2. **Text Attention**: See which clinical terms are most important
3. **Cross-Modal Attention**: Analyze how image and text information is combined
4. **Embedding Analysis**: Examine the learned representations

These tools help clinicians understand and trust the model's predictions.