In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

class ImagePreprocessor:
    def __init__(self, image_path):
        self.image_path = image_path
        self.original_img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

    def equalize_histogram(self):
        self.equalized_img = cv2.equalizeHist(self.original_img)
        return self.equalized_img

    def normalize_image(self, image):
        normalized_img = (image - np.min(image)) / (np.max(image) - np.min(image))
        return normalized_img

    def resize_image(self, image, target_size):
        resized_img = np.array(Image.fromarray((image * 255).astype(np.uint8)).resize(target_size))
        return resized_img

    def process_and_display(self):
        # Perform all preprocessing steps
        equalized_img = self.equalize_histogram()
        normalized_img = self.normalize_image(equalized_img)
        resized_img = self.resize_image(normalized_img, (256, 256))

        # Plot the images with headings
        images = [self.original_img, equalized_img, normalized_img, resized_img]
        titles = ['Original Image', 'Equalized Histogram', 'Normalized Image', 'Resized Image']

        plt.figure(figsize=(12, 4))
        for i, (img, title) in enumerate(zip(images, titles)):
            plt.subplot(1, 4, i + 1)
            plt.imshow(img, cmap='gray')
            plt.title(title)
            plt.axis('off')
        
        plt.show()

        return resized_img

# Usage
image_path = 'DATASETS/split_data/train/2/9873823L.png'
preprocessor = ImagePreprocessor(image_path)
final_image = preprocessor.process_and_display()

In [None]:
import os
import cv2
import numpy as np
from tqdm import tqdm
import torch

def process_dataset(dataset_dir):
    X = []
    y = []

    # Count total files for progress bar
    total_files = sum(len(files) for _, _, files in os.walk(dataset_dir))

    with tqdm(total=total_files, unit="images", desc="Preprocessing dataset") as pbar:
        for dirname, _, filenames in os.walk(dataset_dir):
            
            # Check if directory name is numeric, indicating a label folder
            if os.path.basename(dirname).isdigit():
                label = int(os.path.basename(dirname))  # Folder name as the label

                for filename in filenames:
                    image_path = os.path.join(dirname, filename)
                    image = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)

                    # Initialize preprocessor with the image path
                    preprocessor = ImagePreprocessor(image_path)
                    equalized_image = preprocessor.equalize_histogram()
                    normalized_image = preprocessor.normalize_image(equalized_image)
                    resized_image = preprocessor.resize_image(normalized_image, (256, 256))

                    X.append(resized_image)
                    y.append(label)

                    pbar.update(1)
        # Convert lists to NumPy arrays first
        X_array = np.array(X)  # Convert to NumPy array
        y_array = np.array(y)
        # Convert lists to PyTorch tensors
        X_tensor = torch.tensor(X_array).float()  # Convert to float tensor
        y_tensor = torch.tensor(y_array).long()   # Convert to long tensor for labels
    return X_tensor, y_tensor

# Usage
dataset_dir = "DATASETS/split_data/train/"
X_train, y_train = process_dataset(dataset_dir)

In [None]:
dataset_dir = "DATASETS/OSAIL_KL_Dataset/Labeled/"
X_val, y_val = process_dataset(dataset_dir)

In [None]:
dataset_dir = "DATASETS/split_data/test/"
X_test, y_test = process_dataset(dataset_dir)

In [None]:
import matplotlib.pyplot as plt

# Plot random images with labels to check preprocessing
num_samples = 5
plt.figure(figsize=(15, 3))
for i in range(num_samples):
    idx = np.random.randint(0, len(X_train))
    plt.subplot(1, num_samples, i + 1)
    plt.imshow(X_train[idx], cmap='gray')
    plt.title(f"Label: {y_train[idx]}")
    plt.axis('off')
plt.show()


In [None]:
import collections
import numpy as np

# Convert the PyTorch tensor to a NumPy array
y_train_np = y_train.numpy()

# Count the occurrences of each label
label_counts = collections.Counter(y_train_np)

# Print the label distribution
print("Label distribution:", label_counts)



In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class SquashFunction(nn.Module):
    def forward(self, x, dim=-1):
        squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * x / torch.sqrt(squared_norm + 1e-8)

