# 10 Cross Validation with Pretrained(Resnet18) + Attention(CBAM) 

In [None]:
import os
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import datasets, models, transforms
from PIL import Image
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import confusion_matrix, classification_report, precision_score, recall_score, f1_score
from tqdm import tqdm
from timm import create_model
from sklearn.model_selection import train_test_split
from timeit import default_timer as timer
# Set device (assuming GPU is available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

### Considering 3 Classes for heart US

In [None]:
# Set a random seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

# Path to the dataset (assuming it's organized into 10 folds with 3 classes)
dataset_path = ""  # Update as needed

# Create a list to store fold data
fold_data = []

# Class name to label mapping (assuming 3 classes)
class_mapping = {
    'HCM': 0,  # Label 0 
    'HTN': 1,  # Label 1 
    'Normal': 2   # Label 2 
}

# Assuming each fold has 3 subfolders corresponding to the 3 classes
for fold in range(1, 11):
    fold_images = []
    fold_labels = []

    # Path for the current fold
    fold_path = os.path.join(dataset_path, f"fold_{fold}")
    if not os.path.isdir(fold_path):
        print(f"Error: {fold_path} does not exist.")
        continue

    print(f"Processing Fold {fold}:")

    # Iterate through the subdirectories for each class
    for class_folder in os.listdir(fold_path):
        class_folder_path = os.path.join(fold_path, class_folder)
        if not os.path.isdir(class_folder_path):
            continue

        # Get the label for the class folder using the class_mapping
        if class_folder in class_mapping:
            label = class_mapping[class_folder]
        else:
            print(f"Warning: Class folder '{class_folder}' not in mapping. Skipping.")
            continue

        # Process each image in the class folder
        for img_name in os.listdir(class_folder_path):
            img_path = os.path.join(class_folder_path, img_name)
            fold_images.append(img_path)
            fold_labels.append(label)

    # Convert to numpy arrays
    fold_images = np.array(fold_images)
    fold_labels = np.array(fold_labels)

    # Split into 80% train and 20% test+validation
    train_images, temp_images, train_labels, temp_labels = train_test_split(
        fold_images, fold_labels, test_size=0.2, random_state=seed, stratify=fold_labels
    )

    # Split the remaining 20% into 10% validation and 10% test
    val_images, test_images, val_labels, test_labels = train_test_split(
        temp_images, temp_labels, test_size=0.5, random_state=seed, stratify=temp_labels
    )

    # Print the size of train, validation, and test sets
    print(f"  Train size: {len(train_images)}, Validation size: {len(val_images)}, Test size: {len(test_images)}")
    print(f"  Class distribution in Train: {np.bincount(train_labels)}")
    print(f"  Class distribution in Validation: {np.bincount(val_labels)}")
    print(f"  Class distribution in Test: {np.bincount(test_labels)}\n")

    # Store fold data
    fold_data.append({
        "train_images": train_images,
        "train_labels": train_labels,
        "val_images": val_images,
        "val_labels": val_labels,
        "test_images": test_images,
        "test_labels": test_labels
    })


In [6]:
# Custom dataset class
class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, transform=None):
        self.image_paths = image_paths
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')  # Open image and convert to RGB
        
        # Apply transformations if any
        if self.transform:
            image = self.transform(image)
        
        label = self.labels[idx]
        return image, label
# Define the image transformations (e.g., normalization, augmentation)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to match the input size expected by ResNet
    transforms.ToTensor(),  # Convert image to tensor
])

