In [None]:
# Import libraries
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from collections import OrderedDict
import cv2

# Project imports
from data import CAMUSDataset
from models import MambaUNet

# Configuration
DATA_ROOT = '../data/CAMUS'
CHECKPOINT_PATH = '../checkpoints/mamba_unet_best.pth'
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
NUM_CLASSES = 4

CLASS_NAMES = {
    0: 'Background',
    1: 'LV Endocardium',
    2: 'Myocardium',
    3: 'Left Atrium'
}

plt.style.use('seaborn-v0_8-whitegrid')
%matplotlib inline

print(f"Using device: {DEVICE}")

In [None]:
# Load model and data
model = MambaUNet(in_channels=1, num_classes=NUM_CLASSES)

if Path(CHECKPOINT_PATH).exists():
    checkpoint = torch.load(CHECKPOINT_PATH, map_location=DEVICE)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded checkpoint: {CHECKPOINT_PATH}")
else:
    print(f"Warning: Checkpoint not found, using random weights")

model = model.to(DEVICE).eval()

# Load test data
test_dataset = CAMUSDataset(root_dir=DATA_ROOT, split='test')
print(f"Test samples: {len(test_dataset)}")

## 1. Grad-CAM Visualization

Grad-CAM (Gradient-weighted Class Activation Mapping) shows which regions of the image the model focuses on for each class prediction.

In [None]:
class GradCAM:
    """Grad-CAM implementation for segmentation models."""
    
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        
        # Register hooks
        target_layer.register_forward_hook(self._save_activation)
        target_layer.register_full_backward_hook(self._save_gradient)
    
    def _save_activation(self, module, input, output):
        self.activations = output.detach()
    
    def _save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def __call__(self, input_tensor, target_class):
        """Generate Grad-CAM heatmap."""
        self.model.zero_grad()
        
        # Forward pass
        output = self.model(input_tensor)
        
        # Create target mask for specific class
        target = torch.zeros_like(output)
        target[:, target_class, :, :] = 1
        
        # Backward pass
        loss = (output * target).sum()
        loss.backward()
        
        # Generate heatmap
        weights = self.gradients.mean(dim=(2, 3), keepdim=True)
        cam = (weights * self.activations).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        
        # Resize to input size
        cam = F.interpolate(cam, size=input_tensor.shape[2:], mode='bilinear', align_corners=False)
        
        # Normalize
        cam = cam - cam.min()
        if cam.max() > 0:
            cam = cam / cam.max()
        
        return cam.squeeze().cpu().numpy()

In [None]:
# Get target layer (typically from encoder)
# Adjust based on your model architecture
try:
    target_layer = model.encoder.layer4  # Example - adjust for your model
except AttributeError:
    # Fallback - find a suitable layer
    target_layer = list(model.modules())[-10]  # Use a layer near the end

gradcam = GradCAM(model, target_layer)
print(f"Target layer: {target_layer.__class__.__name__}")

In [None]:
# Load sample and generate Grad-CAM for each class
sample = test_dataset[0]
image = sample['image']
gt = sample['mask'].numpy() if hasattr(sample['mask'], 'numpy') else sample['mask']

# Prepare input
if image.ndim == 2:
    input_tensor = image.unsqueeze(0).unsqueeze(0).float().to(DEVICE)
elif image.ndim == 3:
    input_tensor = image.unsqueeze(0).float().to(DEVICE)
else:
    input_tensor = image.float().to(DEVICE)

input_tensor.requires_grad = True

# Generate CAMs for each class
cams = {}
for class_idx in range(1, NUM_CLASSES):  # Skip background
    cam = gradcam(input_tensor, class_idx)
    cams[CLASS_NAMES[class_idx]] = cam
    print(f"Generated CAM for {CLASS_NAMES[class_idx]}")

In [None]:
# Visualize Grad-CAM
img_display = image.numpy() if hasattr(image, 'numpy') else image
if img_display.ndim == 3:
    img_display = img_display[0]

fig, axes = plt.subplots(2, NUM_CLASSES, figsize=(5*NUM_CLASSES, 10))