class PrimaryCapsules(nn.Module):
    def __init__(self, in_channels, out_channels, dim_caps, kernel_size, stride):
        super(PrimaryCapsules, self).__init__()
        self.dim_caps = dim_caps
        self.out_channels = out_channels
        self.conv = nn.Conv2d(in_channels, out_channels * dim_caps, kernel_size, stride, padding=1)  # Added padding=1
        self.squash = SquashFunction()

    def forward(self, x):
        outputs = self.conv(x)
        batch, _, height, width = outputs.shape
        
        outputs = outputs.view(batch, self.out_channels, self.dim_caps, height, width)
        outputs = outputs.permute(0, 1, 3, 4, 2).contiguous()
        outputs = outputs.view(batch, -1, self.dim_caps)
        
        return self.squash(outputs)

class CapsuleLayer(nn.Module):
    def __init__(self, num_caps_in, num_caps_out, dim_caps_in, dim_caps_out, num_iterations=3):
        super(CapsuleLayer, self).__init__()
        self.num_iterations = num_iterations
        self.num_caps_in = num_caps_in
        self.num_caps_out = num_caps_out
        
        self.W = nn.Parameter(torch.randn(1, num_caps_in, num_caps_out, dim_caps_out, dim_caps_in))
        self.squash = SquashFunction()

    def forward(self, u):
        batch_size = u.size(0)
        u = u.unsqueeze(2).unsqueeze(4)
        u_hat = torch.matmul(self.W, u)
        u_hat = u_hat.squeeze(-1)
        
        b = torch.zeros(batch_size, self.num_caps_in, self.num_caps_out).to(u.device)
        
        for i in range(self.num_iterations):
            c = F.softmax(b, dim=2)
            c = c.unsqueeze(3)
            s = (c * u_hat).sum(dim=1)
            v = self.squash(s)
            if i < self.num_iterations - 1:
                b = b + (u_hat * v.unsqueeze(1)).sum(dim=-1)
        
        return v

class DenseBlock(nn.Module):
    def __init__(self, in_channels, num_layers, growth_rate=12):
        super(DenseBlock, self).__init__()
        self.layers = nn.ModuleList()
        
        for i in range(num_layers):
            layer = nn.Sequential(
                nn.BatchNorm2d(in_channels + i * growth_rate),
                nn.ReLU(inplace=True),
                nn.Conv2d(in_channels + i * growth_rate, growth_rate, kernel_size=3, padding=1),
                nn.Dropout(0.2)
            )
            self.layers.append(layer)
            
    def forward(self, x):
        features = [x]
        for layer in self.layers:
            new_features = layer(torch.cat(features, 1))
            features.append(new_features)
        return torch.cat(features, 1)