# Channel Attention Module
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, reduction=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, bias=False)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(in_channels // reduction, in_channels, kernel_size=1, bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
        max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
        out = self.sigmoid(avg_out + max_out)
        return x * out  # Scale the input by the channel attention
# Spatial Attention Module
class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=kernel_size, padding=(kernel_size // 2), bias=False)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        # Compress along the channel dimension using max and average pooling
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x_compressed = torch.cat([avg_out, max_out], dim=1)  # Combine along channel dimension
        attention = self.sigmoid(self.conv(x_compressed))
        return x * attention  # Scale the input by the spatial attention
# Convolutional Block Attention Module (CBAM)
class CBAM(nn.Module):
    def __init__(self, in_channels, reduction=16, kernel_size=7):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels, reduction)
        self.spatial_attention = SpatialAttention(kernel_size)
    def forward(self, x):
        x = self.channel_attention(x)  # Apply channel attention
        x = self.spatial_attention(x)  # Apply spatial attention
        return x
class ResNet18CBAM(nn.Module):
    def __init__(self, num_classes):
        super(ResNet18CBAM, self).__init__()
        # Load pre-trained ResNet18
        self.resnet18 = models.resnet18(weights='IMAGENET1K_V1')
        # Add CBAM to each residual block
        self.resnet18.layer1 = self.add_cbam_to_layer(self.resnet18.layer1)
        self.resnet18.layer2 = self.add_cbam_to_layer(self.resnet18.layer2)
        self.resnet18.layer3 = self.add_cbam_to_layer(self.resnet18.layer3)
        self.resnet18.layer4 = self.add_cbam_to_layer(self.resnet18.layer4)
        # Modify the final fully connected layer for num_classes
        num_ftrs = self.resnet18.fc.in_features
        self.resnet18.fc = nn.Linear(num_ftrs, num_classes)
    def add_cbam_to_layer(self, layer):
        # Add CBAM to each block in the layer
        for idx, block in enumerate(layer):
            block.add_module("cbam", CBAM(block.conv2.out_channels))
        return layer
    def forward(self, x):
        return self.resnet18(x)
def initialize_model():
    # Instantiate and move to device
    model = ResNet18CBAM(num_classes=3)
    model=model.to(device)
    return model
# Set optimizer and loss function
def initialize_optimizer(model):
    optimizer = torch.optim.Adam(params=model.parameters(), lr=3e-4, weight_decay=0.0001)
    loss_fn = nn.CrossEntropyLoss()
    return optimizer, loss_fn
# Train step
def train_step(model, dataloader, loss_fn, optimizer, device):
    model.train()
    train_loss, train_acc = 0, 0
    all_train_preds, all_train_labels = [], []
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)
        # Forward pass
        outputs = model(X)
        # Handle Inception model's tuple output
        if isinstance(outputs, tuple):
            y_pred = outputs.logits  # Get the main logits
        else:
            y_pred = outputs
        # Compute loss
        loss = loss_fn(y_pred, y)
        train_loss += loss.item()
        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # Compute accuracy
        y_pred_class = torch.argmax(y_pred, dim=1)
        train_acc += (y_pred_class == y).sum().item() / y.size(0)
        # Store predictions and labels
        all_train_preds.extend(y_pred_class.cpu().numpy())
        all_train_labels.extend(y.cpu().numpy())
    train_loss /= len(dataloader)
    train_acc /= len(dataloader)
    return train_loss, train_acc, all_train_preds, all_train_labels

