<a href="https://colab.research.google.com/github/Abhishek-s-kumar/stroke/blob/main/stoke_new_all_data.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ==========================================
# SECTION 1: INSTALLATION AND SETUP
# ==========================================

# Install required packages
!pip -q install kaggle timm einops albumentations==1.4.6 torchmetrics wandb kagglehub grad-cam pandas scikit-learn

# Import libraries
import kagglehub
import os
import glob
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import (confusion_matrix, ConfusionMatrixDisplay, roc_curve, auc,
                           precision_recall_curve, average_precision_score,
                           classification_report, accuracy_score)

print("✅ All packages imported successfully!")

✅ All packages imported successfully!


In [None]:


# ==========================================
# SECTION 2: DATASET DOWNLOAD AND SETUP
# ==========================================

# Download dataset using kagglehub
print("📥 Downloading dataset...")
path = kagglehub.dataset_download("turkertuncer/multimodal-stroke-image-dataset")
print("✅ Dataset downloaded to:", path)

# Update DATA_DIR to use the downloaded path
DATA_DIR = os.path.join(path, "deep")

# Verify dataset structure
print("\n📁 Dataset structure:")
for root, dirs, files in os.walk(DATA_DIR):
    level = root.replace(DATA_DIR, "").count(os.sep)
    indent = " " * 2 * level
    print(f"{indent}{os.path.basename(root)}/")
    for d in dirs:
        print(f"{indent}  {d}/")
    if level > 2:
        break


📥 Downloading dataset...
✅ Dataset downloaded to: /kaggle/input/multimodal-stroke-image-dataset

📁 Dataset structure:
deep/
  test/
  train/
  test/
    strokeMR/
    normalBT/
    strokeBT/
    normalMR/
    strokeMR/
    normalBT/
    strokeBT/
    normalMR/
  train/
    2- Control/
    1- Stroke/
    2- Control/
      NormalBT/
      NormalMR/
      NormalBT/


In [None]:

# ==========================================
# SECTION 3: DATASET CLASS AND UTILITIES
# ==========================================

def gather_images(folders):
    """Gather all image paths from given folders"""
    paths = []
    for folder in folders:
        if os.path.exists(folder):
            for ext in ('*.png', '*.jpg', '*.jpeg'):
                paths.extend(glob.glob(os.path.join(folder, '**', ext), recursive=True))
    return paths

class CustomStrokeDataset(Dataset):
    def __init__(self, folders_dict, transform):
        self.paths = []
        self.labels = []
        self.transform = transform

        for label_name, folders in folders_dict.items():
            label = 1 if "stroke" in label_name.lower() else 0
            imgs = gather_images(folders)
            self.paths.extend(imgs)
            self.labels.extend([label]*len(imgs))
            print(f"✅ Loaded {len(imgs)} images for class '{label_name}' (label={label})")

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img = Image.open(self.paths[idx]).convert("RGB")
        img = self.transform(img)
        label = torch.tensor(self.labels[idx], dtype=torch.long)
        return img, label

print("✅ Dataset utilities defined")

✅ Dataset utilities defined


In [None]:


# ==========================================
# SECTION 4: MODEL ARCHITECTURE
# ==========================================

import timm

class ResNet50_ViT(nn.Module):
    def __init__(self):
        super().__init__()
        # Load pre-trained models
        self.resnet = timm.create_model('resnet50', pretrained=True)
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True)

        # Remove original classifiers
        self.resnet.global_pool = nn.Identity()
        self.resnet.fc = nn.Identity()
        self.vit.head = nn.Identity()

        # Add custom components
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Sequential(
            nn.Linear(2048+768, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3),
            nn.Linear(512, 2)
        )

    def forward(self, x):
        # ResNet features: (B, 2048, H, W) -> (B, 2048)
        r_feat = self.pool(self.resnet.forward_features(x)).flatten(1)

        # ViT features: (B, N, 768), take CLS token
        v_feat = self.vit.forward_features(x)
        if v_feat.ndim == 3:  # (B, N, C)
            v_feat = v_feat[:, 0, :]  # CLS token

        # Concatenate and classify
        combined = torch.cat([r_feat, v_feat], dim=1)
        return self.classifier(combined)

print("✅ Model architecture defined")

✅ Model architecture defined


In [None]:


# ==========================================
# SECTION 5: DATA PREPARATION
# ==========================================