class EnhancedMedicalCapsCNN(nn.Module):
    def __init__(self, num_classes, in_channels=1):
        super(EnhancedMedicalCapsCNN, self).__init__()
        
        # Initial feature extraction (256x256 -> 64x64)
        self.init_conv = nn.Sequential(
            nn.Conv2d(in_channels, 32, kernel_size=7, stride=2, padding=3, bias=False),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        )
        
        # Multi-scale feature extraction (64x64 -> 64x64)
        self.multi_scale = nn.ModuleList([
            nn.Conv2d(32, 16, kernel_size=k, padding=k//2) 
            for k in [3, 5, 7]
        ])
        
        # Dense block (64x64 -> 64x64)
        self.dense_block = DenseBlock(48, num_layers=4, growth_rate=12)
        dense_out_channels = 48 + 4 * 12  # 96 channels
        
        # Transition layer (64x64 -> 32x32)
        self.transition = nn.Sequential(
            nn.BatchNorm2d(dense_out_channels),
            nn.Conv2d(dense_out_channels, dense_out_channels // 2, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2)
        )
        
        transition_out_channels = dense_out_channels // 2  # 48 channels
        
        # Primary capsules (32x32 -> 16x16 with padding=1)
        self.primary_caps = PrimaryCapsules(
            in_channels=transition_out_channels,
            out_channels=32,
            dim_caps=8,
            kernel_size=3,
            stride=2
        )
        
        # Calculate primary capsules output size
        # After PrimaryCapsules: 16x16 feature maps with 32 channels
        # Total capsules = 32 * 16 * 16 = 8192
        primary_caps_size = 32 * 16 * 16  # 8192
        
        # Medical feature capsules
        self.medical_caps = CapsuleLayer(
            num_caps_in=primary_caps_size,
            num_caps_out=16,
            dim_caps_in=8,
            dim_caps_out=16
        )
        
        # Diagnostic capsules
        self.diagnostic_caps = CapsuleLayer(
            num_caps_in=16,
            num_caps_out=num_classes,
            dim_caps_in=16,
            dim_caps_out=16
        )
        
        # Attention mechanism
        self.attention = nn.Sequential(
            nn.Conv2d(48, 1, kernel_size=1),
            nn.Sigmoid()
        )
        
        # Classifier
        self.classifier = nn.Sequential(
            nn.Linear(num_classes * 16, 512),
            nn.ReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(512, num_classes)
        )
        
        self._initialize_weights()
        
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)
    
    def forward(self, x):
        # Initial feature extraction (256x256 -> 64x64)
        x = self.init_conv(x)  # Output: [batch, 32, 64, 64]
        
        # Multi-scale feature extraction
        multi_scale_features = [conv(x) for conv in self.multi_scale]
        x = torch.cat(multi_scale_features, dim=1)  # Output: [batch, 48, 64, 64]
        
        # Apply attention
        attention = self.attention(x)
        x = x * attention
        
        # Dense feature extraction with transition (64x64 -> 32x32)
        x = self.dense_block(x)  # Output: [batch, 96, 64, 64]
        x = self.transition(x)   # Output: [batch, 48, 32, 32]
        
        # Primary capsules (32x32 -> 16x16)
        primary_caps = self.primary_caps(x)  # Output: [batch, 8192, 8]
        
        # Medical feature capsules
        medical_caps = self.medical_caps(primary_caps)  # Output: [batch, 16, 16]
        
        # Diagnostic capsules
        diagnostic_caps = self.diagnostic_caps(medical_caps)  # Output: [batch, num_classes, 16]
        
        # Final classification
        x = diagnostic_caps.view(diagnostic_caps.size(0), -1)
        output = self.classifier(x)
        
        if self.training:
            return output, diagnostic_caps
        return output

class MarginLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, caps_output, target):
        batch_size = caps_output.size(0)
        
        # Calculate vector length (magnitude) of capsule outputs
        v_c = torch.sqrt((caps_output ** 2).sum(dim=-1))
        
        # Calculate losses for present and absent digit classes
        left = F.relu(self.m_pos - v_c) ** 2
        right = F.relu(v_c - self.m_neg) ** 2
        
        # Convert target to one-hot encoding
        target = F.one_hot(target, num_classes=caps_output.size(1)).float()
        
        # Calculate total loss
        loss = target * left + self.lambda_ * (1.0 - target) * right
        return loss.sum(dim=1).mean()


def print_model_summary(model, input_size=(1, 256, 256)):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    batch_size = 1
    input_shape = (batch_size, *input_size)
    dummy_input = torch.randn(input_shape).to(device)
    
    print("\nModel Architecture:")
    print(model)
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print(f"\nTotal Parameters: {total_params:,}")
    print(f"Trainable Parameters: {trainable_params:,}")
    
    model_size_mb = total_params * 4 / (1024 * 1024)
    print(f"Estimated Model Size: {model_size_mb:.2f} MB")
    
    try:
        output = model(dummy_input)
        print(f"\nInput Shape: {input_shape}")
        if isinstance(output, tuple):
            print(f"Output Shapes: {[o.shape for o in output]}")
        else:
            print(f"Output Shape: {output.shape}")
        print("\nModel summary test passed successfully!")
    except Exception as e:
        print(f"\nError during forward pass: {str(e)}")

def test_model(num_classes=5):
    model = EnhancedMedicalCapsCNN(num_classes=num_classes)
    print_model_summary(model)
    return model