# Validation step
def val_step(model, dataloader, loss_fn, device):
    model.eval()
    val_loss, val_acc = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            # Forward pass
            outputs = model(X)
            if isinstance(outputs, tuple):
                val_pred_logits = outputs.logits
            else:
                val_pred_logits = outputs
            # Compute loss
            loss = loss_fn(val_pred_logits, y)
            val_loss += loss.item()
            # Compute accuracy
            val_pred_labels = torch.argmax(val_pred_logits, dim=1)
            val_acc += (val_pred_labels == y).sum().item() / y.size(0)
            # Store predictions and labels
            all_preds.extend(val_pred_labels.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    val_loss /= len(dataloader)
    val_acc /= len(dataloader)
    return val_loss, val_acc, all_preds, all_labels

# Test step
def test_step(model, dataloader, loss_fn, device):
    model.eval()
    test_loss, test_acc = 0, 0
    all_preds, all_labels = [], []
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            # Forward pass
            outputs = model(X)
            if isinstance(outputs, tuple):
                test_pred_logits = outputs.logits
            else:
                test_pred_logits = outputs
            # Compute loss
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()
            # Compute accuracy
            test_pred_labels = torch.argmax(test_pred_logits, dim=1)
            test_acc += (test_pred_labels == y).sum().item() / y.size(0)
            # Store predictions and labels
            all_preds.extend(test_pred_labels.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
    test_loss /= len(dataloader)
    test_acc /= len(dataloader)
    return test_loss, test_acc, all_preds, all_labels
# Train loop
def train(model, train_dataloader, val_dataloader, test_dataloader, optimizer, loss_fn, epochs, device):
    results = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "test_loss": [], "test_acc": []}
    all_train_preds, all_train_labels = [], []
    all_val_preds, all_val_labels = [], []
    all_test_preds, all_test_labels = [], []
    for epoch in tqdm(range(epochs), desc="Training Epochs"):
        # Training step
        train_loss, train_acc, train_preds, train_labels = train_step(
            model, train_dataloader, loss_fn, optimizer, device
        )
        # Validation step
        val_loss, val_acc, val_preds, val_labels = val_step(
            model, val_dataloader, loss_fn, device
        )
        # Testing step
        test_loss, test_acc, test_preds, test_labels = test_step(
            model, test_dataloader, loss_fn, device
        )
        # Store epoch results
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["val_loss"].append(val_loss)
        results["val_acc"].append(val_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)
        # Store all predictions and labels
        all_train_preds.extend(train_preds)
        all_train_labels.extend(train_labels)
        all_val_preds.extend(val_preds)
        all_val_labels.extend(val_labels)
        all_test_preds.extend(test_preds)
        all_test_labels.extend(test_labels)
        # Log epoch results
        print(
            f"Epoch {epoch + 1}/{epochs} | "
            f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f} | "
            f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | "
            f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    return results, all_train_preds, all_train_labels, all_val_preds, all_val_labels, all_test_preds, all_test_labels

In [7]:
def calculate_metrics(cm):
    num_classes = cm.shape[0]
    
    # Initialize empty lists to store metrics for each class
    precision = []
    recall = []
    f1_score = []
    specificity = []
    
    # Initialize counters for total TP, FP, TN, FN for overall accuracy calculation
    TP_total = 0
    TN_total = 0
    FP_total = 0
    FN_total = 0
    total_samples = cm.sum()  
    
    # Iterate through each class
    for i in range(num_classes):
        TP = cm[i, i]  # True Positive for class i
        FP = cm[:, i].sum() - TP  # False Positive for class i
        FN = cm[i, :].sum() - TP  # False Negative for class i
        
        # Calculate True Negative correctly by summing all non-class predictions
        TN = total_samples - (TP + FP + FN) 

        # Accumulate TP, FP, TN, FN for overall accuracy
        TP_total += TP
        TN_total += TN
        FP_total += FP
        FN_total += FN

        # Precision: TP / (TP + FP)
        precision_i = TP / (TP + FP) if (TP + FP) > 0 else 0
        precision.append(precision_i)
        
        # Recall: TP / (TP + FN)
        recall_i = TP / (TP + FN) if (TP + FN) > 0 else 0
        recall.append(recall_i)
        
        # F1-Score: Harmonic mean of precision and recall
        if precision_i + recall_i > 0:
            f1_i = 2 * (precision_i * recall_i) / (precision_i + recall_i)
        else:
            f1_i = 0
        f1_score.append(f1_i)
        
        # Specificity: TN / (TN + FP)
        specificity_i = TN / (TN + FP) if (TN + FP) > 0 else 0
        specificity.append(specificity_i)

    # Correct Accuracy calculation using the new formula
    accuracy = (TP_total + TN_total) / (TP_total + TN_total + FP_total + FN_total) if total_samples > 0 else np.nan

    precision_macro = np.mean(precision)
    recall_macro = np.mean(recall)
    f1_macro = np.mean(f1_score)
    specificity_macro = np.mean(specificity)

    return {
        'accuracy': accuracy,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'specificity_macro': specificity_macro
    }
# Train, validation, and test loop with manual metric calculation
for fold_index, fold in enumerate(fold_data):
    print(f"\nProcessing Fold {fold_index + 1}...")
    
    # Get the data for the current fold
    train_images = fold["train_images"]
    train_labels = fold["train_labels"]
    val_images = fold["val_images"]  # Added validation images
    val_labels = fold["val_labels"]  # Added validation labels
    test_images = fold["test_images"]
    test_labels = fold["test_labels"]

    # Create datasets
    train_dataset = CustomDataset(train_images, train_labels, transform=transform)
    val_dataset = CustomDataset(val_images, val_labels, transform=transform)  # Validation dataset
    test_dataset = CustomDataset(test_images, test_labels, transform=transform)
    torch.manual_seed(42)

    # Create dataloaders
    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)  # Validation dataloader
    test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

    # Initialize model, optimizer, and loss function
    model = initialize_model()
    optimizer, loss_fn = initialize_optimizer(model)

    # Start training
    start_time = timer()
    results, train_preds, train_labels, val_preds, val_labels, test_preds, test_labels = train(
        model, train_dataloader, val_dataloader, test_dataloader, optimizer, loss_fn, epochs=10, device=device)
    end_time = timer()

    # Flatten labels and predictions to 1D arrays
    train_labels = np.array(train_labels).flatten()
    train_preds = np.array(train_preds).flatten()
    val_labels = np.array(val_labels).flatten()  # Validation labels
    val_preds = np.array(val_preds).flatten()    # Validation predictions
    test_labels = np.array(test_labels).flatten()
    test_preds = np.array(test_preds).flatten()
    
    # Confusion Matrices
    train_cm = confusion_matrix(train_labels, train_preds)
    val_cm = confusion_matrix(val_labels, val_preds)  # Validation confusion matrix
    test_cm = confusion_matrix(test_labels, test_preds)
    
    print(f"\nFold {fold_index + 1} Confusion Matrices:")
    print("\nTrain Confusion Matrix:")
    print(train_cm)
    print("\nValidation Confusion Matrix:")
    print(val_cm)
    print("\nTest Confusion Matrix:")
    print(test_cm)
    
    # Calculate Train Metrics manually
    train_metrics = calculate_metrics(train_cm)
    print("\nTrain Metrics:")
    print(f"Accuracy: {train_metrics['accuracy']:.4f}")
    print(f"Precision (Macro Avg): {train_metrics['precision_macro']:.4f}")
    print(f"Recall/Sensitivity (Macro Avg): {train_metrics['recall_macro']:.4f}")
    print(f"F1-Score (Macro Avg): {train_metrics['f1_macro']:.4f}")
    print(f"Specificity (Macro Avg): {train_metrics['specificity_macro']:.4f}")

    # Calculate Validation Metrics manually
    val_metrics = calculate_metrics(val_cm)
    print("\nValidation Metrics:")
    print(f"Accuracy: {val_metrics['accuracy']:.4f}")
    print(f"Precision (Macro Avg): {val_metrics['precision_macro']:.4f}")
    print(f"Recall/Sensitivity (Macro Avg): {val_metrics['recall_macro']:.4f}")
    print(f"F1-Score (Macro Avg): {val_metrics['f1_macro']:.4f}")
    print(f"Specificity (Macro Avg): {val_metrics['specificity_macro']:.4f}")

    # Calculate Test Metrics manually
    test_metrics = calculate_metrics(test_cm)
    print("\nTest Metrics:")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")
    print(f"Precision (Macro Avg): {test_metrics['precision_macro']:.4f}")
    print(f"Recall/Sensitivity (Macro Avg): {test_metrics['recall_macro']:.4f}")
    print(f"F1-Score (Macro Avg): {test_metrics['f1_macro']:.4f}")
    print(f"Specificity (Macro Avg): {test_metrics['specificity_macro']:.4f}")

    # Print training time for each fold
    print(f"Training time for fold {fold_index + 1}: {end_time - start_time:.2f} seconds")