# Configuration
IMG_SIZE = 224
BATCH_SIZE = 32

# Data transforms
train_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

val_tf = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Dataset folders
TRAIN_FOLDERS = {
    "stroke": [f"{DATA_DIR}/train/1- Stroke"],
    "normal": [f"{DATA_DIR}/train/2- Control"]
}

VAL_FOLDERS = {
    "stroke": [f"{DATA_DIR}/test/strokeBT", f"{DATA_DIR}/test/strokeMR"],
    "normal": [f"{DATA_DIR}/test/normalBT", f"{DATA_DIR}/test/normalMR"]
}

# Create datasets
print("📊 Creating datasets...")
train_ds = CustomStrokeDataset(TRAIN_FOLDERS, train_tf)
val_ds = CustomStrokeDataset(VAL_FOLDERS, val_tf)

# Create data loaders
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"✅ Train samples: {len(train_ds)} | Val samples: {len(val_ds)}")

📊 Creating datasets...
✅ Loaded 393 images for class 'stroke' (label=1)
✅ Loaded 1984 images for class 'normal' (label=0)
✅ Loaded 119 images for class 'stroke' (label=1)
✅ Loaded 657 images for class 'normal' (label=0)
✅ Train samples: 2377 | Val samples: 776


In [None]:


# ==========================================
# SECTION 6: MODEL TRAINING
# ==========================================

# Initialize model and training components
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"🔧 Using device: {device}")

model = ResNet50_ViT().to(device)

from torchmetrics.classification import BinaryAccuracy, AUROC

# Training configuration
EPOCHS = 35
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
acc_metric = BinaryAccuracy().to(device)
auc_metric = AUROC(task='binary').to(device)

# Training tracking
train_losses = []
val_losses = []
val_accuracies = []
val_aurocs = []

print("🚀 Starting training...")

for epoch in range(EPOCHS):
    # Training phase
    model.train()
    train_loss = 0
    train_batches = 0

    for xb, yb in train_dl:
        xb, yb = xb.to(device), yb.to(device)

        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()

        train_loss += loss.item() * xb.size(0)
        train_batches += 1

    train_loss /= len(train_ds)
    train_losses.append(train_loss)

    # Validation phase
    model.eval()
    val_loss = 0
    all_val_logits = []
    all_val_targets = []

    with torch.no_grad():
        for xb, yb in val_dl:
            xb, yb = xb.to(device), yb.to(device)
            logits = model(xb)
            val_loss += criterion(logits, yb).item() * xb.size(0)
            all_val_logits.append(logits.softmax(1)[:, 1].detach().cpu())
            all_val_targets.append(yb.detach().cpu())

    val_loss /= len(val_ds)
    val_losses.append(val_loss)

    # Calculate metrics
    all_val_logits = torch.cat(all_val_logits)
    all_val_targets = torch.cat(all_val_targets)

    acc = acc_metric(all_val_logits, all_val_targets)
    auc_score = auc_metric(all_val_logits, all_val_targets)

    val_accuracies.append(acc.item())
    val_aurocs.append(auc_score.item())

    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:2d}/{EPOCHS} - Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Acc: {acc:.4f} | AUROC: {auc_score:.4f}")

    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), f"model_epoch_{epoch+1}.pth")

print("✅ Training completed!")

🔧 Using device: cpu


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


🚀 Starting training...


In [None]:


# ==========================================
# SECTION 7: TRAINING VISUALIZATION
# ==========================================

# Plot training curves
plt.figure(figsize=(15, 5))

plt.subplot(1, 3, 1)
plt.plot(range(1, EPOCHS+1), train_losses, 'b-', label='Train Loss', linewidth=2)
plt.plot(range(1, EPOCHS+1), val_losses, 'r-', label='Val Loss', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 2)
plt.plot(range(1, EPOCHS+1), val_accuracies, 'g-', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Validation Accuracy')
plt.grid(True, alpha=0.3)

plt.subplot(1, 3, 3)
plt.plot(range(1, EPOCHS+1), val_aurocs, 'purple', linewidth=2)
plt.xlabel('Epoch')
plt.ylabel('AUROC')
plt.title('Validation AUROC')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"📊 Final Training Results:")
print(f"   Best Validation Accuracy: {max(val_accuracies):.4f}")
print(f"   Best Validation AUROC: {max(val_aurocs):.4f}")


In [None]:

