# Start BraCs Multi class Classification project!

In [1]:
import os
os.environ["TORCH_HOME"] = r"O:\O drive\AI\my project\medical image projects\BraCs\New folder"

import torch
print(torch.hub.get_dir())
print(os.environ["TORCH_HOME"])


O:\O drive\AI\my project\medical image projects\BraCs\New folder\hub
O:\O drive\AI\my project\medical image projects\BraCs\New folder


In [2]:

torch.cuda.empty_cache()

os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'

In [3]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, models
from torchvision.models import (
    ResNeXt50_32X4D_Weights, DenseNet201_Weights, EfficientNet_B0_Weights,
    ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
)
import albumentations as A
from albumentations.pytorch import ToTensorV2
from sklearn.metrics import roc_auc_score, confusion_matrix, classification_report
from sklearn.utils.class_weight import compute_class_weight
import matplotlib.pyplot as plt
import seaborn as sns
from torch.amp import GradScaler, autocast
import pandas as pd
from PIL import Image
import random

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

# Step 1: Create DataFrames
def create_dataframes(root_dir):
    """
    Create separate DataFrames for train, val, and test splits.

    Directory structure:
    <root_dir>/
        train/
            class1/
            class2/
            ...
        val/
            class1/
            class2/
            ...
        test/
            class1/
            class2/
            ...

    Args:
        root_dir (str): Root directory of the dataset.

    Returns:
        tuple: (train_df, val_df, test_df, class_to_idx)
            - train_df, val_df, test_df: DataFrames with columns "file_path", "label", "class_name"
            - class_to_idx: Dict mapping class names to numeric indices
    """
    splits = ["train", "val", "test"]
    valid_extensions = {'.png', '.jpg', '.jpeg', '.tif', '.bmp'}
    data = {split: [] for split in splits}

    # Get class names from train split to ensure consistent labeling
    train_dir = os.path.join(root_dir, "train")
    class_names = sorted([d for d in os.listdir(train_dir) if os.path.isdir(os.path.join(train_dir, d))])
    class_to_idx = {name: idx for idx, name in enumerate(class_names)}

    for split in splits:
        split_dir = os.path.join(root_dir, split)
        if not os.path.isdir(split_dir):
            print(f"Warning: {split_dir} does not exist!")
            continue

        for class_name in os.listdir(split_dir):
            class_dir = os.path.join(split_dir, class_name)
            if not os.path.isdir(class_dir):
                continue

            # Use numeric label from class_to_idx
            label = class_to_idx.get(class_name, -1)
            if label == -1:
                print(f"Warning: Class {class_name} in {split} not found in train classes!")
                continue

            for file in os.listdir(class_dir):
                if os.path.splitext(file)[1].lower() in valid_extensions:
                    image_path = os.path.join(class_dir, file)
                    data[split].append({
                        "file_path": image_path,
                        "label": label,
                        "class_name": class_name
                    })

    train_df = pd.DataFrame(data["train"])
    val_df = pd.DataFrame(data["val"])
    test_df = pd.DataFrame(data["test"])

    for df, split in zip([train_df, val_df, test_df], splits):
        if df.empty:
            print(f"Warning: No images found in {split} split!")
    
    return train_df, val_df, test_df, class_to_idx

