In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import os
import time
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image
from sklearn.metrics import (
    roc_auc_score, average_precision_score, accuracy_score, 
    f1_score, classification_report, confusion_matrix, precision_recall_curve, roc_curve
)
import matplotlib.pyplot as plt
import seaborn as sns
import json
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore')

# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# -----------------------------
# Define CBAM Module
# -----------------------------
class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
        out = avg_out + max_out
        return self.sigmoid(out)

class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        return self.sigmoid(x)

class CBAM(nn.Module):
    def __init__(self, in_planes, ratio=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.ca = ChannelAttention(in_planes, ratio)
        self.sa = SpatialAttention(kernel_size)

    def forward(self, x):
        x = x * self.ca(x)
        x = x * self.sa(x)
        return x

In [None]:
# -----------------------------
# Load PCam-trained DenseNet121-CBAM Model
# -----------------------------
def load_pcam_model(model_path, device):
    """Load DenseNet121 with CBAM attention modules"""
    # Load base DenseNet121
    model = models.densenet121(weights=None)
    
    # Get the number of output channels for each dense block
    # DenseNet121 feature channels: 64 -> 128 -> 256 -> 512 -> 1024
    cbam_positions = {
        'denseblock1': 256,   # After first dense block
        'denseblock2': 512,   # After second dense block  
        'denseblock3': 1024,  # After third dense block
        'denseblock4': 1024   # After fourth dense block
    }
    
    # Add CBAM modules after each dense block
    for name, channels in cbam_positions.items():
        cbam_module = CBAM(channels)
        setattr(model.features, f'{name}_cbam', cbam_module)
    
    # Modify the forward pass to include CBAM
    original_forward = model.features.forward
    
    def new_forward(x):
        features = model.features
        x = features.conv0(x)
        x = features.norm0(x)
        x = features.relu0(x)
        x = features.pool0(x)
        
        # Process through dense blocks with CBAM
        x = features.denseblock1(x)
        x = features.denseblock1_cbam(x) if hasattr(features, 'denseblock1_cbam') else x
        x = features.transition1(x)
        
        x = features.denseblock2(x)
        x = features.denseblock2_cbam(x) if hasattr(features, 'denseblock2_cbam') else x
        x = features.transition2(x)
        
        x = features.denseblock3(x)
        x = features.denseblock3_cbam(x) if hasattr(features, 'denseblock3_cbam') else x
        x = features.transition3(x)
        
        x = features.denseblock4(x)
        x = features.denseblock4_cbam(x) if hasattr(features, 'denseblock4_cbam') else x
        
        x = features.norm5(x)
        return x
    
    model.features.forward = new_forward
    
    # Modify classifier for binary classification
    num_ftrs = model.classifier.in_features
    model.classifier = nn.Linear(num_ftrs, 1)
    
    # Load weights with flexible matching
    try:
        checkpoint = torch.load(model_path, map_location=device)
        if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
            state_dict = checkpoint['state_dict']
        else:
            state_dict = checkpoint
        
        # Clean state dict keys
        new_state_dict = {}
        for k, v in state_dict.items():
            # Remove 'module.' prefix if present
            key = k[7:] if k.startswith('module.') else k
            new_state_dict[key] = v
        
        # Load state dict with strict=False to handle architectural differences
        model.load_state_dict(new_state_dict, strict=False)
        print("✅ Model loaded successfully with CBAM integration")
        
    except Exception as e:
        print(f"Error loading model: {e}")
        raise
    
    model = model.to(device)
    model.eval()
    return model

In [None]:
# -----------------------------
# Enhanced BreakHis Dataset Class
# -----------------------------
class BreakHisDataset(Dataset):
    def __init__(self, root_dir, transform=None, magnification=None):
        self.samples = []
        self.metadata = []
        self.transform = transform
        self.class_to_idx = {'benign': 0, 'malignant': 1}
        self.magnifications = set()
        self.tumor_types = defaultdict(int)
        
        print(f"Scanning dataset directory: {root_dir}")
        
        for class_name in ['benign', 'malignant']:
            class_dir = os.path.join(root_dir, class_name)
            
            if not os.path.exists(class_dir):
                print(f"  ⚠️  Class directory not found: {class_dir}")
                continue
                
            print(f"  ✅ Found {class_name} directory")
            
            # Check if SOB subdirectory exists
            sob_dir = os.path.join(class_dir, "SOB")
            if os.path.exists(sob_dir):
                tumor_base_dir = sob_dir
            else:
                tumor_base_dir = class_dir
                
            for tumor_type_dir in os.listdir(tumor_base_dir):
                tumor_type_path = os.path.join(tumor_base_dir, tumor_type_dir)
                if not os.path.isdir(tumor_type_path):
                    continue
                    
                for patient_dir in os.listdir(tumor_type_path):
                    patient_path = os.path.join(tumor_type_path, patient_dir)
                    if not os.path.isdir(patient_path):
                        continue
                        
                    for mag in os.listdir(patient_path):
                        if magnification and mag != magnification:
                            continue
                            
                        mag_dir = os.path.join(patient_path, mag)
                        if not os.path.isdir(mag_dir):
                            continue
                        
                        self.magnifications.add(mag)
                        
                        for fname in os.listdir(mag_dir):
                            if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                                img_path = os.path.join(mag_dir, fname)
                                self.samples.append((img_path, self.class_to_idx[class_name]))
                                self.metadata.append({
                                    'path': img_path,
                                    'class': class_name,
                                    'tumor_type': tumor_type_dir,
                                    'patient': patient_dir,
                                    'magnification': mag,
                                    'filename': fname
                                })
                                self.tumor_types[f"{class_name}_{tumor_type_dir}"] += 1
        
        print(f"\n📊 Dataset Statistics:")
        print(f"   Total images: {len(self.samples)}")
        if len(self.samples) > 0:
            print(f"   Magnifications: {sorted(self.magnifications)}")
            print(f"   Tumor type distribution:")
            for tumor_type, count in sorted(self.tumor_types.items()):
                print(f"     {tumor_type}: {count}")
        else:
            print("   ❌ No images found!")
        
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label, idx
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a dummy image in case of error
            dummy_image = torch.zeros(3, 128, 128) if self.transform else Image.new('RGB', (128, 128))
            return dummy_image, label, idx


In [None]:
# -----------------------------
# Evaluation Functions
# -----------------------------
def evaluate_by_magnification(model, dataset, device, batch_size=32):
    """Evaluate model performance by magnification level"""
    results_by_mag = {}
    
    for mag in sorted(dataset.magnifications):
        print(f"\nEvaluating magnification {mag}...")
        
        # Create subset for this magnification
        mag_indices = [i for i, meta in enumerate(dataset.metadata) 
                      if meta['magnification'] == mag]
        
        if len(mag_indices) == 0:
            continue
            
        mag_dataset = torch.utils.data.Subset(dataset, mag_indices)
        mag_loader = DataLoader(mag_dataset, batch_size=batch_size, 
                               shuffle=False, num_workers=2, pin_memory=True)
        
        all_preds = []
        all_labels = []
        all_probs = []
        
        with torch.no_grad():
            for images, labels, _ in mag_loader:
                images = images.to(device, non_blocking=True)
                labels = labels.float().to(device, non_blocking=True)
                
                outputs = model(images)
                probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()
                predictions = (probabilities > 0.5).astype(int)
                
                all_probs.extend(probabilities)
                all_preds.extend(predictions)
                all_labels.extend(labels.cpu().numpy())
        
        all_probs = np.array(all_probs)
        all_preds = np.array(all_preds)
        all_labels = np.array(all_labels)
        
        if len(all_labels) > 0:
            results_by_mag[mag] = {
                'roc_auc': roc_auc_score(all_labels, all_probs),
                'pr_auc': average_precision_score(all_labels, all_probs),
                'accuracy': accuracy_score(all_labels, all_preds),
                'f1_score': f1_score(all_labels, all_preds),
                'sample_count': len(all_labels),
                'predictions': all_preds,
                'labels': all_labels,
                'probabilities': all_probs
            }
            
            print(f"  Accuracy: {results_by_mag[mag]['accuracy']:.4f}")
            print(f"  ROC-AUC: {results_by_mag[mag]['roc_auc']:.4f}")
            print(f"  Samples: {results_by_mag[mag]['sample_count']}")
        else:
            print(f"  No samples found for magnification {mag}")
    
    return results_by_mag

In [None]:
# -----------------------------
# Main Evaluation Script
# -----------------------------
def main():
    # Data preprocessing
    transform = transforms.Compose([
        transforms.Resize((128, 128)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    # Load model
    model_path = "/kaggle/input/densenet121-cbam/pytorch/default/1/densenet121_chunked.pth"
    print("Loading PCam-trained DenseNet121-CBAM model...")
    
    try:
        model = load_pcam_model(model_path, device)
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return
    
    # Load BreakHis dataset
    breakhis_root = "/kaggle/input/breakhis/BreaKHis_v1/BreaKHis_v1/histology_slides/breast"
    print(f"\nLoading BreakHis dataset from: {breakhis_root}")
    
    if not os.path.exists(breakhis_root):
        print(f"❌ Error: Dataset path does not exist: {breakhis_root}")
        return
    
    breakhis_dataset = BreakHisDataset(
        root_dir=breakhis_root,
        transform=transform,
        magnification=None  # Use all magnifications
    )
    
    if len(breakhis_dataset) == 0:
        print("❌ Error: No samples found in dataset!")
        return
    
    # Overall evaluation
    print(f"\n{'='*60}")
    print("CROSS-DATASET VALIDATION: PCam → BreakHis")
    print(f"{'='*60}")
    
    breakhis_loader = DataLoader(
        breakhis_dataset,
        batch_size=32,
        shuffle=False,
        num_workers=2,
        pin_memory=True
    )
    
    # Overall evaluation
    start_time = time.time()
    all_preds = []
    all_labels = []
    all_probabilities = []
    
    print("Starting model evaluation...")
    
    try:
        with torch.no_grad():
            for batch_idx, (images, labels, _) in enumerate(breakhis_loader):
                images = images.to(device, non_blocking=True)
                labels = labels.float().to(device, non_blocking=True)
                
                outputs = model(images)
                probabilities = torch.sigmoid(outputs).cpu().numpy().flatten()
                predictions = (probabilities > 0.5).astype(int)
                
                all_probabilities.extend(probabilities)
                all_preds.extend(predictions)
                all_labels.extend(labels.cpu().numpy())
                
                if batch_idx % 50 == 0:
                    print(f"Processed batch {batch_idx}/{len(breakhis_loader)} "
                          f"({len(all_labels)} samples so far)")
                    
    except Exception as e:
        print(f"❌ Error during evaluation: {e}")
        return
    
    evaluation_time = time.time() - start_time
    all_probabilities = np.array(all_probabilities)
    all_preds = np.array(all_preds)
    all_labels = np.array(all_labels)
    
    if len(all_labels) == 0:
        print("❌ Error: No samples were processed!")
        return
    
    print(f"✅ Processed {len(all_labels)} samples successfully")
    
    # Calculate overall metrics
    roc_auc = roc_auc_score(all_labels, all_probabilities)
    pr_auc = average_precision_score(all_labels, all_probabilities)
    accuracy = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    cm = confusion_matrix(all_labels, all_preds)
    
    print(f"\n{'='*40}")
    print("OVERALL RESULTS")
    print(f"{'='*40}")
    print(f"Evaluation time    : {evaluation_time:.1f} seconds")
    print(f"Total samples      : {len(all_labels)}")
    print(f"ROC-AUC           : {roc_auc:.4f}")
    print(f"PR-AUC            : {pr_auc:.4f}")
    print(f"Accuracy          : {accuracy:.4f}")
    print(f"F1-Score          : {f1:.4f}")
    
    # Detailed evaluation by magnification
    print(f"\n{'='*40}")
    print("MAGNIFICATION-SPECIFIC ANALYSIS")
    print(f"{'='*40}")
    
    results_by_mag = evaluate_by_magnification(model, breakhis_dataset, device)
    
    # Print summary table
    print(f"\n{'='*80}")
    print("PERFORMANCE SUMMARY BY MAGNIFICATION")
    print(f"{'='*80}")
    print(f"{'Mag':<6} {'Samples':<8} {'Accuracy':<10} {'ROC-AUC':<10} {'PR-AUC':<10} {'F1-Score':<10}")
    print(f"{'-'*80}")
    
    for mag in sorted(results_by_mag.keys()):
        r = results_by_mag[mag]
        print(f"{mag:<6} {r['sample_count']:<8} {r['accuracy']:<10.4f} "
              f"{r['roc_auc']:<10.4f} {r['pr_auc']:<10.4f} {r['f1_score']:<10.4f}")
    
    print(f"{'-'*80}")
    print(f"{'Overall':<6} {len(all_labels):<8} {accuracy:<10.4f} "
          f"{roc_auc:<10.4f} {pr_auc:<10.4f} {f1:<10.4f}")
    
    print(f"\n✅ Analysis complete!")

if __name__ == "__main__":
    main()