if __name__ == "__main__":
    model = test_model(num_classes=5)

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
import wandb
import numpy as np
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import pandas as pd
from typing import Dict, Tuple
import os
from datetime import datetime

class CustomDataset(Dataset):
    def __init__(self, X: np.ndarray, y: np.ndarray, transform=None):
        self.X = torch.FloatTensor(X).unsqueeze(1)  # Add channel dimension
        self.y = torch.LongTensor(y)
        self.transform = transform

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        if self.transform:
            return self.transform(self.X[idx]), self.y[idx]
        return self.X[idx], self.y[idx]

class EarlyStopping:
    def __init__(self, patience=7, min_delta=0, mode='min'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_loss = None
        self.early_stop = False

    def __call__(self, current_val):
        if self.best_loss is None:
            self.best_loss = current_val
        elif current_val > self.best_loss + self.min_delta and self.mode == 'min':
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_loss = current_val
            self.counter = 0
class MedicalCapsTrainer:
    def __init__(
        self,
        model: nn.Module,
        config: Dict,
        train_loader: DataLoader,
        val_loader: DataLoader,
        margin_loss_params: Dict = None,
        optimizer: optim.Optimizer = None,
        scheduler = None,
        device: str = None
    ):
        self.config = config
        self.model = model
        self.train_loader = train_loader
        self.val_loader = val_loader
        
        # Initialize MarginLoss with default or custom parameters
        margin_loss_params = margin_loss_params or {
            'm_pos': 0.9,
            'm_neg': 0.1,
            'lambda_': 0.5
        }
        self.criterion = MarginLoss(**margin_loss_params)
        
        self.optimizer = optimizer or optim.AdamW(
            model.parameters(),
            lr=config['learning_rate'],
            weight_decay=config['weight_decay']
        )
        self.scheduler = scheduler or optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='min', patience=3, factor=0.1
        )
        self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)
        self.scaler = GradScaler()
        self.early_stopping = EarlyStopping(patience=config['early_stopping_patience'])
        
        # Initialize metrics tracking
        self.best_val_loss = float('inf')
        self.best_val_acc = 0.0
        
        # Setup WandB
        self.run = wandb.init(
            project=config['project_name'],
            config=config,
            name=f"{config['model_name']}_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
            mode=config.get('wandb_mode', 'online')
        )
        
        # Save paths
        self.save_dir = config['save_dir']
        os.makedirs(self.save_dir, exist_ok=True)
    
    def train_epoch(self) -> Tuple[float, float]:
        self.model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        pbar = tqdm(self.train_loader, desc='Training')
        for batch_idx, (inputs, targets) in enumerate(pbar):
            inputs, targets = inputs.to(self.device), targets.to(self.device)
            
            self.optimizer.zero_grad()
            
            # Mixed precision training
            with autocast('cuda'):
                outputs = self.model(inputs)
                if isinstance(outputs, tuple):
                    logits, caps_output = outputs
                else:
                    logits = outputs
                    caps_output = None
                
                # If we have capsule output, use it for the loss
                if caps_output is not None:
                    loss = self.criterion(caps_output, targets)
                    class_outputs = torch.sqrt((caps_output ** 2).sum(dim=-1))
                else:
                    loss = F.cross_entropy(logits, targets)
                    class_outputs = logits
            
            # Backward pass with gradient scaling
            self.scaler.scale(loss).backward()
            self.scaler.step(self.optimizer)
            self.scaler.update()
            
            # Metrics
            running_loss += loss.item()
            _, predicted = class_outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            
            # Update progress bar
            pbar.set_postfix({
                'loss': running_loss/(batch_idx+1),
                'acc': 100.*correct/total
            })
        
        epoch_loss = running_loss / len(self.train_loader)
        epoch_acc = 100. * correct / total
        return epoch_loss, epoch_acc
    
    def validate_epoch(self) -> Tuple[float, float]:
        self.model.eval()
        running_loss = 0.0
        correct = 0
        total = 0
        all_preds = []
        all_targets = []
        
        with torch.no_grad():
            pbar = tqdm(self.val_loader, desc='Validation')
            for batch_idx, (inputs, targets) in enumerate(pbar):
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                
                outputs = self.model(inputs)
                if isinstance(outputs, tuple):
                    logits, caps_output = outputs
                else:
                    logits = outputs
                    caps_output = None
                
                # If we have capsule output, use it for the loss
                if caps_output is not None:
                    loss = self.criterion(caps_output, targets)
                    class_outputs = torch.sqrt((caps_output ** 2).sum(dim=-1))
                else:
                    loss = F.cross_entropy(logits, targets)
                    class_outputs = logits
                
                running_loss += loss.item()
                _, predicted = class_outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
                
                all_preds.extend(predicted.cpu().numpy())
                all_targets.extend(targets.cpu().numpy())
                
                pbar.set_postfix({
                    'loss': running_loss/(batch_idx+1),
                    'acc': 100.*correct/total
                })
        
        epoch_loss = running_loss / len(self.val_loader)
        epoch_acc = 100. * correct / total
        
        # Log confusion matrix
        cm = confusion_matrix(all_targets, all_preds)
        plt.figure(figsize=(10,8))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
        plt.title('Validation Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        wandb.log({"confusion_matrix": wandb.Image(plt)})
        plt.close()
        
        return epoch_loss, epoch_acc
    
    def train(self, epochs: int):
        for epoch in range(epochs):
            print(f'\nEpoch {epoch+1}/{epochs}')
            
            # Training phase
            train_loss, train_acc = self.train_epoch()
            
            # Validation phase
            val_loss, val_acc = self.validate_epoch()
            
            # Learning rate scheduling
            self.scheduler.step(val_loss)
            current_lr = self.optimizer.param_groups[0]['lr']
            
            # Logging
            wandb.log({
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
                "learning_rate": current_lr
            })
            
            # Save best model
            if val_acc > self.best_val_acc:
                self.best_val_acc = val_acc
                self.save_model('best_model.pth')
            
            # Early stopping check
            self.early_stopping(val_loss)
            if self.early_stopping.early_stop:
                print("Early stopping triggered")
                break
        
        # Save final model
        self.save_model('final_model.pth')
        self.run.finish()
    
    def save_model(self, filename: str):
        """Save model with config and metrics"""
        save_path = os.path.join(self.save_dir, filename)
        torch.save({
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'config': self.config,
            'best_val_acc': self.best_val_acc,
            'best_val_loss': self.best_val_loss
        }, save_path)
        wandb.save(save_path)

# Example usage
def main():
    # Configuration
    config = {
        'project_name': 'medical_capsule_classification',
        'model_name': 'EnhancedMedicalCapsCNN',
        'learning_rate': 1e-4,
        'weight_decay': 1e-4,
        'batch_size': 16,
        'early_stopping_patience': 10,
        'save_dir': './models',
        'num_epochs': 100
    }
    
    # Create datasets
    train_dataset = CustomDataset(X_train, y_train)
    val_dataset = CustomDataset(X_val, y_val)
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config['batch_size'],
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config['batch_size'],
        shuffle=False,
        num_workers=4,
        pin_memory=True
    )
    
    # Initialize model and trainer
    model = EnhancedMedicalCapsCNN(num_classes=len(np.unique(y_train)))
    trainer = MedicalCapsTrainer(
        model=model,
        config=config,
        train_loader=train_loader,
        val_loader=val_loader
    )

    # Train model
    trainer.train(epochs=config['num_epochs'])