# Step 2: Analyze and Visualize
def analyze_and_plot_dataframes(train_df, val_df, test_df, class_to_idx, save_dir):
    """
    Analyze and visualize each DataFrame.

    Args:
        train_df, val_df, test_df (pd.DataFrame): DataFrames for each split.
        class_to_idx (dict): Mapping of class names to indices.
        save_dir (str): Directory to save plots.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Basic statistics and duplicates
    for df, split in zip([train_df, val_df, test_df], ["Train", "Validation", "Test"]):
        print(f"\n{split} Dataset Statistics:")
        print(f"Total images: {len(df)}")
        print(f"Class distribution:\n{df['class_name'].value_counts()}")
        print(f"Duplicate images: {df.duplicated().sum()}")

    # Dynamic color palette for classes
    class_names = sorted(class_to_idx.keys())
    colors = sns.color_palette("husl", len(class_names))
    color_map = {idx: colors[i] for i, idx in enumerate(class_to_idx.values())}

    # Plot class distribution for each split
    for df, split in zip([train_df, val_df, test_df], ["Train", "Validation", "Test"]):
        if df.empty:
            continue
        plt.figure(figsize=(10, 6))
        counts = df['label'].value_counts().sort_index()
        bars = plt.bar(counts.index, counts.values, color=[color_map[i] for i in counts.index])
        
        plt.ylim(0, counts.max() + counts.max() * 0.1)
        for bar in bars:
            height = bar.get_height()
            plt.annotate(f'{int(height)}',
                         xy=(bar.get_x() + bar.get_width() / 2, height),
                         xytext=(0, 5),
                         textcoords='offset points',
                         ha='center', va='bottom')
        
        plt.xlabel('Class')
        plt.ylabel('Count')
        plt.title(f'Class Distribution in {split} Set')
        plt.xticks(counts.index, [class_names[i] for i in counts.index], rotation=45)
        plt.grid(axis='y', alpha=0.7)
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{split.lower()}_class_distribution.png"))
        plt.close()

    # Plot random samples (up to 10 per class) for each split
    for df, split in zip([train_df, val_df, test_df], ["Train", "Validation", "Test"]):
        if df.empty:
            continue
        classes = sorted(df['class_name'].unique())
        samples_per_class = min(10, df['class_name'].value_counts().min())
        
        # Calculate grid size
        n_cols = min(5, samples_per_class)
        n_rows = (samples_per_class + n_cols - 1) // n_cols * len(classes)
        
        plt.figure(figsize=(n_cols * 4, n_rows * 4))
        plot_idx = 1
        
        for class_name in classes:
            class_images = df[df['class_name'] == class_name]['file_path'].values
            if len(class_images) == 0:
                continue
            samples = np.random.choice(class_images, min(samples_per_class, len(class_images)), replace=False)
            
            for img_path in samples:
                plt.subplot(n_rows, n_cols, plot_idx)
                img = Image.open(img_path).convert('RGB')
                plt.imshow(img)
                plt.title(class_name, fontsize=12)
                plt.axis('off')
                plot_idx += 1
        
        plt.tight_layout()
        plt.savefig(os.path.join(save_dir, f"{split.lower()}_random_samples.png"))
        plt.close()



In [4]:
#Step 1: Create DataFrames
data_dir = r"O:\O drive\AI\my project\medical image projects\BraCs\Dataset\archive"
train_df, val_df, test_df, class_to_idx = create_dataframes(data_dir)

print("\nTrain DataFrame Preview:")
print(train_df.head())
print(f"Total train images: {len(train_df)}")

print("\nValidation DataFrame Preview:")
print(val_df.head())
print(f"Total validation images: {len(val_df)}")

print("\nTest DataFrame Preview:")
print(test_df.head())
print(f"Total test images: {len(test_df)}")

print("\nClass to Index Mapping:")
print(class_to_idx)




Train DataFrame Preview:
                                           file_path  label class_name
0  O:\O drive\AI\my project\medical image project...      0        0_N
1  O:\O drive\AI\my project\medical image project...      0        0_N
2  O:\O drive\AI\my project\medical image project...      0        0_N
3  O:\O drive\AI\my project\medical image project...      0        0_N
4  O:\O drive\AI\my project\medical image project...      0        0_N
Total train images: 3655

Validation DataFrame Preview:
                                           file_path  label class_name
0  O:\O drive\AI\my project\medical image project...      0        0_N
1  O:\O drive\AI\my project\medical image project...      0        0_N
2  O:\O drive\AI\my project\medical image project...      0        0_N
3  O:\O drive\AI\my project\medical image project...      0        0_N
4  O:\O drive\AI\my project\medical image project...      0        0_N
Total validation images: 311

Test DataFrame Preview:
            

In [5]:
# Step 2: Analyze and Visualize
analyze_and_plot_dataframes(train_df, val_df, test_df, class_to_idx, "data_analysis")


Train Dataset Statistics:
Total images: 3655
Class distribution:
class_name
1_PB      714
5_DCIS    665
3_FEA     624
6_IC      519
2_UDH     389
4_ADH     387
0_N       357
Name: count, dtype: int64
Duplicate images: 0

Validation Dataset Statistics:
Total images: 311
Class distribution:
class_name
3_FEA     49
6_IC      47
0_N       46
2_UDH     46
1_PB      43
4_ADH     41
5_DCIS    39
Name: count, dtype: int64
Duplicate images: 0

Test Dataset Statistics:
Total images: 570
Class distribution:
class_name
5_DCIS    85
3_FEA     83
2_UDH     82
0_N       81
6_IC      81
1_PB      79
4_ADH     79
Name: count, dtype: int64
Duplicate images: 0


In [6]:
train_df.shape

(3655, 3)

In [7]:
val_df.shape

(311, 3)

In [8]:
test_df.shape

(570, 3)

In [9]:
import matplotlib.pyplot as plt
import seaborn as sns
import os

# Pie Chart 1: Dataset Split Distribution
def plot_dataset_split_pie(train_df, val_df, test_df, save_dir="data_analysis"):
    """
    Plot a pie chart showing the distribution of samples across train, val, and test splits.

    Args:
        train_df, val_df, test_df (pd.DataFrame): DataFrames for each split.
        save_dir (str): Directory to save the plot.
    """
    sizes = [len(train_df), len(val_df), len(test_df)]
    labels = ['Training', 'Validation', 'Test']
    total = sum(sizes)
    colors = sns.color_palette("pastel")[:3]

    fig, ax = plt.subplots(figsize=(7, 7))
    wedges, texts, autotexts = ax.pie(
        sizes,
        labels=labels,
        autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total)})",
        startangle=140,
        colors=colors,
        shadow=True,
        wedgeprops={'edgecolor': 'white', 'linewidth': 1}
    )

    for text in texts:
        text.set_fontsize(13)
        text.set_fontweight('bold')
    for autotext in autotexts:
        autotext.set_fontsize(11)

    ax.set_title("Dataset Split Distribution", fontsize=16, fontweight='bold')
    plt.text(0, -1.3, f"Total Samples: {total}", fontsize=11, ha='center', style='italic')
    plt.tight_layout()
    
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "dataset_split_pie.png"))
    plt.close()

# Pie Chart 2: Test Set Class Distribution
def plot_test_class_pie(test_df, class_to_idx, save_dir="data_analysis"):
    """
    Plot a pie chart showing the class distribution in the test set.

    Args:
        test_df (pd.DataFrame): Test DataFrame with 'label' and 'class_name' columns.
        class_to_idx (dict): Mapping of class names to numeric indices.
        save_dir (str): Directory to save the plot.
    """
    # Count samples per class
    class_counts = test_df['label'].value_counts().sort_index()
    sizes = class_counts.values
    labels = [list(class_to_idx.keys())[list(class_to_idx.values()).index(i)] for i in class_counts.index]
    total = sum(sizes)
    colors = sns.color_palette("pastel", len(labels))

    fig, ax = plt.subplots(figsize=(7, 7))
    wedges, texts, autotexts = ax.pie(
        sizes,
        labels=labels,
        autopct=lambda pct: f"{pct:.1f}%\n({int(pct/100.*total)})",
        startangle=140,
        colors=colors,
        shadow=True,
        wedgeprops={'edgecolor': 'white', 'linewidth': 1}
    )

    for text in texts:
        text.set_fontsize(13)
        text.set_fontweight('bold')
    for autotext in autotexts:
        autotext.set_fontsize(11)

    ax.set_title("Test Set Class Distribution", fontsize=16, fontweight='bold')
    plt.text(0, -1.3, f"Total Test Samples: {total}", fontsize=11, ha='center', style='italic')
    plt.tight_layout()
    
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "test_class_pie.png"))
    plt.close()

In [10]:
from sklearn.utils import resample
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Assuming train_df and class_to_idx are available from previous step
# Resample train_df to balance classes
def resample_train_df(train_df, class_to_idx, save_dir="data_analysis"):
    """
    Resample the training DataFrame to balance classes by oversampling minority classes.

    Args:
        train_df (pd.DataFrame): Training DataFrame with 'file_path', 'label', 'class_name' columns.
        class_to_idx (dict): Mapping of class names to numeric indices.
        save_dir (str): Directory to save the distribution plot.

    Returns:
        pd.DataFrame: Balanced training DataFrame.
    """
    # Get class counts
    class_counts = train_df['label'].value_counts()
    majority_count = class_counts.max()
    majority_label = class_counts.idxmax()

    # Separate DataFrames for each class
    dfs_by_class = [train_df[train_df['label'] == label] for label in class_counts.index]

    # Oversample minority classes to match majority
    balanced_dfs = []
    for df_class, label in zip(dfs_by_class, class_counts.index):
        if label == majority_label:
            balanced_dfs.append(df_class)
        else:
            df_oversampled = resample(
                df_class,
                replace=True,
                n_samples=majority_count,
                random_state=42
            )
            balanced_dfs.append(df_oversampled)

    # Combine all classes
    train_df_balanced = pd.concat(balanced_dfs)

    # Shuffle the dataset
    train_df_balanced = train_df_balanced.sample(frac=1, random_state=42).reset_index(drop=True)

    # Print new class distribution
    print("New class distribution after oversampling:")
    print(train_df_balanced['label'].value_counts())

    # Visualize new class distribution
    class_names = sorted(class_to_idx, key=class_to_idx.get)  # Sort by index
    colors = sns.color_palette("husl", len(class_names))
    color_map = {i: colors[i] for i in range(len(class_names))}

    plt.figure(figsize=(10, 6))
    counts = train_df_balanced['label'].value_counts().sort_index()
    bars = plt.bar(counts.index, counts.values, color=[color_map[i] for i in counts.index])

    plt.ylim(0, counts.max() + counts.max() * 0.1)
    for bar in bars:
        height = bar.get_height()
        plt.annotate(f'{int(height)}',
                     xy=(bar.get_x() + bar.get_width() / 2, height),
                     xytext=(0, 5),
                     textcoords='offset points',
                     ha='center', va='bottom')

    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.title('Class Distribution in Balanced Train Set')
    plt.xticks(counts.index, [class_names[i] for i in counts.index], rotation=45)
    plt.grid(axis='y', alpha=0.7)
    plt.tight_layout()
    os.makedirs(save_dir, exist_ok=True)
    plt.savefig(os.path.join(save_dir, "train_balanced_class_distribution.png"))
    plt.close()

    return train_df_balanced

In [11]:
# After creating DataFrames and resampling
train_df, val_df, test_df, class_to_idx = create_dataframes(data_dir)
train_df_balanced = resample_train_df(train_df, class_to_idx)

# Plot pie charts
plot_dataset_split_pie(train_df_balanced, val_df, test_df)
plot_test_class_pie(test_df, class_to_idx)

New class distribution after oversampling:
label
1    714
6    714
0    714
5    714
4    714
2    714
3    714
Name: count, dtype: int64


In [12]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from PIL import Image, ImageFile
from sklearn.utils.class_weight import compute_class_weight
import albumentations as A
from albumentations.pytorch import ToTensorV2
import os
import gc
import random
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Handle truncated/corrupted images
ImageFile.LOAD_TRUNCATED_IMAGES = True

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Define preprocessing transforms with albumentations
def get_transforms(split, mean=(0.7531985640525818, 0.599432110786438, 0.7416298985481262), 
                  std=(0.21170948445796967, 0.2636403441429138, 0.19742192327976227)):
    """Define Albumentations transforms for train/val/test splits."""
    if split == "train":
        return A.Compose([
            A.Resize(224, 224, always_apply=True),
            A.HorizontalFlip(p=0.5),
            A.VerticalFlip(p=0.5),
            A.Rotate(limit=90, p=0.5),
            A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),
            A.ElasticTransform(alpha=1, sigma=50, alpha_affine=50, p=0.5),
            A.Cutout(num_holes=8, max_h_size=32, max_w_size=32, p=0.5),
            A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.5),
            A.GaussianBlur(blur_limit=(3, 7), p=0.5),
            A.RandomBrightnessContrast(p=0.5),
            A.RandomResizedCrop(224, 224, scale=(0.8, 1.0), p=0.5),
            A.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ])
    else:
        return A.Compose([
            A.Resize(224, 224, always_apply=True),
            A.Normalize(mean=mean, std=std),
            ToTensorV2(),
        ])

# Wrapper to use albumentations with PyTorch
class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img = np.array(img)
        augmented = self.transform(image=img)
        return augmented['image']

# Custom Dataset for DataFrame
class DataFrameDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform
        self.classes = sorted(self.df['class_name'].unique())
        self.class_to_idx = {cls: idx for idx, cls in enumerate(self.classes)}
        self.df = self._validate_and_filter_dataset()
        unique_labels = sorted(self.df['label'].unique())
        expected_labels = list(range(len(self.classes)))
        if sorted(unique_labels) != expected_labels:
            logger.error(f"Label mismatch: Expected labels {expected_labels}, found {unique_labels}")
            raise ValueError(f"Label mismatch: Expected {expected_labels}, found {unique_labels}")
        logger.info(f"Dataset size after filtering: {len(self.df)}")
        logger.info(f"Classes: {self.classes}")
        logger.info(f"Class to index: {self.class_to_idx}")
        logger.info(f"Class distribution:\n{self.df['class_name'].value_counts()}")

    def _validate_and_filter_dataset(self):
        valid_rows = []
        for idx in range(len(self.df)):
            row = self.df.iloc[idx]
            img_path = row['file_path']
            label = row['label']
            if not os.path.isfile(img_path):
                logger.warning(f"Invalid file path at index {idx}: {img_path}")
                continue
            if not isinstance(label, (int, np.integer)) or label not in range(len(self.classes)):
                logger.warning(f"Invalid label at index {idx}: {label}")
                continue
            try:
                img = Image.open(img_path).convert('RGB')
                img.close()
            except Exception as e:
                logger.warning(f"Corrupted image at index {idx}: {img_path}, error: {e}")
                continue
            valid_rows.append(idx)
        if not valid_rows:
            raise ValueError("No valid items in dataset after filtering")
        filtered_df = self.df.iloc[valid_rows].reset_index(drop=True)
        logger.info(f"Filtered out {len(self.df) - len(filtered_df)} invalid rows")
        return filtered_df

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = row['file_path']
        label = row['label']
        try:
            img = Image.open(img_path).convert('RGB')
        except Exception as e:
            logger.error(f"Error loading image {img_path}: {e}")
            return None
        if self.transform:
            img = self.transform(img)
        label_tensor = torch.tensor(label, dtype=torch.long)
        if label_tensor.numel() != 1:
            logger.error(f"Label at index {idx} is not scalar: {label_tensor}")
            return None
        logger.debug(f"Item {idx}: label {label} -> tensor {label_tensor}")
        return img, label_tensor

# Custom Dataset Wrapper
class CustomDataset(Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __getattr__(self, name):
        return getattr(self.dataset, name)

# # CutMix and MixUp Functions
# def rand_bbox(size, lam):
#     W = size[2]
#     H = size[3]
#     cut_rat = np.sqrt(1. - lam)
#     cut_w = int(W * cut_rat)
#     cut_h = int(H * cut_rat)
#     cx = np.random.randint(W)
#     cy = np.random.randint(H)
#     bbx1 = np.clip(cx - cut_w // 2, 0, W)
#     bby1 = np.clip(cy - cut_h // 2, 0, H)
#     bbx2 = np.clip(cx + cut_w // 2, 0, W)
#     bby2 = np.clip(cy + cut_h // 2, 0, H)
#     return bbx1, bby1, bbx2, bby2

# def cutmix(data, targets, alpha=1.0):
#     if not isinstance(targets, torch.Tensor):
#         raise ValueError(f"CutMix expects tensor targets, got {type(targets)}")
#     batch_size = data.size(0)
#     index = torch.randperm(batch_size)
#     shuffled_data = data[index]
#     shuffled_targets = targets[index]
#     lam = np.random.beta(alpha, alpha)
#     bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
#     data[:, :, bbx1:bbx2, bby1:bby2] = shuffled_data[:, :, bbx1:bbx2, bby1:bby2]
#     lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size(-1) * data.size(-2)))
#     return data, (targets, shuffled_targets, lam)

# def mixup(data, targets, alpha=0.2):
#     if not isinstance(targets, torch.Tensor):
#         raise ValueError(f"MixUp expects tensor targets, got {type(targets)}")
#     batch_size = data.size(0)
#     index = torch.randperm(batch_size)
#     shuffled_data = data[index]
#     shuffled_targets = targets[index]
#     lam = np.random.beta(alpha, alpha)
#     mixed_data = lam * data + (1 - lam) * shuffled_data
#     return mixed_data, (targets, shuffled_targets, lam)

# Custom Collate Function
# def custom_collate(batch):
#     batch = [item for item in batch if item is not None]
#     if len(batch) == 0:
#         logger.warning("Empty batch after filtering None items")
#         return None, None
#     try:
#         data = torch.stack([item[0] for item in batch])
#         logger.debug(f"Data stacked: shape {data.shape}, type {type(data)}")
#         targets = []
#         for idx, item in enumerate(batch):
#             label = item[1]
#             if not isinstance(label, torch.Tensor):
#                 logger.error(f"Non-tensor label at batch index {idx}: {type(label)}")
#                 return None, None
#             if label.numel() != 1:
#                 logger.error(f"Label at batch index {idx} is not scalar: {label}")
#                 return None, None
#             targets.append(label.item())
#         targets = torch.tensor(targets, dtype=torch.long)
#         logger.debug(f"Targets before augmentation: {targets}, type {type(targets)}")
#         augmentation = random.choices(['none', 'mixup', 'cutmix'], weights=[0.6, 0.2, 0.2], k=1)[0]
#         logger.debug(f"Applying augmentation: {augmentation}")
#         if augmentation == 'mixup':
#             data, targets = mixup(data, targets, alpha=0.2)
#             if not isinstance(targets, tuple) or len(targets) != 3:
#                 logger.error(f"MixUp returned invalid targets: {type(targets)}")
#                 return None, None
#             labels1, labels2, lam = targets
#             if not (isinstance(labels1, torch.Tensor) and isinstance(labels2, torch.Tensor) and isinstance(lam, float)):
#                 logger.error(f"MixUp targets tuple invalid: {type(labels1)}, {type(labels2)}, {type(lam)}")
#                 return None, None
#             logger.debug(f"After MixUp: data shape {data.shape}, targets {type(targets)} (labels1: {labels1.shape}, labels2: {labels2.shape}, lam: {lam})")
#         elif augmentation == 'cutmix':
#             data, targets = cutmix(data, targets, alpha=1.0)
#             if not isinstance(targets, tuple) or len(targets) != 3:
#                 logger.error(f"CutMix returned invalid targets: {type(targets)}")
#                 return None, None
#             labels1, labels2, lam = targets
#             if not (isinstance(labels1, torch.Tensor) and isinstance(labels2, torch.Tensor) and isinstance(lam, float)):
#                 logger.error(f"CutMix targets tuple invalid: {type(labels1)}, {type(labels2)}, {type(lam)}")
#                 return None, None
#             logger.debug(f"After CutMix: data shape {data.shape}, targets {type(targets)} (labels1: {labels1.shape}, labels2: {labels2.shape}, lam: {lam})")
#         else:
#             logger.debug(f"No augmentation, targets remain tensor: {targets.shape}")
#         return data, targets
#     except Exception as e:
#         logger.error(f"Error in custom_collate: {e}")
#         return None, None


# Create Datasets (Assuming train_df_balanced, val_df, test_df are defined)
try:
    train_dataset = DataFrameDataset(train_df_balanced, transform=AlbumentationsTransform(get_transforms("train")))
    val_dataset = DataFrameDataset(val_df, transform=AlbumentationsTransform(get_transforms("val")))
    test_dataset = DataFrameDataset(test_df, transform=AlbumentationsTransform(get_transforms("test")))
except Exception as e:
    logger.error(f"Error creating datasets: {e}")
    raise

# Wrap Datasets
train_dataset_custom = CustomDataset(train_dataset)
val_dataset_custom = CustomDataset(val_dataset)
test_dataset_custom = CustomDataset(test_dataset)

num_classes = len(train_dataset.classes)

# Create Data Loaders
batch_size = 64
train_loader = DataLoader(
    train_dataset_custom,
    batch_size=batch_size,
    shuffle=True,
    num_workers=0,
    pin_memory=True
    #collate_fn=custom_collate
)
val_loader = DataLoader(
    val_dataset_custom,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset_custom,
    batch_size=batch_size,
    shuffle=False,
    num_workers=0,
    pin_memory=True
)

# Dataset Summary
logger.info(f"Number of classes: {num_classes}")
logger.info(f"Classes: {train_dataset.classes}")
logger.info(f"Training dataset size: {len(train_loader.dataset)}")
logger.info(f"Validation dataset size: {len(val_loader.dataset)}")
logger.info(f"Test dataset size: {len(test_loader.dataset)}")

# Memory Optimization
torch.cuda.empty_cache()
gc.collect()


2025-05-29 12:24:29,534 - INFO - Using device: cuda
2025-05-29 12:24:52,227 - INFO - Filtered out 0 invalid rows
2025-05-29 12:24:52,236 - INFO - Dataset size after filtering: 4998
2025-05-29 12:24:52,236 - INFO - Classes: ['0_N', '1_PB', '2_UDH', '3_FEA', '4_ADH', '5_DCIS', '6_IC']
2025-05-29 12:24:52,236 - INFO - Class to index: {'0_N': 0, '1_PB': 1, '2_UDH': 2, '3_FEA': 3, '4_ADH': 4, '5_DCIS': 5, '6_IC': 6}
2025-05-29 12:24:52,236 - INFO - Class distribution:
class_name
1_PB      714
6_IC      714
0_N       714
5_DCIS    714
4_ADH     714
2_UDH     714
3_FEA     714
Name: count, dtype: int64
2025-05-29 12:24:53,656 - INFO - Filtered out 0 invalid rows
2025-05-29 12:24:53,656 - INFO - Dataset size after filtering: 311
2025-05-29 12:24:53,656 - INFO - Classes: ['0_N', '1_PB', '2_UDH', '3_FEA', '4_ADH', '5_DCIS', '6_IC']
2025-05-29 12:24:53,656 - INFO - Class to index: {'0_N': 0, '1_PB': 1, '2_UDH': 2, '3_FEA': 3, '4_ADH': 4, '5_DCIS': 5, '6_IC': 6}
2025-05-29 12:24:53,664 - INFO - Cl

536206

In [17]:
import torch
import time
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Select CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")

# Check batch loading time
for i, (inputs, labels) in enumerate(train_loader):
    start_time = time.time()

    # Skip invalid batches
    if inputs is None or labels is None:
        logger.warning(f"Skipping empty batch {i+1}")
        continue

    try:
        # Validate inputs
        if not isinstance(inputs, torch.Tensor):
            logger.error(f"Batch {i+1} inputs is not a tensor: {type(inputs)}")
            continue
        inputs = inputs.to(device)

        # Handle labels (tensor or tuple)
        if isinstance(labels, tuple):
            labels1, labels2, lam = labels
            if not isinstance(labels1, torch.Tensor) or not isinstance(labels2, torch.Tensor):
                logger.error(f"Batch {i+1} labels tuple contains non-tensor: {type(labels1)}, {type(labels2)}")
                continue
            labels = (labels1.to(device), labels2.to(device), lam)
        else:
            if not isinstance(labels, torch.Tensor):
                logger.error(f"Batch {i+1} labels is not a tensor: {type(labels)}")
                continue
            labels = labels.to(device)

        batch_time = time.time() - start_time
        logger.info(f"✅ Batch {i+1} Loaded in {batch_time:.4f} sec")
        logger.info(f"Batch {i+1} inputs shape: {inputs.shape}")
        logger.info(f"Batch {i+1} labels: {labels}")

    except Exception as e:
        logger.error(f"Error processing batch {i+1}: {e}")
        continue

    if i == 20:
        break

# Memory cleanup
torch.cuda.empty_cache()

2025-05-29 12:25:17,146 - INFO - Using device: cuda
2025-05-29 12:25:19,344 - INFO - ✅ Batch 1 Loaded in 0.0256 sec
2025-05-29 12:25:19,351 - INFO - Batch 1 inputs shape: torch.Size([64, 3, 224, 224])
2025-05-29 12:25:19,351 - INFO - Batch 1 labels: tensor([1, 0, 4, 0, 3, 6, 4, 6, 2, 6, 4, 1, 5, 4, 4, 5, 3, 1, 3, 1, 4, 6, 1, 4,
        3, 4, 6, 6, 3, 0, 0, 5, 0, 0, 6, 2, 1, 0, 3, 3, 5, 5, 1, 3, 5, 6, 4, 5,
        4, 2, 5, 5, 3, 3, 4, 6, 2, 6, 3, 5, 1, 0, 2, 5], device='cuda:0')
2025-05-29 12:25:21,337 - INFO - ✅ Batch 2 Loaded in 0.0246 sec
2025-05-29 12:25:21,337 - INFO - Batch 2 inputs shape: torch.Size([64, 3, 224, 224])
2025-05-29 12:25:21,337 - INFO - Batch 2 labels: tensor([0, 1, 1, 3, 0, 5, 6, 6, 0, 5, 1, 0, 3, 6, 0, 4, 6, 4, 1, 4, 4, 2, 0, 1,
        4, 4, 0, 5, 0, 5, 0, 5, 2, 6, 0, 1, 5, 1, 1, 2, 6, 1, 1, 0, 3, 6, 2, 1,
        0, 1, 2, 1, 5, 2, 2, 6, 2, 0, 6, 6, 6, 1, 4, 4], device='cuda:0')
2025-05-29 12:25:23,554 - INFO - ✅ Batch 3 Loaded in 0.0080 sec
2025-05-29 12:25:23,

In [14]:
len(train_loader)

79

In [15]:
len(val_loader)

5

In [16]:
len(test_loader)

9

In [50]:
# def compute_class_weights(dataset, num_classes, device):
#     """
#     Compute balanced class weights for a dataset.
    