# ==========================================
# SECTION 8: COMPLETE DATASET EVALUATION
# ==========================================

print("\n" + "="*60)
print("🔍 EVALUATING ON COMPLETE VALIDATION DATASET")
print("="*60)

# Final evaluation on all validation data
model.eval()
all_logits = []
all_preds = []
all_targets = []
all_image_paths = []

with torch.no_grad():
    for batch_idx, (xb, yb) in enumerate(val_dl):
        xb, yb = xb.to(device), yb.to(device)
        logits = model(xb)
        probs = logits.softmax(1)[:, 1]  # Probability of stroke class
        preds = logits.argmax(dim=1)

        # Store results
        all_logits.extend(probs.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())
        all_targets.extend(yb.cpu().numpy())

        # Store corresponding image paths
        start_idx = batch_idx * BATCH_SIZE
        end_idx = min(start_idx + len(xb), len(val_ds))
        batch_paths = val_ds.paths[start_idx:end_idx]
        all_image_paths.extend(batch_paths)

# Convert to numpy arrays
all_logits = np.array(all_logits)
all_preds = np.array(all_preds)
all_targets = np.array(all_targets)

print(f"✅ Processed {len(all_targets)} validation images")


In [None]:

# ==========================================
# SECTION 9: PERFORMANCE METRICS
# ==========================================

# Calculate overall performance
overall_accuracy = accuracy_score(all_targets, all_preds)

print("\n📈 OVERALL PERFORMANCE METRICS:")
print("-" * 40)
print(f"Overall Accuracy: {overall_accuracy:.4f}")

# Detailed classification report
print(f"\n📋 DETAILED CLASSIFICATION REPORT:")
print(classification_report(all_targets, all_preds, target_names=['Normal', 'Stroke']))


In [None]:

# ==========================================
# SECTION 10: VISUALIZATION - CONFUSION MATRIX
# ==========================================

# Confusion Matrix
plt.figure(figsize=(8, 6))
cm = confusion_matrix(all_targets, all_preds)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=["Normal", "Stroke"])
disp.plot(cmap='Blues', values_format='d')
plt.title(f"Confusion Matrix - Complete Dataset\n(n={len(all_targets)} images)", fontsize=14)
plt.show()

# Print confusion matrix values
tn, fp, fn, tp = cm.ravel()
print(f"📊 Confusion Matrix Breakdown:")
print(f"   True Negatives (Normal → Normal): {tn}")
print(f"   False Positives (Normal → Stroke): {fp}")
print(f"   False Negatives (Stroke → Normal): {fn}")
print(f"   True Positives (Stroke → Stroke): {tp}")


In [None]:

# ==========================================
# SECTION 11: ROC AND PR CURVES
# ==========================================

# Calculate curves
fpr, tpr, roc_thresholds = roc_curve(all_targets, all_logits)
roc_auc = auc(fpr, tpr)

precision, recall, pr_thresholds = precision_recall_curve(all_targets, all_logits)
ap_score = average_precision_score(all_targets, all_logits)

# Plot curves
plt.figure(figsize=(15, 5))

# ROC Curve
plt.subplot(1, 3, 1)
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC Curve (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.6)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend(loc="lower right")
plt.grid(True, alpha=0.3)

# Precision-Recall Curve
plt.subplot(1, 3, 2)
plt.plot(recall, precision, color='purple', lw=2, label=f'PR Curve (AP = {ap_score:.3f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="best")
plt.grid(True, alpha=0.3)

# Probability distribution
plt.subplot(1, 3, 3)
stroke_probs = all_logits[all_targets == 1]
normal_probs = all_logits[all_targets == 0]

plt.hist(normal_probs, bins=30, alpha=0.7, label='Normal', color='blue', density=True)
plt.hist(stroke_probs, bins=30, alpha=0.7, label='Stroke', color='red', density=True)
plt.xlabel('Predicted Stroke Probability')
plt.ylabel('Density')
plt.title('Prediction Probability Distribution')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"📈 Curve Metrics:")
print(f"   ROC AUC: {roc_auc:.4f}")
print(f"   Average Precision: {ap_score:.4f}")


In [None]:

# ==========================================
# SECTION 12: GRAD-CAM VISUALIZATION
# ==========================================