if __name__ == "__main__":
    main()

In [None]:
def evaluate_model(
    model: nn.Module,
    test_loader: DataLoader,
    device: str = None,
    save_path: str = None,
    class_names: List[str] = None
) -> Tuple[pd.DataFrame, np.ndarray]:
    """
    Evaluate capsule network model on test set and generate detailed metrics.
    
    Args:
        model: The capsule network model
        test_loader: DataLoader for test data
        device: Computing device (cuda/cpu)
        save_path: Path to save evaluation results
        class_names: List of class names for better visualization
    
    Returns:
        Tuple of (classification report DataFrame, confusion matrix)
    """
    device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    model.eval()
    
    all_preds = []
    all_targets = []
    all_probs = []
    
    with torch.no_grad():
        for inputs, targets in tqdm(test_loader, desc='Testing'):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Get model outputs
            outputs, caps_output = model(inputs)
            
            # Calculate capsule magnitudes for classification
            class_outputs = torch.sqrt((caps_output ** 2).sum(dim=-1))
            probs = F.softmax(class_outputs, dim=1)
            
            _, predicted = class_outputs.max(1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
    
    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_targets = np.array(all_targets)
    all_probs = np.array(all_probs)
    
    # Generate classification report
    report = classification_report(
        all_targets, 
        all_preds, 
        output_dict=True,
        target_names=class_names if class_names else None
    )
    report_df = pd.DataFrame(report).transpose()
    
    # Calculate and add additional metrics
    report_df['balanced_accuracy'] = balanced_accuracy_score(all_targets, all_preds)
    report_df['roc_auc'] = calculate_multiclass_roc_auc(all_targets, all_probs)
    
    # Plot confusion matrix
    plt.figure(figsize=(12, 10))
    cm = confusion_matrix(all_targets, all_preds)
    
    # Normalize confusion matrix for percentage view
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Create confusion matrix plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))
    
    # Raw counts
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax1)
    ax1.set_title('Test Set Confusion Matrix (Counts)')
    ax1.set_ylabel('True Label')
    ax1.set_xlabel('Predicted Label')
    if class_names:
        ax1.set_xticklabels(class_names, rotation=45)
        ax1.set_yticklabels(class_names, rotation=45)
    
    # Normalized percentages
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='Blues', ax=ax2)
    ax2.set_title('Test Set Confusion Matrix (Normalized)')
    ax2.set_ylabel('True Label')
    ax2.set_xlabel('Predicted Label')
    if class_names:
        ax2.set_xticklabels(class_names, rotation=45)
        ax2.set_yticklabels(class_names, rotation=45)
    
    plt.tight_layout()
    
    if save_path:
        # Create directory if it doesn't exist
        os.makedirs(save_path, exist_ok=True)
        
        # Save results
        plt.savefig(os.path.join(save_path, 'confusion_matrices.png'))
        report_df.to_csv(os.path.join(save_path, 'classification_report.csv'))
        
        # Save ROC curves
        plot_roc_curves(all_targets, all_probs, class_names)
        plt.savefig(os.path.join(save_path, 'roc_curves.png'))
    
    plt.show()
    print("\nClassification Report:")
    print(report_df)
    
    return report_df, cm