#     Args:
#         dataset: Dataset object containing images and labels.
#         num_classes: Integer, number of classes (e.g., 7 for BraCS).
#         device: torch.device, device to move weights to (e.g., 'cuda').
    
#     Returns:
#         torch.Tensor: Tensor of class weights on the specified device.
#     """
#     import logging
#     logger = logging.getLogger(__name__)
    
#     labels = []
#     for idx, (_, label) in enumerate(dataset):
#         if label is None:
#             logger.warning(f"None label at index {idx}")
#             continue
#         if not isinstance(label, torch.Tensor):
#             logger.error(f"Non-tensor label at index {idx}: {type(label)}")
#             continue
#         labels.append(label.item())
    
#     if not labels:
#         raise ValueError("No valid labels found in dataset")
    
#     unique_labels = sorted(labels)
#     expected_classes = list(range(num_classes))
#     if unique_labels != expected_classes:
#         logger.error(f"Missing labels: Expected {expected_classes}, found {unique_labels}")
#         raise ValueError(f"Dataset labels {unique_labels} do not match expected classes {expected_classes}")
    
#     # Log class distribution
#     from collections import Counter
#     label_counts = Counter(labels)
#     logger.info(f"Class distribution: {dict(sorted(label_counts.items()))}")
    
#     from sklearn.utils.class_weight import compute_class_weight
#     class_weights = compute_class_weight('balanced', classes=np.array(expected_classes), y=labels)
#     logger.info(f"Class weights: {class_weights.tolist()}")
    