# Import Grad-CAM (separate cell to handle import issues)
try:
    import subprocess
    import sys

    # Install grad-cam if not available
    try:
        from pytorch_grad_cam import GradCAM
    except ImportError:
        print("📦 Installing pytorch-grad-cam...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", "grad-cam"])
        from pytorch_grad_cam import GradCAM

    from pytorch_grad_cam.utils.image import show_cam_on_image
    from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
    gradcam_available = True
    print("✅ Grad-CAM imported successfully")
except ImportError as e:
    print(f"⚠️  Grad-CAM not available: {e}. Skipping visualizations.")
    gradcam_available = False

if gradcam_available:
    def get_vit_attention_map(model, input_tensor, target_layer_name='vit'):
        """Extract attention map from Vision Transformer"""
        model.eval()

        # Hook to capture attention weights
        attention_weights = []

        def attention_hook(module, input, output):
            # For ViT, we want the attention weights from the last layer
            if hasattr(module, 'attn') and hasattr(module.attn, 'attention_weights'):
                attention_weights.append(module.attn.attention_weights)

        # Register hook on the last transformer block
        hooks = []
        for name, module in model.vit.named_modules():
            if 'blocks' in name and 'attn' in name:
                hooks.append(module.register_forward_hook(attention_hook))

        # Forward pass
        with torch.no_grad():
            _ = model(input_tensor)

        # Clean up hooks
        for hook in hooks:
            hook.remove()

        # If we got attention weights, process them
        if attention_weights:
            # Take the last layer's attention
            attn = attention_weights[-1]  # [batch, heads, tokens, tokens]

            # Average over heads and take CLS token attention to all patches
            attn = attn.mean(dim=1)[0, 0, 1:].detach().cpu().numpy()  # Remove CLS token

            # Reshape to spatial dimensions (14x14 for patch16 on 224x224)
            grid_size = int(np.sqrt(len(attn)))
            attn_map = attn.reshape(grid_size, grid_size)

            return attn_map
        else:
            return None

    def create_hybrid_gradcam_visualization(model, dataset, num_samples=8):
        """Create comprehensive visualization showing ResNet, ViT, and combined results"""

        # Set up Grad-CAM for ResNet component
        resnet_target_layer = model.resnet.layer4[-1]
        resnet_cam = GradCAM(model=model, target_layers=[resnet_target_layer])

        # Select diverse samples
        stroke_indices = [i for i, label in enumerate(dataset.labels) if label == 1]
        normal_indices = [i for i, label in enumerate(dataset.labels) if label == 0]

        # Mix samples from both classes
        sample_indices = []
        if len(stroke_indices) >= num_samples//2:
            sample_indices.extend(np.random.choice(stroke_indices, num_samples//2, replace=False))
        else:
            sample_indices.extend(stroke_indices)

        remaining = num_samples - len(sample_indices)
        if len(normal_indices) >= remaining:
            sample_indices.extend(np.random.choice(normal_indices, remaining, replace=False))
        else:
            sample_indices.extend(normal_indices)

        sample_indices = sample_indices[:num_samples]

        print(f"🎨 Generating Enhanced Grad-CAM for {len(sample_indices)} samples...")
        print("   - ResNet feature maps")
        print("   - Vision Transformer attention")
        print("   - Combined visualization")

        # Create figure with subplots for each component
        cols = 4  # Original, ResNet CAM, ViT Attention, Combined
        rows = len(sample_indices)

        plt.figure(figsize=(20, 5 * rows))

        for idx, sample_idx in enumerate(sample_indices):
            # Get sample info
            img_path = dataset.paths[sample_idx]
            true_label = dataset.labels[sample_idx]

            # Get predictions if available
            if sample_idx < len(all_logits):
                pred_prob = all_logits[sample_idx]
                pred_label = all_preds[sample_idx]
            else:
                # Make prediction for this sample
                model.eval()
                raw_img = Image.open(img_path).convert("RGB")
                transformed_img = val_tf(raw_img).unsqueeze(0).to(device)
                with torch.no_grad():
                    logits = model(transformed_img)
                    pred_prob = logits.softmax(1)[0, 1].item()
                    pred_label = logits.argmax(dim=1)[0].item()

            # Load and preprocess image
            raw_img = Image.open(img_path).convert("RGB")
            img_np = np.array(raw_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.float32) / 255.0
            transformed_img = val_tf(raw_img).unsqueeze(0).to(device)

            # 1. Original Image
            plt.subplot(rows, cols, idx * cols + 1)
            plt.imshow(img_np)
            plt.axis('off')
            plt.title(f"Original\n{os.path.basename(img_path)[:15]}...", fontsize=10)

            # 2. ResNet Grad-CAM
            targets = [ClassifierOutputTarget(1)]  # Target stroke class
            resnet_cam_map = resnet_cam(input_tensor=transformed_img, targets=targets)
            resnet_cam_map = resnet_cam_map[0, :]

            plt.subplot(rows, cols, idx * cols + 2)
            resnet_visualization = show_cam_on_image(img_np, resnet_cam_map, use_rgb=True)
            plt.imshow(resnet_visualization)
            plt.axis('off')
            plt.title("ResNet Features", fontsize=10)

            # 3. Vision Transformer Attention
            plt.subplot(rows, cols, idx * cols + 3)

            # Try to get ViT attention map
            vit_attention = get_vit_attention_map(model, transformed_img)

            if vit_attention is not None:
                # Resize attention map to image size
                from scipy.ndimage import zoom
                attention_resized = zoom(vit_attention,
                                       (IMG_SIZE / vit_attention.shape[0],
                                        IMG_SIZE / vit_attention.shape[1]))

                # Normalize attention map
                attention_resized = (attention_resized - attention_resized.min()) / \
                                  (attention_resized.max() - attention_resized.min())

                # Create visualization
                vit_visualization = show_cam_on_image(img_np, attention_resized, use_rgb=True)
                plt.imshow(vit_visualization)
                plt.title("ViT Attention", fontsize=10)
            else:
                # Fallback: show a simple attention-style visualization
                # Use gradient-based approach for ViT component
                model.eval()
                transformed_img.requires_grad_(True)

                # Forward pass and get gradients
                logits = model(transformed_img)
                score = logits[0, 1]  # Stroke class score
                score.backward()

                # Get gradients
                gradients = transformed_img.grad.data[0]

                # Create simple saliency map
                saliency = gradients.abs().mean(dim=0).cpu().numpy()
                saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())

                vit_visualization = show_cam_on_image(img_np, saliency, use_rgb=True)
                plt.imshow(vit_visualization)
                plt.title("ViT Gradients", fontsize=10)

                # Clear gradients
                transformed_img.grad = None

            plt.axis('off')

            # 4. Combined Visualization
            plt.subplot(rows, cols, idx * cols + 4)

            # Combine ResNet CAM with ViT attention
            if vit_attention is not None:
                combined_map = 0.6 * resnet_cam_map + 0.4 * attention_resized
            else:
                combined_map = 0.7 * resnet_cam_map + 0.3 * saliency

            # Normalize combined map
            combined_map = (combined_map - combined_map.min()) / \
                          (combined_map.max() - combined_map.min())

            combined_visualization = show_cam_on_image(img_np, combined_map, use_rgb=True)
            plt.imshow(combined_visualization)
            plt.axis('off')

            # Create detailed title for combined view
            true_class = "Stroke" if true_label == 1 else "Normal"
            pred_class = "Stroke" if pred_label == 1 else "Normal"
            confidence = pred_prob if pred_label == 1 else (1 - pred_prob)

            title = f"Combined View\nTrue: {true_class}\nPred: {pred_class} ({confidence:.3f})"
            plt.title(title, fontsize=10)

        plt.suptitle('Enhanced Grad-CAM: ResNet + Vision Transformer Analysis', fontsize=16)
        plt.tight_layout()
        plt.show()

    def create_attention_summary(model, dataset, num_samples=4):
        """Create a summary view showing how different components contribute"""

        print(f"\n🔍 Creating Component Analysis Summary...")

        # Select samples (2 stroke, 2 normal)
        stroke_indices = [i for i, label in enumerate(dataset.labels) if label == 1]
        normal_indices = [i for i, label in enumerate(dataset.labels) if label == 0]

        selected_indices = []
        if len(stroke_indices) >= 2:
            selected_indices.extend(np.random.choice(stroke_indices, 2, replace=False))
        if len(normal_indices) >= 2:
            selected_indices.extend(np.random.choice(normal_indices, 2, replace=False))

        selected_indices = selected_indices[:num_samples]

        fig, axes = plt.subplots(num_samples, 5, figsize=(25, 6 * num_samples))
        if num_samples == 1:
            axes = axes.reshape(1, -1)

        for idx, sample_idx in enumerate(selected_indices):
            img_path = dataset.paths[sample_idx]
            true_label = dataset.labels[sample_idx]

            # Load image
            raw_img = Image.open(img_path).convert("RGB")
            img_np = np.array(raw_img.resize((IMG_SIZE, IMG_SIZE))).astype(np.float32) / 255.0
            transformed_img = val_tf(raw_img).unsqueeze(0).to(device)

            # Get model prediction and feature analysis
            model.eval()
            with torch.no_grad():
                # Forward pass through components
                resnet_features = model.pool(model.resnet.forward_features(transformed_img)).flatten(1)
                vit_features = model.vit.forward_features(transformed_img)
                if vit_features.ndim == 3:
                    vit_features = vit_features[:, 0, :]

                # Combined features
                combined_features = torch.cat([resnet_features, vit_features], dim=1)
                final_logits = model.classifier(combined_features)
                final_prob = final_logits.softmax(1)[0, 1].item()

            # Original image
            axes[idx, 0].imshow(img_np)
            axes[idx, 0].set_title(f"Original\n{os.path.basename(img_path)[:20]}", fontsize=10)
            axes[idx, 0].axis('off')

            # ResNet CAM
            resnet_cam = GradCAM(model=model, target_layers=[model.resnet.layer4[-1]])
            targets = [ClassifierOutputTarget(1)]
            resnet_map = resnet_cam(input_tensor=transformed_img, targets=targets)[0]

            resnet_vis = show_cam_on_image(img_np, resnet_map, use_rgb=True)
            axes[idx, 1].imshow(resnet_vis)
            axes[idx, 1].set_title("ResNet Focus", fontsize=10)
            axes[idx, 1].axis('off')

            # ViT attention/gradients
            vit_attention = get_vit_attention_map(model, transformed_img)
            if vit_attention is not None:
                from scipy.ndimage import zoom
                attention_resized = zoom(vit_attention,
                                       (IMG_SIZE / vit_attention.shape[0],
                                        IMG_SIZE / vit_attention.shape[1]))
                attention_resized = (attention_resized - attention_resized.min()) / \
                                  (attention_resized.max() - attention_resized.min())
                vit_vis = show_cam_on_image(img_np, attention_resized, use_rgb=True)
                title = "ViT Attention"
            else:
                # Fallback to gradients
                transformed_img.requires_grad_(True)
                logits = model(transformed_img)
                score = logits[0, 1]
                score.backward()
                gradients = transformed_img.grad.data[0]
                saliency = gradients.abs().mean(dim=0).cpu().numpy()
                saliency = (saliency - saliency.min()) / (saliency.max() - saliency.min())
                vit_vis = show_cam_on_image(img_np, saliency, use_rgb=True)
                title = "ViT Gradients"
                transformed_img.grad = None

            axes[idx, 2].imshow(vit_vis)
            axes[idx, 2].set_title(title, fontsize=10)
            axes[idx, 2].axis('off')

            # Combined view
            if vit_attention is not None:
                combined_map = 0.6 * resnet_map + 0.4 * attention_resized
            else:
                combined_map = 0.7 * resnet_map + 0.3 * saliency

            combined_map = (combined_map - combined_map.min()) / \
                          (combined_map.max() - combined_map.min())
            combined_vis = show_cam_on_image(img_np, combined_map, use_rgb=True)

            axes[idx, 3].imshow(combined_vis)
            axes[idx, 3].set_title("Combined", fontsize=10)
            axes[idx, 3].axis('off')

            # Prediction summary
            axes[idx, 4].text(0.1, 0.8, f"True Label: {'Stroke' if true_label == 1 else 'Normal'}",
                             fontsize=12, transform=axes[idx, 4].transAxes)
            axes[idx, 4].text(0.1, 0.6, f"Prediction: {'Stroke' if final_prob > 0.5 else 'Normal'}",
                             fontsize=12, transform=axes[idx, 4].transAxes)
            axes[idx, 4].text(0.1, 0.4, f"Confidence: {final_prob:.3f}",
                             fontsize=12, transform=axes[idx, 4].transAxes)

            # Feature contribution analysis
            resnet_norm = torch.norm(resnet_features).item()
            vit_norm = torch.norm(vit_features).item()
            total_norm = resnet_norm + vit_norm

            axes[idx, 4].text(0.1, 0.2, f"ResNet contrib: {resnet_norm/total_norm:.2f}",
                             fontsize=10, transform=axes[idx, 4].transAxes)
            axes[idx, 4].text(0.1, 0.1, f"ViT contrib: {vit_norm/total_norm:.2f}",
                             fontsize=10, transform=axes[idx, 4].transAxes)

            axes[idx, 4].set_xlim(0, 1)
            axes[idx, 4].set_ylim(0, 1)
            axes[idx, 4].axis('off')
            axes[idx, 4].set_title("Analysis", fontsize=10)

        plt.suptitle('Component Analysis: ResNet vs Vision Transformer Contributions', fontsize=16)
        plt.tight_layout()
        plt.show()

    print("\n" + "="*60)
    print("🎨 ENHANCED GRAD-CAM VISUALIZATIONS")
    print("="*60)

    # Generate comprehensive visualizations
    print("🔍 Generating hybrid ResNet-ViT visualizations...")
    create_hybrid_gradcam_visualization(model, val_ds, num_samples=6)

    print("\n🔬 Generating component analysis summary...")
    create_attention_summary(model, val_ds, num_samples=4)

    print("✅ Enhanced Grad-CAM analysis completed!")

else:
    print("⚠️ Skipping Grad-CAM visualizations due to import issues.")
    print("💡 To manually install: !pip install grad-cam")

In [None]:

# ==========================================
# SECTION 13: ANALYSIS BY IMAGE TYPE
# ==========================================

def analyze_by_modality(image_paths, predictions, targets, probabilities):
    """Analyze performance by imaging modality (BT vs MR)"""

    bt_indices = [i for i, path in enumerate(image_paths)
                  if any(x in path.upper() for x in ['BT', 'BRAIN', 'TOMOGRAPHY'])]
    mr_indices = [i for i, path in enumerate(image_paths)
                  if any(x in path.upper() for x in ['MR', 'MAGNETIC', 'RESONANCE'])]

    print("\n" + "="*60)
    print("🔬 ANALYSIS BY IMAGING MODALITY")
    print("="*60)

    modalities = [
        ("Brain Tomography (BT)", bt_indices),
        ("Magnetic Resonance (MR)", mr_indices)
    ]

    results = []

    for name, indices in modalities:
        if len(indices) == 0:
            continue

        subset_preds = np.array([predictions[i] for i in indices])
        subset_targets = np.array([targets[i] for i in indices])
        subset_probs = np.array([probabilities[i] for i in indices])

        accuracy = accuracy_score(subset_targets, subset_preds)

        # Calculate AUROC if both classes are present
        unique_targets = np.unique(subset_targets)
        if len(unique_targets) > 1:
            fpr_sub, tpr_sub, _ = roc_curve(subset_targets, subset_probs)
            subset_auc = auc(fpr_sub, tpr_sub)
            subset_ap = average_precision_score(subset_targets, subset_probs)
        else:
            subset_auc = "N/A"
            subset_ap = "N/A"

        print(f"\n📊 {name} Images:")
        print(f"   Sample Count: {len(indices)}")
        print(f"   Accuracy: {accuracy:.4f}")
        print(f"   AUROC: {subset_auc}")
        print(f"   Average Precision: {subset_ap}")

        if len(unique_targets) > 1:
            print(f"   Classification Breakdown:")
            report = classification_report(subset_targets, subset_preds,
                                         target_names=['Normal', 'Stroke'],
                                         output_dict=True)
            for class_name in ['Normal', 'Stroke']:
                if class_name.lower() in report:
                    metrics = report[class_name.lower()]
                    print(f"     {class_name}: Precision={metrics['precision']:.3f}, "
                          f"Recall={metrics['recall']:.3f}, F1={metrics['f1-score']:.3f}")

        results.append({
            'modality': name,
            'count': len(indices),
            'accuracy': accuracy,
            'auroc': subset_auc if isinstance(subset_auc, str) else f"{subset_auc:.4f}",
            'ap': subset_ap if isinstance(subset_ap, str) else f"{subset_ap:.4f}"
        })

    return results

# Perform modality analysis
modality_results = analyze_by_modality(all_image_paths, all_preds, all_targets, all_logits)


In [None]:

# ==========================================
# SECTION 14: SAVE COMPREHENSIVE RESULTS
# ==========================================

print("\n" + "="*60)
print("💾 SAVING COMPREHENSIVE RESULTS")
print("="*60)

# Create detailed results DataFrame
results_data = {
    'image_path': all_image_paths,
    'filename': [os.path.basename(p) for p in all_image_paths],
    'true_label': all_targets,
    'predicted_label': all_preds,
    'stroke_probability': all_logits,
    'correct_prediction': (all_targets == all_preds).astype(int),
    'confidence': np.where(all_preds == 1, all_logits, 1 - all_logits)
}

results_df = pd.DataFrame(results_data)

# Add derived columns
results_df['image_type'] = results_df['image_path'].apply(
    lambda x: 'BT' if any(term in x.upper() for term in ['BT', 'BRAIN', 'TOMOGRAPHY'])
             else 'MR' if any(term in x.upper() for term in ['MR', 'MAGNETIC', 'RESONANCE'])
             else 'Unknown'
)

results_df['true_class'] = results_df['true_label'].map({0: 'Normal', 1: 'Stroke'})
results_df['predicted_class'] = results_df['predicted_label'].map({0: 'Normal', 1: 'Stroke'})

# Save main results
results_df.to_csv('stroke_classification_results.csv', index=False)
print("✅ Detailed results saved to 'stroke_classification_results.csv'")

# Create summary statistics
summary_stats = {
    'metric': ['Total Images', 'Overall Accuracy', 'AUROC', 'Average Precision',
               'True Positives', 'True Negatives', 'False Positives', 'False Negatives'],
    'value': [len(results_df), f"{overall_accuracy:.4f}", f"{roc_auc:.4f}", f"{ap_score:.4f}",
              int(tp), int(tn), int(fp), int(fn)]
}

summary_df = pd.DataFrame(summary_stats)
summary_df.to_csv('summary_statistics.csv', index=False)
print("✅ Summary statistics saved to 'summary_statistics.csv'")


In [None]:

# ==========================================
# SECTION 15: FINAL SUMMARY REPORT
# ==========================================

print("\n" + "="*60)
print("📋 FINAL COMPREHENSIVE REPORT")
print("="*60)

print(f"""
🎯 MODEL PERFORMANCE SUMMARY:
   • Total Validation Images: {len(results_df):,}
   • Overall Accuracy: {overall_accuracy:.4f} ({overall_accuracy*100:.2f}%)
   • AUROC Score: {roc_auc:.4f}
   • Average Precision: {ap_score:.4f}

📊 CONFUSION MATRIX:
   • True Positives (Stroke correctly identified): {tp}
   • True Negatives (Normal correctly identified): {tn}
   • False Positives (Normal misclassified as Stroke): {fp}
   • False Negatives (Stroke misclassified as Normal): {fn}

📈 CLASS-WISE PERFORMANCE:
""")

class_breakdown = results_df.groupby('true_class').agg({
    'correct_prediction': ['count', 'sum', 'mean'],
    'confidence': 'mean'
}).round(4)

for true_class in ['Normal', 'Stroke']:
    if true_class in class_breakdown.index:
        count = class_breakdown.loc[true_class, ('correct_prediction', 'count')]
        correct = class_breakdown.loc[true_class, ('correct_prediction', 'sum')]
        accuracy = class_breakdown.loc[true_class, ('correct_prediction', 'mean')]
        avg_conf = class_breakdown.loc[true_class, ('confidence', 'mean')]

        print(f"   • {true_class}: {correct}/{count} correct ({accuracy:.4f} accuracy)")
        print(f"     Average Confidence: {avg_conf:.4f}")

print(f"""
🔬 MODALITY BREAKDOWN:
""")

modality_breakdown = results_df.groupby('image_type').agg({
    'correct_prediction': ['count', 'mean'],
    'confidence': 'mean'
}).round(4)

for img_type in modality_breakdown.index:
    count = modality_breakdown.loc[img_type, ('correct_prediction', 'count')]
    accuracy = modality_breakdown.loc[img_type, ('correct_prediction', 'mean')]
    avg_conf = modality_breakdown.loc[img_type, ('confidence', 'mean')]

    print(f"   • {img_type}: {count} images, {accuracy:.4f} accuracy, {avg_conf:.4f} avg confidence")

print(f"""
💾 SAVED FILES:
   • stroke_classification_results.csv - Detailed per-image results
   • summary_statistics.csv - Overall performance metrics
   • model_epoch_*.pth - Model checkpoints

✅ COMPLETE DATASET PROCESSING FINISHED!
""")

print("🎉 All sections completed successfully!")