# Row 1: Original and Grad-CAMs
axes[0, 0].imshow(img_display, cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

for i, (name, cam) in enumerate(cams.items()):
    axes[0, i+1].imshow(cam, cmap='jet')
    axes[0, i+1].set_title(f'Grad-CAM: {name}')
    axes[0, i+1].axis('off')

# Row 2: Overlays
axes[1, 0].imshow(img_display, cmap='gray')
axes[1, 0].imshow(gt, cmap='jet', alpha=0.5)
axes[1, 0].set_title('Ground Truth')
axes[1, 0].axis('off')

for i, (name, cam) in enumerate(cams.items()):
    axes[1, i+1].imshow(img_display, cmap='gray')
    axes[1, i+1].imshow(cam, cmap='jet', alpha=0.5)
    axes[1, i+1].set_title(f'Overlay: {name}')
    axes[1, i+1].axis('off')

plt.suptitle('Grad-CAM Visualization', fontsize=14)
plt.tight_layout()
plt.show()

## 2. Mamba State Analysis

Analyze the internal states of Mamba blocks to understand how the model processes sequential information.

In [None]:
class MambaStateExtractor:
    """Extract and analyze Mamba SSM states."""
    
    def __init__(self, model):
        self.model = model
        self.states = {}
        self._register_hooks()
    
    def _register_hooks(self):
        """Register hooks on Mamba layers."""
        for name, module in self.model.named_modules():
            if 'mamba' in name.lower() or 'ssm' in name.lower():
                module.register_forward_hook(
                    lambda m, i, o, n=name: self._save_state(n, i, o)
                )
    
    def _save_state(self, name, input, output):
        """Save intermediate states."""
        if isinstance(output, tuple):
            self.states[name] = output[0].detach().cpu()
        else:
            self.states[name] = output.detach().cpu()
    
    def get_states(self, input_tensor):
        """Run forward pass and return collected states."""
        self.states = {}
        with torch.no_grad():
            _ = self.model(input_tensor)
        return self.states

In [None]:
# Extract Mamba states
state_extractor = MambaStateExtractor(model)
states = state_extractor.get_states(input_tensor.detach())

print(f"Found {len(states)} Mamba layers:")
for name, state in states.items():
    print(f"  {name}: shape {state.shape}")

In [None]:
# Visualize Mamba states
if states:
    # Select a few representative states
    state_names = list(states.keys())[:4]  # First 4 layers
    
    fig, axes = plt.subplots(2, len(state_names), figsize=(5*len(state_names), 10))
    if len(state_names) == 1:
        axes = axes.reshape(-1, 1)
    
    for i, name in enumerate(state_names):
        state = states[name]
        
        # Mean across channels
        if state.ndim == 4:
            state_mean = state[0].mean(0).numpy()
        elif state.ndim == 3:
            state_mean = state[0].mean(0).numpy()
        else:
            state_mean = state.squeeze().numpy()
        
        # Reshape if needed
        if state_mean.ndim == 1:
            side = int(np.sqrt(len(state_mean)))
            if side * side == len(state_mean):
                state_mean = state_mean.reshape(side, side)
            else:
                state_mean = state_mean.reshape(-1, 1)
        
        # Plot state
        axes[0, i].imshow(state_mean, cmap='viridis')
        axes[0, i].set_title(f'Layer: {name.split(".")[-1]}')
        axes[0, i].axis('off')
        
        # Plot histogram
        axes[1, i].hist(state.flatten().numpy(), bins=50, alpha=0.7)
        axes[1, i].set_xlabel('State Value')
        axes[1, i].set_ylabel('Count')
        axes[1, i].set_title(f'State Distribution')
    
    plt.suptitle('Mamba SSM State Analysis', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No Mamba states found. Check model architecture.")

## 3. Attention Map Visualization

For models with attention mechanisms, visualize where the model attends.

In [None]:
class AttentionExtractor:
    """Extract attention weights from model."""
    
    def __init__(self, model):
        self.model = model
        self.attention_maps = {}
        self._register_hooks()
    
    def _register_hooks(self):
        for name, module in self.model.named_modules():
            if 'attention' in name.lower() or 'attn' in name.lower():
                module.register_forward_hook(
                    lambda m, i, o, n=name: self._save_attention(n, m, i, o)
                )
    
    def _save_attention(self, name, module, input, output):
        # Try to get attention weights
        if hasattr(module, 'attention_weights'):
            self.attention_maps[name] = module.attention_weights.detach().cpu()
        elif isinstance(output, tuple) and len(output) > 1:
            # Many attention modules return (output, attention_weights)
            self.attention_maps[name] = output[1].detach().cpu()
    
    def get_attention(self, input_tensor):
        self.attention_maps = {}
        with torch.no_grad():
            _ = self.model(input_tensor)
        return self.attention_maps

In [None]:
# Extract attention maps
attention_extractor = AttentionExtractor(model)
attention_maps = attention_extractor.get_attention(input_tensor.detach())

if attention_maps:
    print(f"Found {len(attention_maps)} attention layers:")
    for name, attn in attention_maps.items():
        print(f"  {name}: shape {attn.shape}")
    
    # Visualize
    n_show = min(4, len(attention_maps))
    fig, axes = plt.subplots(1, n_show + 1, figsize=(5*(n_show+1), 5))
    
    axes[0].imshow(img_display, cmap='gray')
    axes[0].set_title('Input')
    axes[0].axis('off')
    
    for i, (name, attn) in enumerate(list(attention_maps.items())[:n_show]):
        # Average attention
        attn_avg = attn.mean(dim=(0, 1)).numpy()  # Average over batch and heads
        
        # Reshape to 2D if needed
        if attn_avg.ndim == 1:
            side = int(np.sqrt(len(attn_avg)))
            attn_avg = attn_avg[:side*side].reshape(side, side)
        elif attn_avg.ndim > 2:
            attn_avg = attn_avg.mean(axis=tuple(range(attn_avg.ndim - 2)))
        
        # Resize to image size
        attn_resized = cv2.resize(attn_avg, (img_display.shape[1], img_display.shape[0]))
        
        axes[i+1].imshow(img_display, cmap='gray')
        axes[i+1].imshow(attn_resized, cmap='hot', alpha=0.6)
        axes[i+1].set_title(f'Attention: {name.split(".")[-2]}')
        axes[i+1].axis('off')
    
    plt.suptitle('Attention Map Visualization', fontsize=14)
    plt.tight_layout()
    plt.show()
else:
    print("No attention maps found. Model may not use attention mechanism.")

## 4. Uncertainty Estimation

Use Monte Carlo Dropout to estimate prediction uncertainty.

In [None]:
def enable_dropout(model):
    """Enable dropout layers for MC Dropout."""
    for module in model.modules():
        if isinstance(module, nn.Dropout) or isinstance(module, nn.Dropout2d):
            module.train()

def mc_dropout_inference(model, input_tensor, n_samples=20):
    """Perform MC Dropout inference."""
    model.eval()
    enable_dropout(model)
    
    predictions = []
    with torch.no_grad():
        for _ in range(n_samples):
            output = model(input_tensor)
            prob = F.softmax(output, dim=1)
            predictions.append(prob.cpu())
    
    predictions = torch.stack(predictions)
    
    # Mean prediction
    mean_pred = predictions.mean(dim=0)
    
    # Uncertainty (entropy)
    entropy = -(mean_pred * torch.log(mean_pred + 1e-10)).sum(dim=1)
    
    # Variance
    variance = predictions.var(dim=0).mean(dim=1)
    
    return mean_pred, entropy, variance

In [None]:
# Run MC Dropout
mean_pred, entropy, variance = mc_dropout_inference(model, input_tensor.detach(), n_samples=20)

# Get prediction
pred = mean_pred.argmax(dim=1).squeeze().numpy()
entropy = entropy.squeeze().numpy()
variance = variance.squeeze().numpy()

In [None]:
# Visualize uncertainty
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Row 1
axes[0, 0].imshow(img_display, cmap='gray')
axes[0, 0].set_title('Input Image')
axes[0, 0].axis('off')

axes[0, 1].imshow(pred, cmap='jet')
axes[0, 1].set_title('Prediction')
axes[0, 1].axis('off')

axes[0, 2].imshow(gt, cmap='jet')
axes[0, 2].set_title('Ground Truth')
axes[0, 2].axis('off')

# Row 2: Uncertainty
im1 = axes[1, 0].imshow(entropy, cmap='hot')
axes[1, 0].set_title('Entropy (Uncertainty)')
axes[1, 0].axis('off')
plt.colorbar(im1, ax=axes[1, 0], fraction=0.046)

im2 = axes[1, 1].imshow(variance, cmap='hot')
axes[1, 1].set_title('Variance')
axes[1, 1].axis('off')
plt.colorbar(im2, ax=axes[1, 1], fraction=0.046)

# Error map
error = (pred != gt).astype(float)
axes[1, 2].imshow(error, cmap='Reds')
axes[1, 2].set_title('Prediction Errors')
axes[1, 2].axis('off')

plt.suptitle('Uncertainty Analysis', fontsize=14)
plt.tight_layout()
plt.show()

# Correlation between uncertainty and errors
correlation = np.corrcoef(entropy.flatten(), error.flatten())[0, 1]
print(f"Correlation between entropy and errors: {correlation:.3f}")

## 5. Feature Map Analysis

In [None]:
class FeatureExtractor:
    """Extract intermediate feature maps."""
    
    def __init__(self, model):
        self.model = model
        self.features = OrderedDict()
        self._register_hooks()
    
    def _register_hooks(self):
        for name, module in self.model.named_modules():
            # Hook on Conv layers and main blocks
            if isinstance(module, (nn.Conv2d, nn.ConvTranspose2d)):
                module.register_forward_hook(
                    lambda m, i, o, n=name: self._save_feature(n, o)
                )
    
    def _save_feature(self, name, output):
        self.features[name] = output.detach().cpu()
    
    def get_features(self, input_tensor):
        self.features = OrderedDict()
        with torch.no_grad():
            _ = self.model(input_tensor)
        return self.features

In [None]:
# Extract features
feature_extractor = FeatureExtractor(model)
features = feature_extractor.get_features(input_tensor.detach())

print(f"Extracted {len(features)} feature maps")

# Show feature map sizes
print("\nFeature map sizes (sample):")
for i, (name, feat) in enumerate(list(features.items())[:10]):
    print(f"  {name}: {feat.shape}")

In [None]:
# Visualize selected feature maps
# Select layers from different stages
layer_names = list(features.keys())
selected_layers = [layer_names[i] for i in [0, len(layer_names)//4, len(layer_names)//2, -1]]

fig, axes = plt.subplots(len(selected_layers), 6, figsize=(18, 4*len(selected_layers)))

for row, layer_name in enumerate(selected_layers):
    feat = features[layer_name][0]  # Remove batch dim
    n_channels = feat.shape[0]
    
    # Show first 5 channels + average
    for col in range(5):
        if col < n_channels:
            axes[row, col].imshow(feat[col].numpy(), cmap='viridis')
            axes[row, col].set_title(f'Ch {col}')
        axes[row, col].axis('off')
    
    # Average
    axes[row, 5].imshow(feat.mean(0).numpy(), cmap='viridis')
    axes[row, 5].set_title('Average')
    axes[row, 5].axis('off')
    
    # Row label
    axes[row, 0].set_ylabel(layer_name.split('.')[-1], fontsize=10)

plt.suptitle('Feature Maps at Different Depths', fontsize=14)
plt.tight_layout()
plt.show()

## 6. Clinical Interpretation Report

In [None]:
def generate_clinical_report(image, prediction, gt, patient_info=None):
    """Generate a clinical interpretation report."""
    
    report = []
    report.append("="*60)
    report.append("CARDIAC SEGMENTATION REPORT")
    report.append("="*60)
    
    if patient_info:
        report.append(f"\nPatient: {patient_info.get('id', 'Unknown')}")
        report.append(f"View: {patient_info.get('view', 'Unknown')}")
        report.append(f"Phase: {patient_info.get('phase', 'Unknown')}")
    
    # Calculate areas
    pixel_areas = {}
    for class_idx, class_name in CLASS_NAMES.items():
        if class_idx == 0:
            continue
        pred_area = np.sum(prediction == class_idx)
        gt_area = np.sum(gt == class_idx)
        pixel_areas[class_name] = {'pred': pred_area, 'gt': gt_area}
    
    report.append("\n" + "-"*40)
    report.append("STRUCTURE MEASUREMENTS (pixels)")
    report.append("-"*40)
    
    for name, areas in pixel_areas.items():
        diff = areas['pred'] - areas['gt']
        diff_pct = (diff / areas['gt'] * 100) if areas['gt'] > 0 else 0
        report.append(f"\n{name}:")
        report.append(f"  Predicted: {areas['pred']:,} pixels")
        report.append(f"  Ground Truth: {areas['gt']:,} pixels")
        report.append(f"  Difference: {diff:+,} ({diff_pct:+.1f}%)")
    
    # Calculate Dice scores
    report.append("\n" + "-"*40)
    report.append("SEGMENTATION ACCURACY")
    report.append("-"*40)
    
    dice_scores = []
    for class_idx, class_name in CLASS_NAMES.items():
        if class_idx == 0:
            continue
        pred_mask = (prediction == class_idx).astype(float)
        gt_mask = (gt == class_idx).astype(float)
        
        intersection = np.sum(pred_mask * gt_mask)
        union = np.sum(pred_mask) + np.sum(gt_mask)
        dice = 2 * intersection / union if union > 0 else 1.0
        dice_scores.append(dice)
        
        quality = "Excellent" if dice > 0.9 else "Good" if dice > 0.8 else "Fair" if dice > 0.7 else "Poor"
        report.append(f"  {class_name}: {dice:.3f} ({quality})")
    
    report.append(f"\n  Mean Dice: {np.mean(dice_scores):.3f}")
    
    # Clinical interpretation
    report.append("\n" + "-"*40)
    report.append("CLINICAL NOTES")
    report.append("-"*40)
    
    mean_dice = np.mean(dice_scores)
    if mean_dice > 0.9:
        report.append("  - Segmentation quality: EXCELLENT")
        report.append("  - Confidence: HIGH")
        report.append("  - Recommendation: Results suitable for clinical use")
    elif mean_dice > 0.8:
        report.append("  - Segmentation quality: GOOD")
        report.append("  - Confidence: MODERATE-HIGH")
        report.append("  - Recommendation: Review recommended before clinical use")
    else:
        report.append("  - Segmentation quality: FAIR")
        report.append("  - Confidence: MODERATE")
        report.append("  - Recommendation: Manual verification required")
    
    report.append("\n" + "="*60)
    
    return "\n".join(report)

In [None]:
# Generate report for sample
patient_info = {
    'id': test_dataset.patients[0].patient_id if hasattr(test_dataset, 'patients') else 'Unknown',
    'view': sample.get('view', '4CH'),
    'phase': sample.get('phase', 'ED')
}

report = generate_clinical_report(img_display, pred, gt, patient_info)
print(report)

In [None]:
# Create comprehensive visualization for report
fig = plt.figure(figsize=(16, 12))

# Main image grid
ax1 = fig.add_subplot(2, 3, 1)
ax1.imshow(img_display, cmap='gray')
ax1.set_title('Input Echocardiogram')
ax1.axis('off')

ax2 = fig.add_subplot(2, 3, 2)
ax2.imshow(img_display, cmap='gray')
ax2.imshow(pred, cmap='jet', alpha=0.5)
ax2.set_title('Model Prediction')
ax2.axis('off')

ax3 = fig.add_subplot(2, 3, 3)
ax3.imshow(img_display, cmap='gray')
ax3.imshow(gt, cmap='jet', alpha=0.5)
ax3.set_title('Ground Truth')
ax3.axis('off')

# Uncertainty
ax4 = fig.add_subplot(2, 3, 4)
ax4.imshow(entropy, cmap='hot')
ax4.set_title('Prediction Uncertainty')
ax4.axis('off')

# Error map
ax5 = fig.add_subplot(2, 3, 5)
ax5.imshow(error, cmap='Reds')
ax5.set_title('Error Regions')
ax5.axis('off')

# Legend
ax6 = fig.add_subplot(2, 3, 6)
colors = ['gray', 'red', 'green', 'blue']
for i, (label, name) in enumerate(CLASS_NAMES.items()):
    ax6.bar(i, 1, color=plt.cm.jet(label / 3), label=name)
ax6.set_xticks(range(4))
ax6.set_xticklabels([CLASS_NAMES[i] for i in range(4)], rotation=45, ha='right')
ax6.set_title('Class Legend')
ax6.set_ylabel('Class')

plt.suptitle(f'Clinical Explainability Report - Patient: {patient_info["id"]}', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated several explainability techniques:

1. **Grad-CAM**: Shows which image regions drive predictions for each class
2. **Mamba States**: Analyzes internal SSM states to understand sequential processing
3. **Attention Maps**: Visualizes where the model focuses (if applicable)
4. **Uncertainty Estimation**: Uses MC Dropout to estimate prediction confidence
5. **Feature Maps**: Shows intermediate representations at different network depths
6. **Clinical Reports**: Generates human-readable interpretation summaries

These techniques help clinicians understand and trust model predictions, identify potential failure cases, and make informed decisions about when to rely on automated segmentation.