#     return torch.tensor(class_weights, dtype=torch.float).to(device)

In [51]:
import time

In [21]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import GradScaler, autocast
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, roc_auc_score
import os
from torchvision import models
from torchvision.models import ResNeXt50_32X4D_Weights, DenseNet201_Weights, EfficientNet_B0_Weights, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights

# Assuming these are from your previous code
# from your_data_preparation import train_dataset, train_loader, val_loader, test_loader, DataFrameDataset
# from your_modeling import compute_class_weights

# CBAM Components
class ChannelAttention(nn.Module):
    def __init__(self, in_channels, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(in_channels, in_channels // ratio, bias=False),
            nn.ReLU(),
            nn.Linear(in_channels // ratio, in_channels, bias=False)
        )
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        b, c, _, _ = x.size()
        avg_out = self.fc(self.avg_pool(x).view(b, c))
        max_out = self.fc(self.max_pool(x).view(b, c))
        out = avg_out + max_out
        return self.sigmoid(out).view(b, c, 1, 1)

class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = nn.Conv2d(2, 1, kernel_size=7, padding=3, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        out = torch.cat([avg_out, max_out], dim=1)
        out = self.conv(out)
        return self.sigmoid(out)

class CBAM(nn.Module):
    def __init__(self, in_channels):
        super(CBAM, self).__init__()
        self.channel_attention = ChannelAttention(in_channels)
        self.spatial_attention = SpatialAttention()

    def forward(self, x):
        ca = self.channel_attention(x)
        x = x * ca
        sa = self.spatial_attention(x)
        x = x * sa
        return x

# Insert CBAM into Model
def insert_cbam(model, cbam_channels):
    cbam = CBAM(cbam_channels)
    if isinstance(model, models.ResNet):
        model.cbam = cbam
        original_forward = model.forward
        def new_forward(self, x):
            x = self.conv1(x)
            x = self.cbam(x)
            x = self.bn1(x)
            x = self.relu(x)
            x = self.maxpool(x)
            x = self.layer1(x)
            x = self.layer2(x)
            x = self.layer3(x)
            x = self.layer4(x)
            x = self.avgpool(x)
            x = torch.flatten(x, 1)
            x = self.fc(x)
            return x
        model.forward = new_forward.__get__(model)
    elif isinstance(model, (models.DenseNet, models.EfficientNet)):
        features = model.features
        new_features = nn.Sequential(
            features[0],
            cbam,
            *features[1:]
        )
        model.features = new_features
    else:
        raise ValueError("Model type not supported for CBAM insertion.")
    return model


# Early Stopping
class EarlyStopping:
    def __init__(self, patience=10, verbose=False, delta=0, save_dir=None):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_loss = float('inf')
        self.best_epoch = 0
        self.early_stop = False
        self.delta = delta
        self.save_dir = save_dir

    def __call__(self, val_loss, epoch, model_weights, model_name_prefix):
        if val_loss < self.best_loss - self.delta:
            self.best_loss = val_loss
            self.best_epoch = epoch
            self.counter = 0
            self.save_best_weights(model_weights, model_name_prefix)
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
                if self.verbose:
                    print(f"Early stopping triggered after {self.counter} epochs of no improvement.")

    def save_best_weights(self, model_weights, model_name_prefix):
        os.makedirs(self.save_dir, exist_ok=True)
        model_name = os.path.join(self.save_dir, f"{model_name_prefix}_epoch_{self.best_epoch + 1}.pth")
        torch.save(model_weights, model_name)
        if self.verbose:
            print(f"✅ Best model weights saved to {model_name}.")

import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNeXt50_32X4D_Weights, DenseNet201_Weights, EfficientNet_B0_Weights, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
import torch
import torch.nn as nn
import torchvision.models as models
from torchvision.models import ResNeXt50_32X4D_Weights, DenseNet201_Weights, EfficientNet_B0_Weights, ResNet18_Weights, ResNet50_Weights, ResNeXt101_32X8D_Weights
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Placeholder for insert_cbam
def insert_cbam(model, cbam_channels):
    """
    Insert CBAM modules into the model.
    Args:
        model: PyTorch model (e.g., ResNeXt, EfficientNet).
        cbam_channels: Number of channels for CBAM.
    Returns:
        Modified model with CBAM.
    """
    logger.info(f"Inserting CBAM with {cbam_channels} channels")
    return model  # Replace with actual CBAM insertion

# CustomClassifier (unchanged)
class CustomClassifier(nn.Module):
    def __init__(self, in_features, num_classes):
        super(CustomClassifier, self).__init__()
        self.fc1 = nn.Linear(in_features, 512)
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.6)
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        x = self.fc3(x)
        return x

# Placeholder for SupConLoss
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature
    
    def forward(self, features, labels):
        # Replace with actual implementation
        return torch.tensor(0.0, device=features.device)

# Updated ContrastiveModel
class ContrastiveModel(nn.Module):
    def __init__(self, base_model, num_classes, feature_dim):
        super(ContrastiveModel, self).__init__()
        self.num_classes = num_classes
        self.feature_dim = feature_dim
        
        # Store backbone
        self.backbone = base_model
        
        # Determine classifier attribute
        if hasattr(self.backbone, 'fc'):
            self.classifier_attr = 'fc'
        elif hasattr(self.backbone, 'classifier'):
            self.classifier_attr = 'classifier'
        else:
            raise ValueError("Backbone has neither 'fc' nor 'classifier' attribute")
        
        # Create feature extractor by removing the final classifier
        self.feature_extractor = nn.Sequential(*list(self.backbone.children())[:-1])
        
        # Projection head uses feature_dim (e.g., 2048 for resnext50)
        self.projection_head = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.ReLU(),
            nn.Linear(512, 128)
        )

    def forward(self, x):
        # Extract features before the final classifier
        features = self.feature_extractor(x)
        features = features.view(features.size(0), -1)  # Flatten: [batch_size, feature_dim]
        
        # Get logits from the classifier
        classifier = getattr(self.backbone, self.classifier_attr)
        logits = classifier(features)  # [batch_size, num_classes]
        
        # Get projections
        proj = self.projection_head(features)  # [batch_size, 128]
        
        return logits, proj

# Updated get_model Function
def get_model(model_name, num_classes, device):
    """
    Load and configure a model with CBAM and contrastive learning.
    
    Args:
        model_name: Name of the model ('resnext50', 'efficientnet_b0', etc.).
        num_classes: Number of output classes (7 for BraCS).
        device: Device to run the model on ('cuda' or 'cpu').
    
    Returns:
        Configured model.
    """
    logger.info(f"Loading model: {model_name} with {num_classes} classes on {device}")

    # Load base model with pretrained weights
    if model_name == "resnext50":
        base_model = models.resnext50_32x4d(weights=ResNeXt50_32X4D_Weights.DEFAULT)
        cbam_channels = 64
    elif model_name == "densenet201":
        base_model = models.densenet201(weights=DenseNet201_Weights.DEFAULT)
        cbam_channels = 64
    elif model_name == "efficientnet_b0":
        base_model = models.efficientnet_b0(weights=EfficientNet_B0_Weights.DEFAULT)
        cbam_channels = 32
    elif model_name == "resnet18":
        base_model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
        cbam_channels = 64
    elif model_name == "resnet50":
        base_model = models.resnet50(weights=ResNet50_Weights.DEFAULT)
        cbam_channels = 64
    elif model_name == "resnext101":
        base_model = models.resnext101_32x8d(weights=ResNeXt101_32X8D_Weights.DEFAULT)
        cbam_channels = 64
    else:
        raise ValueError(f"Model {model_name} not supported.")

    # Insert CBAM
    base_model = insert_cbam(base_model, cbam_channels)

    # Determine input features for classifier
    if hasattr(base_model, 'fc'):
        in_features = base_model.fc.in_features
        classifier_attr = 'fc'
    elif hasattr(base_model, 'classifier'):
        if isinstance(base_model.classifier, nn.Linear):
            in_features = base_model.classifier.in_features
        elif isinstance(base_model.classifier, nn.Sequential):
            for layer in reversed(base_model.classifier):
                if isinstance(layer, nn.Linear):
                    in_features = layer.in_features
                    break
            else:
                raise ValueError("No Linear layer found in classifier")
        classifier_attr = 'classifier'
    else:
        raise ValueError("Cannot find classifier layer (neither 'fc' nor 'classifier')")

    # Replace classifier with CustomClassifier
    custom_classifier = CustomClassifier(in_features, num_classes)
    setattr(base_model, classifier_attr, custom_classifier)

    # Initialize ContrastiveModel with feature_dim
    try:
        model = ContrastiveModel(base_model, num_classes, feature_dim=in_features)
    except Exception as e:
        logger.error(f"Failed to initialize ContrastiveModel: {e}")
        raise

    logger.info(f"Model {model_name} loaded successfully with {in_features} input features to classifier")
    return model.to(device)

import torch
import torch.nn as nn
import os
import logging
from torch.amp import GradScaler, autocast

# Set up logging
logger = logging.getLogger(__name__)

import torch
import torch.nn as nn
import os
import logging
from torch.amp import GradScaler, autocast

logger = logging.getLogger(__name__)

def train_and_validate(model, train_loader, val_loader, optimizer, scheduler, 
                     model_name_prefix, epochs=25, device=None, early_stopping=None, 
                     save_dir=r"O:\O drive\AI\my project\medical image projects\BraCs", accum_steps=6):
    """
    Train and validate the model with manually defined class weights, CutMix/MixUp, and fixed autocast.

    Args:
        model: PyTorch model (ResNeXt50 with CBAM and contrastive learning).
        train_loader: DataLoader for training data.
        val_loader: DataLoader for validation data.
        optimizer: PyTorch optimizer.
        scheduler: Learning rate scheduler.
        model_name_prefix: Prefix for saved model files.
        epochs: Number of training epochs (default: 25).
        device: Device to run the model on (e.g., torch.device('cuda')).
        early_stopping: Optional early stopping object.
        save_dir: Directory to save model weights and metrics.
        accum_steps: Number of batches to accumulate gradients (default: 4).
    
    Returns:
        Trained model.
    """
    os.makedirs(save_dir, exist_ok=True)

    # Validate device type
    if device is None or device.type != 'cuda':
        logger.warning(f"Device is {device}. Mixed precision requires CUDA. Falling back to float32.")
        use_amp = False
    else:
        use_amp = True
        logger.info("Using mixed precision training with CUDA.")

    # Manually define class weights for 7 classes
    class_weights = torch.tensor([1.0, 3.0, 3.0, 1.0, 3.0, 1.0, 1.0], dtype=torch.float).to(device)
    logger.info(f"Class weights: {class_weights.tolist()}")

    # Loss functions
    ce_criterion = nn.CrossEntropyLoss(weight=class_weights)
    supcon_criterion = SupConLoss(temperature=0.7)  # Assuming SupConLoss is defined
    lambda_supcon = 0.4

    train_losses = []
    valid_losses = []
    train_accuracies = []
    valid_accuracies = []

    # Initialize scaler
    scaler = GradScaler() if use_amp else None

    for epoch in range(epochs):
        print(f'Epoch {epoch + 1}/{epochs}')
        print('-' * 50)

        # Training phase
        model.train()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0.0
        optimizer.zero_grad()  # Clear gradients at start


        for batch_idx, (inputs, labels) in enumerate(train_loader):
            # Skip invalid batches
            if inputs is None or labels is None:
                logger.debug("Skipping invalid batch due to None values")
                continue

            logger.debug(f"Batch labels type: {type(labels)}")
            # Handle CutMix/MixUp labels
            if isinstance(labels, tuple):
                labels1, labels2, lam = labels
                labels1, labels2 = labels1.to(device), labels2.to(device)

            else:
                if not isinstance(labels, torch.Tensor):
                    logger.error(f"Invalid labels type: {type(labels)}, expected tensor")
                    continue
                labels = labels.to(device)
            inputs = inputs.to(device)

            if use_amp:
                with autocast(device_type='cuda', dtype=torch.float16, enabled=True):
                    logits, proj = model(inputs)
                    if isinstance(labels, tuple):
                        ce_loss = lam * ce_criterion(logits, labels1) + (1 - lam) * ce_criterion(logits, labels2)
                    else:
                        ce_loss = ce_criterion(logits, labels)
                    supcon_loss = supcon_criterion(proj, labels if not isinstance(labels, tuple) else labels1)
                    total_loss = (ce_loss + lambda_supcon * supcon_loss) / accum_steps  # Scale loss
            else:
                logits, proj = model(inputs)
                if isinstance(labels, tuple):
                    ce_loss = lam * ce_criterion(logits, labels1) + (1 - lam) * ce_criterion(logits, labels2)
                else:
                    ce_loss = ce_criterion(logits, labels)
                supcon_loss = supcon_criterion(proj, labels if not isinstance(labels, tuple) else labels1)
                total_loss = (ce_loss + lambda_supcon * supcon_loss) / accum_steps

            if use_amp:
                scaler.scale(total_loss).backward()
                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1) == len(train_loader):
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
            else:
                total_loss.backward()
                if (batch_idx + 1) % accum_steps == 0 or (batch_idx + 1) == len(train_loader):
                    optimizer.step()
                    optimizer.zero_grad()

            running_loss += total_loss.item() * inputs.size(0) * accum_steps
            _, preds = torch.max(logits, 1)
            if isinstance(labels, tuple):
                correct_preds += lam * (preds == labels1).sum().item() + (1 - lam) * (preds == labels2).sum().item()
            else:
                correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0) if not isinstance(labels, tuple) else labels1.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc = correct_preds / total_preds
        train_losses.append(epoch_loss)
        train_accuracies.append(epoch_acc)
        print(f'Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

        # Validation phase
        model.eval()
        running_loss = 0.0
        correct_preds = 0
        total_preds = 0
        all_labels = []
        all_preds = []

        with torch.no_grad():
            for inputs, labels in val_loader:
                # Skip invalid batches
                if inputs is None or labels is None:
                    logger.debug("Skipping invalid validation batch")
                    continue

                inputs, labels = inputs.to(device), labels.to(device)
                if use_amp:
                    with autocast(device_type='cuda', dtype=torch.float16, enabled=True):
                        logits, _ = model(inputs)
                        loss = ce_criterion(logits, labels)
                else:
                    logits, _ = model(inputs)
                    loss = ce_criterion(logits, labels)
                running_loss += loss.item() * inputs.size(0)
                _, preds = torch.max(logits, 1)
                correct_preds += (preds == labels).sum().item()
                total_preds += labels.size(0)
                all_labels.extend(labels.cpu().numpy())
                all_preds.extend(preds.cpu().numpy())

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc = correct_preds / total_preds
        valid_losses.append(epoch_loss)
        valid_accuracies.append(epoch_acc)
        print(f'Validation Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.4f}')

        scheduler.step(epoch_loss)

        if early_stopping:
            early_stopping(epoch_loss, epoch, model.state_dict(), model_name_prefix)
            if early_stopping.early_stop:
                print(f"🚨 Early stopping triggered at epoch {epoch+1}.")
                break

        if epoch % 5 == 0 or (early_stopping and early_stopping.early_stop):
            compute_metrics(model, val_loader, device, epoch + 1, "val", save_dir)  # Assuming compute_metrics is defined

        torch.cuda.empty_cache()  # Clear memory after each epoch

    # Load best model if early stopping was used
    if early_stopping and early_stopping.best_epoch is not None:
        print(f"Loading best model weights from epoch {early_stopping.best_epoch + 1}...")
        best_model_path = os.path.join(save_dir, f"{model_name_prefix}_epoch_{early_stopping.best_epoch + 1}.pth")
        model.load_state_dict(torch.load(best_model_path))
    else:
        print("No early stopping triggered or best epoch not set. Keeping final epoch weights.")

    plot_loss_accuracy(train_losses, valid_losses, train_accuracies, valid_accuracies, save_dir)  # Assuming plot_loss_accuracy is defined
    return model


In [22]:

# Testing
def test_model(model, test_loader, device, model_name_prefix, 
               save_dir=r"O:\O drive\AI\my project\medical image projects\BraCs"):
    os.makedirs(save_dir, exist_ok=True)

    model.eval()
    running_loss = 0.0
    correct_preds = 0
    total_preds = 0
    all_labels = []
    all_preds = []

    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            with torch.amp.autocast('cuda', dtype=torch.float16, enabled=True):
                logits, _ = model(inputs)
                loss = criterion(logits, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(logits, 1)
            correct_preds += (preds == labels).sum().item()
            total_preds += labels.size(0)
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

    test_loss = running_loss / len(test_loader.dataset)
    test_accuracy = correct_preds / total_preds
    print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")

    compute_metrics(model, test_loader, device, "test", "test", save_dir)

# Metrics Computation
def compute_metrics(model, dataloader, device, epoch, split_name, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    model.eval()
    all_probs = []
    all_preds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            logits, _ = model(inputs)
            probs = torch.softmax(logits, dim=1)
            _, preds = torch.max(logits, 1)
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)

    auc_scores = {}
    for i, class_name in enumerate(dataloader.dataset.classes):
        auc_scores[class_name] = roc_auc_score(
            (all_labels == i).astype(int), all_probs[:, i]
        )
    print(f"{split_name.capitalize()} AUC-ROC Scores: {auc_scores}")

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues",
                xticklabels=dataloader.dataset.classes,
                yticklabels=dataloader.dataset.classes)
    plt.title(f"{split_name.capitalize()} Confusion Matrix - Epoch {epoch}")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.savefig(os.path.join(save_dir, f"{split_name}_confusion_matrix_epoch_{epoch}.png"))
    plt.close()

    report = classification_report(all_labels, all_preds, 
                                  target_names=dataloader.dataset.classes, digits=4)
    with open(os.path.join(save_dir, f"{split_name}_classification_report_epoch_{epoch}.txt"), "w") as f:
        f.write(report)
    print(f"{split_name.capitalize()} Classification Report:\n{report}")

    return auc_scores

# Plotting Loss and Accuracy
def plot_loss_accuracy(train_losses, valid_losses, train_accuracies, valid_accuracies, save_dir):
    os.makedirs(save_dir, exist_ok=True)

    epochs = range(1, len(train_losses) + 1)
    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss', color='blue', marker='o')
    plt.plot(epochs, valid_losses, label='Validation Loss', color='red', marker='o')
    plt.title('Loss per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Training Accuracy', color='blue', marker='o')
    plt.plot(epochs, valid_accuracies, label='Validation Accuracy', color='red', marker='o')
    plt.title('Accuracy per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig(os.path.join(save_dir, 'training_validation_metrics.png'))
    plt.close()


In [23]:
torch.cuda.empty_cache()
gc.collect()

56

In [24]:

# Main Execution
if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    num_classes = len(train_dataset.classes)
    
    model = get_model("resnet18", num_classes, device)
    optimizer = optim.AdamW(model.parameters(), lr=0.00008, weight_decay=0.008)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.8, patience=10)
    early_stopping = EarlyStopping(
        patience=10, verbose=True, delta=0, 
        save_dir=r"O:\O drive\AI\my project\medical image projects\BraCs"
    )
    model_name_prefix = "resnet_18_contrastive_bracs"
    
    print(f"Training resnext50 with CBAM and Contrastive Learning for {num_classes} classes...")
    model = train_and_validate(
        model, train_loader, val_loader, optimizer, scheduler,
        model_name_prefix, epochs=50, device=device, early_stopping=early_stopping
    )

2025-05-29 12:28:18,880 - INFO - Loading model: resnet18 with 7 classes on cuda
Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to O:\O drive\AI\my project\medical image projects\BraCs\New folder\hub\checkpoints\resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:16<00:00, 2.90MB/s]
2025-05-29 12:28:36,195 - INFO - Inserting CBAM with 64 channels
2025-05-29 12:28:36,203 - INFO - Model resnet18 loaded successfully with 512 input features to classifier
2025-05-29 12:28:37,469 - INFO - Using mixed precision training with CUDA.
2025-05-29 12:28:37,469 - INFO - Class weights: [1.0, 3.0, 3.0, 1.0, 3.0, 1.0, 1.0]