def calculate_multiclass_roc_auc(y_true, y_prob):
    """Calculate ROC AUC for multiclass classification"""
    if y_prob.shape[1] == 2:
        return roc_auc_score(y_true, y_prob[:, 1])
    else:
        return roc_auc_score(y_true, y_prob, multi_class='ovr')

def plot_roc_curves(y_true, y_prob, class_names=None):
    """Plot ROC curves for each class"""
    n_classes = y_prob.shape[1]
    
    # Create class names if not provided
    if class_names is None:
        class_names = [f'Class {i}' for i in range(n_classes)]
    
    # Plot ROC curves
    plt.figure(figsize=(10, 8))
    
    # Plot for each class
    for i in range(n_classes):
        # Convert to binary classification problem
        y_true_binary = (y_true == i).astype(int)
        
        # Calculate ROC curve
        fpr, tpr, _ = roc_curve(y_true_binary, y_prob[:, i])
        roc_auc = auc(fpr, tpr)
        
        # Plot ROC curve
        plt.plot(
            fpr, 
            tpr, 
            label=f'{class_names[i]} (AUC = {roc_auc:.2f})'
        )
    
    # Plot diagonal line
    plt.plot([0, 1], [0, 1], 'k--')
    
    # Set plot properties
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curves')
    plt.legend(loc="lower right")
    
    return plt.gcf()

In [None]:
# Create test dataset and loader
test_dataset = CustomDataset(X_test, y_test)
test_loader = DataLoader(
    test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=1
)
model = EnhancedMedicalCapsCNN(num_classes=5)
# Load best model
checkpoint = torch.load('./models/best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])

# Evaluate
report_df, confusion_matrix = evaluate_model(
    model,
    test_loader,
    save_path='./results'
)