Training resnext50 with CBAM and Contrastive Learning for 7 classes...
Epoch 1/50
--------------------------------------------------
Training Loss: 1.8573, Accuracy: 0.1573
Validation Loss: 1.8325, Accuracy: 0.2154
✅ Best model weights saved to O:\O drive\AI\my project\medical image projects\BraCs\resnet_18_contrastive_bracs_epoch_1.pth.
Val AUC-ROC Scores: {'0_N': 0.5357670221493027, '1_PB': 0.7378514404720584, '2_UDH': 0.6690730106644791, '3_FEA': 0.48426546190995484, '4_ADH': 0.5968383017163505, '5_DCIS': 0.7213423831070889, '6_IC': 0.9539007092198581}


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


Val Classification Report:
              precision    recall  f1-score   support

         0_N     0.0000    0.0000    0.0000        46
        1_PB     0.2044    0.6512    0.3111        43
       2_UDH     0.2437    0.6304    0.3515        46
       3_FEA     0.0000    0.0000    0.0000        49
       4_ADH     0.1818    0.2439    0.2083        41
      5_DCIS     0.0000    0.0000    0.0000        39
        6_IC     0.0000    0.0000    0.0000        47

    accuracy                         0.2154       311
   macro avg     0.0900    0.2179    0.1244       311
weighted avg     0.0883    0.2154    0.1225       311

Epoch 2/50
--------------------------------------------------
Training Loss: 1.7383, Accuracy: 0.1867
Validation Loss: 1.7599, Accuracy: 0.2283
✅ Best model weights saved to O:\O drive\AI\my project\medical image projects\BraCs\resnet_18_contrastive_bracs_epoch_2.pth.
Epoch 3/50
--------------------------------------------------
Training Loss: 1.6452, Accuracy: 0.2187
Valid

In [25]:

print("Testing the model...")
test_model(model, test_loader, device, model_name_prefix)

Testing the model...
Test Loss: 1.2268, Test Accuracy: 0.5965
Test AUC-ROC Scores: {'0_N': 0.9420838698275644, '1_PB': 0.8069297996854778, '2_UDH': 0.8395141943222711, '3_FEA': 0.9512629573736424, '4_ADH': 0.7915130578256723, '5_DCIS': 0.8842935112189206, '6_IC': 0.9796258426115276}
Test Classification Report:
              precision    recall  f1-score   support

         0_N     0.7463    0.6173    0.6757        81
        1_PB     0.4157    0.4684    0.4405        79
       2_UDH     0.4211    0.5854    0.4898        82
       3_FEA     0.7586    0.7952    0.7765        83
       4_ADH     0.4133    0.3924    0.4026        79
      5_DCIS     0.6406    0.4824    0.5503        85
        6_IC     0.9054    0.8272    0.8645        81

    accuracy                         0.5965       570
   macro avg     0.6144    0.5954    0.6000       570
weighted avg     0.6162    0.5965    0.6013       570



# XXX