In [None]:
!python --version

In [None]:
import torch
torch.__version__

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
cd /content/drive/MyDrive

In [None]:
import matplotlib.pyplot as plt
import torch
import torchvision

from torch import nn
from torchvision import transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device # select gpu from edit -> notebook settings ->gpu

In [None]:
# Setup directory paths to train and test images
train_dir = 'AI_demos/custom_dataset/train'
test_dir = 'AI_demos/custom_dataset/test'

In [None]:
#  Create Datasets and DataLoaders

import os

from torchvision import datasets, transforms
from torch.utils.data import DataLoader

NUM_WORKERS = os.cpu_count()

def create_dataloaders(
    train_dir: str,
    test_dir: str,
    transform: transforms.Compose,
    batch_size: int,
    num_workers: int=NUM_WORKERS
):

  # Use ImageFolder to create dataset(s)
  train_data = datasets.ImageFolder(train_dir, transform=transform)
  test_data = datasets.ImageFolder(test_dir, transform=transform)

  # Get class names
  class_names = train_data.classes

  # Turn images into data loaders
  train_dataloader = DataLoader(
      train_data,
      batch_size=batch_size,
      shuffle=True,
      num_workers=num_workers,
      pin_memory=True,
  )
  test_dataloader = DataLoader(
      test_data,
      batch_size=batch_size,
      shuffle=False,
      num_workers=num_workers,
      pin_memory=True,
  )

  return train_dataloader, test_dataloader, class_names

In [None]:
# Create image size
IMG_SIZE = 224

# Create transform pipeline manually
manual_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])
print(f"Manually created transforms: {manual_transforms}")

In [None]:
# Set the batch size
BATCH_SIZE = 32

# Create data loaders
train_dataloader, test_dataloader, class_names = create_dataloaders(
    train_dir=train_dir,
    test_dir=test_dir,
    transform=manual_transforms,
    batch_size=BATCH_SIZE
)

train_dataloader, test_dataloader, class_names

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os
import torch
from torchvision import transforms
import matplotlib.pyplot as plt

class BreastCancerDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['benign', 'malignant']

        # Get all image paths and their filenames
        self.image_paths = []
        self.labels = []
        self.filenames = []

        for class_idx, class_name in enumerate(self.classes):
            class_path = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_path):
                if filename.endswith(('.jpg', '.png', '.jpeg')):
                    self.image_paths.append(os.path.join(class_path, filename))
                    self.labels.append(class_idx)
                    self.filenames.append(filename)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label, self.filenames[idx]

# Define transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Create datasets for train and test
train_dataset = BreastCancerDataset(data_dir='AI_demos/custom_dataset/train', transform=transform)
test_dataset = BreastCancerDataset(data_dir='AI_demos/custom_dataset/test', transform=transform)

# Create dataloaders
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Visualize an image from training set
image_batch, label_batch, filename_batch = next(iter(train_dataloader))

# Get first image details
image = image_batch[0]
label = label_batch[0]
filename = filename_batch[0]

print(f"Displaying image: {filename}")
print(f"Label: {'benign' if label == 0 else 'malignant'}")
print(f"Image shape: {image.shape}")

plt.figure(figsize=(8, 8))
plt.imshow(image.permute(1, 2, 0))
plt.title(f"Class: {'benign' if label == 0 else 'malignant'}\nFilename: {filename}")
plt.axis(False)
plt.show()

# Print dataset sizes
print(f"\nDataset sizes:")
print(f"Training set: {len(train_dataset)} images")
print(f"Test set: {len(test_dataset)} images")

# Print class distribution in training set
train_labels = [label for _, label, _ in train_dataset]
benign_count = sum(1 for label in train_labels if label == 0)
malignant_count = sum(1 for label in train_labels if label == 1)

print(f"\nTraining set class distribution:")
print(f"Benign: {benign_count} images")
print(f"Malignant: {malignant_count} images")

In [None]:
# Breast Cancer Classification System
# A comprehensive system for classifying breast cancer images as benign or malignant
# using a Vision Transformer (ViT) and CNN ensemble approach

import os
import math
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision import transforms
import timm

from sklearn.metrics import confusion_matrix, classification_report

import logging
import datetime

# 1. Device Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# 2. Logging Setup
def setup_logger(log_filename=None):
    """Set up logger for training/evaluation"""
    # Create logger
    logger = logging.getLogger('BreastCancerClassifier')
    logger.setLevel(logging.INFO)

    # Create formatter
    formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')

    # Create console handler
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(formatter)
    logger.addHandler(console_handler)

    # Create file handler if filename is provided
    if log_filename:
        file_handler = logging.FileHandler(log_filename)
        file_handler.setLevel(logging.INFO)
        file_handler.setFormatter(formatter)
        logger.addHandler(file_handler)

    return logger

# 3. Dataset Class with Advanced Medical Image Preprocessing
class BreastCancerDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir, transform=None, apply_preprocessing=True):
        self.data_dir = data_dir
        self.transform = transform
        self.classes = ['benign', 'malignant']
        self.apply_preprocessing = apply_preprocessing

        # Get all image paths and their filenames
        self.image_paths = []
        self.labels = []
        self.filenames = []

        for class_idx, class_name in enumerate(self.classes):
            class_path = os.path.join(data_dir, class_name)
            for filename in os.listdir(class_path):
                if filename.endswith(('.jpg', '.png', '.jpeg')):
                    self.image_paths.append(os.path.join(class_path, filename))
                    self.labels.append(class_idx)
                    self.filenames.append(filename)

    def __len__(self):
        return len(self.image_paths)

    def apply_advanced_preprocessing(self, img):
        """Apply advanced preprocessing with better error handling"""
        try:
            # Convert to numpy for OpenCV processing
            np_img = np.array(img)

            # Check for valid image dimensions
            if np_img.ndim != 3 or np_img.shape[2] != 3:
                print(f"Warning: Image has unexpected dimensions {np_img.shape}. Skipping advanced preprocessing.")
                return img

            try:
                # Convert to LAB color space
                lab = cv2.cvtColor(np_img, cv2.COLOR_RGB2LAB)

                # Apply CLAHE to L channel with stronger parameters
                clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
                lab[:,:,0] = clahe.apply(lab[:,:,0])

                # Convert back to RGB
                enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
            except cv2.error as e:
                print(f"OpenCV error during color conversion: {e}. Using original image.")
                return img

            try:
                # Apply bilateral filter for noise reduction while preserving edges
                enhanced = cv2.bilateralFilter(enhanced, 9, 75, 75)
            except Exception as e:
                print(f"Error during bilateral filtering: {e}. Continuing with partial preprocessing.")

            try:
                # Apply unsharp masking for edge enhancement
                gaussian = cv2.GaussianBlur(enhanced, (0, 0), 3.0)
                enhanced = cv2.addWeighted(enhanced, 1.5, gaussian, -0.5, 0)
            except Exception as e:
                print(f"Error during unsharp masking: {e}. Using partially preprocessed image.")

            return Image.fromarray(enhanced)

        except Exception as e:
            print(f"Unexpected error in preprocessing: {e}. Using original image.")
            return img

    def __getitem__(self, idx):
        image_path = self.image_paths[idx]
        image = Image.open(image_path).convert('RGB')
        label = self.labels[idx]

        # Apply advanced preprocessing for better feature visibility
        if self.apply_preprocessing:
            try:
                image = self.apply_advanced_preprocessing(image)
            except Exception as e:
                print(f"Preprocessing error: {e}. Using original image.")

        if self.transform:
            image = self.transform(image)

        return image, label, self.filenames[idx]

# 4. Custom Dataset for Split Data
class BreastCancerDatasetFromList(torch.utils.data.Dataset):
    def __init__(self, file_list, transform=None, apply_preprocessing=True):
        self.file_list = file_list
        self.transform = transform
        self.apply_preprocessing = apply_preprocessing
        self.classes = ['benign', 'malignant']

    def __len__(self):
        return len(self.file_list)

    def apply_advanced_preprocessing(self, img):
        # Same preprocessing as in the main dataset class
        try:
            # Convert to numpy for OpenCV processing
            np_img = np.array(img)

            # Check for valid image dimensions
            if np_img.ndim != 3 or np_img.shape[2] != 3:
                return img

            # Convert to LAB color space
            lab = cv2.cvtColor(np_img, cv2.COLOR_RGB2LAB)

            # Apply CLAHE to L channel with stronger parameters
            clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
            lab[:,:,0] = clahe.apply(lab[:,:,0])

            # Convert back to RGB
            enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

            # Apply bilateral filter for noise reduction while preserving edges
            enhanced = cv2.bilateralFilter(enhanced, 9, 75, 75)

            # Apply unsharp masking for edge enhancement
            gaussian = cv2.GaussianBlur(enhanced, (0, 0), 3.0)
            enhanced = cv2.addWeighted(enhanced, 1.5, gaussian, -0.5, 0)

            return Image.fromarray(enhanced)
        except Exception:
            return img

    def __getitem__(self, idx):
        image_path, label = self.file_list[idx]
        image = Image.open(image_path).convert('RGB')
        filename = os.path.basename(image_path)

        # Apply advanced preprocessing
        if self.apply_preprocessing:
            try:
                image = self.apply_advanced_preprocessing(image)
            except:
                pass  # If enhancement fails, use original image

        if self.transform:
            image = self.transform(image)

        return image, label, filename

# 5. Data Transforms
def get_transforms(train=True):
    """Enhanced specialized transforms for medical imaging"""
    if train:
        return transforms.Compose([
            # Enhanced standard transforms
            transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),  # Wider scale variation
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),  # Medical images don't have fixed orientation
            transforms.RandomRotation(90),    # More aggressive rotation
            transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),  # Simulate focus variations
            transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.2, hue=0.1),  # More aggressive

            # Convert to tensor (must be before tensor-based transforms)
            transforms.ToTensor(),

            # Tensor-based transforms
            transforms.RandomErasing(p=0.4, scale=(0.02, 0.25)),  # More aggressive erasing
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    else:
        return transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

# 6. Model Architecture: ViT-CNN Ensemble
class ViTCNNEnsemble(nn.Module):
    def __init__(self, num_classes=2):
        super(ViTCNNEnsemble, self).__init__()

        # ViT branch - Use a stronger variant
        self.vit = timm.create_model(
            'vit_base_patch16_224_dino',  # Stronger ViT model
            pretrained=True,
            num_classes=0  # Get feature vector only
        )

        # CNN branch - Use a strong CNN architecture
        self.cnn = timm.create_model(
            'resnet50d',  # Stronger CNN backbone
            pretrained=True,
            num_classes=0  # Get feature vector only
        )

        # Hidden dimension sizes
        self.vit_dim = self.vit.embed_dim  # Usually 768 for vit_base
        self.cnn_dim = 2048  # ResNet50 feature dim

        # Add Stochastic Depth (helps with regularization)
        self.vit.drop_path_rate = 0.2

        # Feature attention module
        self.vit_attention = nn.Sequential(
            nn.Linear(self.vit_dim, self.vit_dim),
            nn.LayerNorm(self.vit_dim),
            nn.GELU(),
            nn.Linear(self.vit_dim, self.vit_dim),
            nn.Sigmoid()
        )

        self.cnn_attention = nn.Sequential(
            nn.Linear(self.cnn_dim, self.cnn_dim),
            nn.LayerNorm(self.cnn_dim),
            nn.GELU(),
            nn.Linear(self.cnn_dim, self.cnn_dim),
            nn.Sigmoid()
        )

        # Feature fusion module
        self.fusion = nn.Sequential(
            nn.Linear(self.vit_dim + self.cnn_dim, 1024),
            nn.LayerNorm(1024),
            nn.GELU(),
            nn.Dropout(0.4)
        )

        # Classification head with strong regularization
        self.classifier = nn.Sequential(
            nn.Linear(1024, 512),
            nn.LayerNorm(512),  # Better normalization
            nn.GELU(),  # GELU often works better than ReLU with transformers
            nn.Dropout(0.3),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Extract features from both branches
        vit_features = self.vit(x)
        cnn_features = self.cnn(x)

        # Apply attention to both feature sets
        vit_attention = self.vit_attention(vit_features)
        cnn_attention = self.cnn_attention(cnn_features)

        # Apply attention mechanism
        vit_features = vit_features * vit_attention
        cnn_features = cnn_features * cnn_attention

        # Concatenate features
        combined_features = torch.cat([vit_features, cnn_features], dim=1)

        # Apply fusion and classification
        fused = self.fusion(combined_features)
        out = self.classifier(fused)

        return out

# 7. Custom loss function: Focal Loss
class FocalLoss(nn.Module):
    """Focal Loss for addressing class imbalance"""
    def __init__(self, alpha=None, gamma=2.0, reduction='mean', device='cuda'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.reduction = reduction

        # If alpha is provided, it should be a tensor of class weights
        if alpha is not None:
            if isinstance(alpha, list) or isinstance(alpha, np.ndarray):
                self.alpha = torch.tensor(alpha, dtype=torch.float32).to(device)
            else:
                self.alpha = alpha
        else:
            self.alpha = None

    def forward(self, inputs, targets):
        # Get standard cross entropy loss
        ce_loss = F.cross_entropy(inputs, targets, reduction='none', weight=self.alpha)

        # Get probabilities
        pt = torch.exp(-ce_loss)

        # Apply focal weighting
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# 8. Custom Learning Rate Scheduler
class WarmupCosineScheduler:
    """Implements learning rate warmup and cosine decay."""
    def __init__(self, optimizer, warmup_epochs, total_epochs, min_lr_ratio=0.1):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.min_lr_ratio = min_lr_ratio

        # Store base learning rates for each parameter group
        self.base_lrs = [param_group['lr'] for param_group in optimizer.param_groups]

    def step(self, epoch):
        """Update learning rate based on current epoch"""
        if epoch < self.warmup_epochs:
            # Linear warmup
            lr_factor = epoch / self.warmup_epochs
        else:
            # Cosine annealing
            progress = (epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr_factor = self.min_lr_ratio + 0.5 * (1 - self.min_lr_ratio) * (1 + math.cos(math.pi * progress))

        # Update learning rates for each parameter group
        for i, param_group in enumerate(self.optimizer.param_groups):
            param_group['lr'] = self.base_lrs[i] * lr_factor

# 9. MixUp Data Augmentation
def mixup_data(x, y, alpha=0.2):
    '''Returns mixed inputs and targets'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# 10. Dataset Splitting Function
def create_dataset_splits(data_dir, val_split=0.15, test_split=0.15, random_seed=42):
    """Create train, validation, and test splits"""
    np.random.seed(random_seed)

    classes = ['benign', 'malignant']
    train_files = []
    val_files = []
    test_files = []

    for class_idx, class_name in enumerate(classes):
        class_path = os.path.join(data_dir, class_name)
        all_files = [os.path.join(class_path, f) for f in os.listdir(class_path)
                    if f.endswith(('.jpg', '.png', '.jpeg'))]

        # Shuffle files
        np.random.shuffle(all_files)

        # Calculate split indices
        n_files = len(all_files)
        n_test = int(test_split * n_files)
        n_val = int(val_split * n_files)

        # Split files
        test_files.extend([(f, class_idx) for f in all_files[:n_test]])
        val_files.extend([(f, class_idx) for f in all_files[n_test:n_test+n_val]])
        train_files.extend([(f, class_idx) for f in all_files[n_test+n_val:]])

    print(f"Split created: {len(train_files)} training, {len(val_files)} validation, {len(test_files)} test files")

    return train_files, val_files, test_files

# 11. DataLoader Creation Functions
def create_dataloaders(batch_size=16):
    """Create dataloaders with class-balanced sampling"""
    # Create datasets with appropriate transforms
    train_dataset = BreastCancerDataset(
        data_dir='AI_demos/custom_dataset/train',
        transform=get_transforms(train=True),
        apply_preprocessing=True  # Apply advanced preprocessing
    )
    test_dataset = BreastCancerDataset(
        data_dir='AI_demos/custom_dataset/test',
        transform=get_transforms(train=False),
        apply_preprocessing=True  # Apply same preprocessing to test
    )

    # Calculate class weights for weighted sampling
    train_labels = torch.tensor([label for _, label, _ in train_dataset])
    class_counts = torch.bincount(train_labels)
    total_samples = sum(class_counts)

    # Calculate inverse weights for more aggressive balancing
    inverse_weights = total_samples / (class_counts * len(class_counts))
    sample_weights = inverse_weights[train_labels]

    # Create sampler for balanced training
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(train_dataset),
        replacement=True
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=os.cpu_count() or 2,
        pin_memory=True
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count() or 2,
        pin_memory=True
    )

    # Get class names
    class_names = train_dataset.classes

    return train_dataloader, test_dataloader, class_names

def create_split_dataloaders(data_dir, batch_size=16):
    """Create dataloaders with separate train/val/test splits"""
    # Create data splits
    train_files, val_files, test_files = create_dataset_splits(data_dir)

    # Create datasets with appropriate transforms
    train_dataset = BreastCancerDatasetFromList(
        file_list=train_files,
        transform=get_transforms(train=True),
        apply_preprocessing=True
    )

    val_dataset = BreastCancerDatasetFromList(
        file_list=val_files,
        transform=get_transforms(train=False),
        apply_preprocessing=True
    )

    test_dataset = BreastCancerDatasetFromList(
        file_list=test_files,
        transform=get_transforms(train=False),
        apply_preprocessing=True
    )

    # Calculate class weights for weighted sampling (train only)
    train_labels = [label for _, label, _ in train_dataset]
    class_counts = np.bincount(train_labels)
    total_samples = sum(class_counts)

    # Calculate inverse weights for more aggressive balancing
    inverse_weights = total_samples / (class_counts * len(class_counts))
    sample_weights = [inverse_weights[label] for label in train_labels]

    # Create sampler for balanced training
    sampler = WeightedRandomSampler(
        weights=sample_weights,
        num_samples=len(train_dataset),
        replacement=True
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        sampler=sampler,
        num_workers=os.cpu_count() or 2,
        pin_memory=True
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count() or 2,
        pin_memory=True
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=os.cpu_count() or 2,
        pin_memory=True
    )

    # Get class names
    class_names = train_dataset.classes

    return train_dataloader, val_dataloader, test_dataloader, class_names

# 12. Parameter Group Creation for Optimizer
def get_optimizer_grouped_parameters(model, weight_decay=0.05):
    # Track which parameters have been assigned to prevent duplicates
    assigned_params = set()
    optimizer_grouped_parameters = []

    # Define parameter groups with a function to ensure no duplicates
    def add_params_to_group(params_list, weight_decay_value, lr_multiplier):
        filtered_params = []
        for param in params_list:
            # Use id(param) to uniquely identify each parameter
            param_id = id(param)
            if param_id not in assigned_params:
                filtered_params.append(param)
                assigned_params.add(param_id)

        if filtered_params:  # Only add group if it has parameters
            optimizer_grouped_parameters.append({
                "params": filtered_params,
                "weight_decay": weight_decay_value,
                "lr_mult": lr_multiplier
            })

    # No decay for biases and LayerNorm params
    no_decay = ["bias", "LayerNorm.weight"]

    # Group parameters by component and whether they should have weight decay
    vit_params_decay = [p for n, p in model.named_parameters()
                        if not any(nd in n for nd in no_decay) and "vit" in n]
    vit_params_no_decay = [p for n, p in model.named_parameters()
                           if any(nd in n for nd in no_decay) and "vit" in n]

    cnn_params_decay = [p for n, p in model.named_parameters()
                        if not any(nd in n for nd in no_decay) and "cnn" in n]
    cnn_params_no_decay = [p for n, p in model.named_parameters()
                           if any(nd in n for nd in no_decay) and "cnn" in n]

    attention_params_decay = [p for n, p in model.named_parameters()
                              if not any(nd in n for nd in no_decay) and "attention" in n]
    attention_params_no_decay = [p for n, p in model.named_parameters()
                                if any(nd in n for nd in no_decay) and "attention" in n]

    fusion_classifier_params_decay = [p for n, p in model.named_parameters()
                                      if not any(nd in n for nd in no_decay) and ("fusion" in n or "classifier" in n)]
    fusion_classifier_params_no_decay = [p for n, p in model.named_parameters()
                                        if any(nd in n for nd in no_decay) and ("fusion" in n or "classifier" in n)]

    # Add parameters to groups in order of priority
    add_params_to_group(vit_params_decay, weight_decay, 0.05)
    add_params_to_group(vit_params_no_decay, 0.0, 0.05)
    add_params_to_group(cnn_params_decay, weight_decay, 0.1)
    add_params_to_group(cnn_params_no_decay, 0.0, 0.1)
    add_params_to_group(attention_params_decay, weight_decay, 0.3)
    add_params_to_group(attention_params_no_decay, 0.0, 0.3)
    add_params_to_group(fusion_classifier_params_decay, weight_decay, 1.0)
    add_params_to_group(fusion_classifier_params_no_decay, 0.0, 1.0)

    return optimizer_grouped_parameters

# 13. Model Checkpointing
def save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, filename):
    """Save model checkpoint with training state"""
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'train_loss': train_loss,
        'val_loss': val_loss,
        'val_acc': val_acc,
    }
    torch.save(checkpoint, filename)
    print(f"Checkpoint saved: {filename}")

# 14. Visualization Functions
def plot_training_history(train_losses, train_accs, test_losses, test_accs):
    """Plot training history"""
    plt.figure(figsize=(15, 5))

    # Plot Loss
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Training Loss')
    plt.plot(test_losses, label='Validation Loss')
    plt.title('Loss Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)

    # Plot Accuracy
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Training Accuracy')
    plt.plot(test_accs, label='Validation Accuracy')
    plt.title('Accuracy Over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)

    plt.tight_layout()
    plt.show()

def plot_prediction_samples(model, test_dataloader, device, class_names):
    """
    Plot sample predictions for each category:
    - True Positives (predicted = malignant, actual = malignant)
    - False Positives (predicted = malignant, actual = benign)
    - True Negatives (predicted = benign, actual = benign)
    - False Negatives (predicted = benign, actual = malignant)
    """
    model.eval()

    # Lists to store images for each category
    true_positives = []  # Predicted malignant, actually malignant
    false_positives = [] # Predicted malignant, actually benign
    true_negatives = []  # Predicted benign, actually benign
    false_negatives = [] # Predicted benign, actually malignant

    # Create dictionary to map classes to indices
    class_to_idx = {name: i for i, name in enumerate(class_names)}

    # Collect samples for each category
    with torch.no_grad():
        for images, labels, filenames in test_dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = outputs.max(1)

            # Process each image in the batch
            for i in range(len(images)):
                # Get prediction and ground truth
                pred = predicted[i].item()
                truth = labels[i].item()

                # Denormalize image for display
                img = images[i].cpu().permute(1, 2, 0).numpy()
                img = (img - img.min()) / (img.max() - img.min())

                # Categorize based on prediction vs. truth
                sample = (img, filenames[i])

                # Malignant is typically index 1, benign is 0
                if pred == class_to_idx['malignant'] and truth == class_to_idx['malignant']:
                    true_positives.append(sample)
                elif pred == class_to_idx['malignant'] and truth == class_to_idx['benign']:
                    false_positives.append(sample)
                elif pred == class_to_idx['benign'] and truth == class_to_idx['benign']:
                    true_negatives.append(sample)
                elif pred == class_to_idx['benign'] and truth == class_to_idx['malignant']:
                    false_negatives.append(sample)

            # Stop once we have enough samples in each category
            if (len(true_positives) >= 3 and len(false_positives) >= 3 and
                len(true_negatives) >= 3 and len(false_negatives) >= 3):
                break

    # Function to display samples from a category
    def display_category(category_samples, title, max_samples=3):
        if not category_samples:
            print(f"No samples found for {title}")
            return

        samples_to_show = min(max_samples, len(category_samples))
        fig, axes = plt.subplots(1, samples_to_show, figsize=(5*samples_to_show, 5))

        # Handle case with only one sample
        if samples_to_show == 1:
            axes = [axes]

        fig.suptitle(title, fontsize=16)

        for i in range(samples_to_show):
            img, filename = category_samples[i]
            axes[i].imshow(img)
            axes[i].set_title(f"File: {filename}")
            axes[i].axis('off')

        plt.tight_layout()
        plt.show()

    # Display samples for each category
    print("\nSample Predictions:")

    display_category(true_positives, "True Positives (Predicted: Malignant, Actual: Malignant)")
    display_category(false_positives, "False Positives (Predicted: Malignant, Actual: Benign)")
    display_category(true_negatives, "True Negatives (Predicted: Benign, Actual: Benign)")
    display_category(false_negatives, "False Negatives (Predicted: Benign, Actual: Malignant)")

    # Print summary counts
    print(f"\nSample counts:")
    print(f"True Positives (correctly identified malignant): {len(true_positives)}")
    print(f"False Positives (benign incorrectly classified as malignant): {len(false_positives)}")
    print(f"True Negatives (correctly identified benign): {len(true_negatives)}")
    print(f"False Negatives (malignant incorrectly classified as benign): {len(false_negatives)}")

def predict_single_image(model, image_path, device, class_names, apply_preprocessing=True):
    """
    Make prediction on a single image using trained model

    Args:
        model: Trained PyTorch model
        image_path: Path to image file
        device: Computation device (CPU/CUDA)
        class_names: List of class names
        apply_preprocessing: Whether to apply medical image preprocessing

    Returns:
        Dictionary with prediction results
    """
    model.eval()

    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')

    # Create preprocessing and transform pipeline similar to dataset class
    transform = get_transforms(train=False)

    # Apply medical image preprocessing if requested
    if apply_preprocessing:
        try:
            # Convert to numpy for OpenCV processing
            np_img = np.array(image)

            # Check for valid image dimensions
            if np_img.ndim == 3 and np_img.shape[2] == 3:
                # Convert to LAB color space
                lab = cv2.cvtColor(np_img, cv2.COLOR_RGB2LAB)

                # Apply CLAHE to L channel
                clahe = cv2.createCLAHE(clipLimit=3.0, tileGridSize=(8, 8))
                lab[:,:,0] = clahe.apply(lab[:,:,0])

                # Convert back to RGB
                enhanced = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)

                # Apply bilateral filter
                enhanced = cv2.bilateralFilter(enhanced, 9, 75, 75)

                # Apply unsharp masking
                gaussian = cv2.GaussianBlur(enhanced, (0, 0), 3.0)
                enhanced = cv2.addWeighted(enhanced, 1.5, gaussian, -0.5, 0)

                image = Image.fromarray(enhanced)
        except Exception as e:
            print(f"Error during preprocessing: {e}. Using original image.")

    # Apply transforms
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Get prediction
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted_class = torch.max(probabilities, 1)

    # Format results
    prediction_results = {
        'predicted_class': class_names[predicted_class.item()],
        'confidence': confidence.item() * 100,  # Convert to percentage
        'probabilities': {
            class_name: prob.item() * 100
            for class_name, prob in zip(class_names, probabilities[0])
        }
    }

    return prediction_results

# 16. Training Loop
def train_model(model, train_dataloader, val_dataloader, num_epochs=30, batch_size=16,
                learning_rate=0.0001, weight_decay=0.05, mixup_alpha=0.2,
                checkpoint_dir='checkpoints', use_focal_loss=True, device='cuda'):
    """
    Train the model with advanced techniques

    Args:
        model: Model to train
        train_dataloader: DataLoader for training data
        val_dataloader: DataLoader for validation data
        num_epochs: Number of training epochs
        batch_size: Batch size
        learning_rate: Base learning rate
        weight_decay: Weight decay coefficient
        mixup_alpha: Alpha parameter for mixup augmentation (0 to disable)
        checkpoint_dir: Directory to save checkpoints
        use_focal_loss: Whether to use focal loss
        device: Computation device

    Returns:
        Trained model and training history
    """
    # Create logger
    timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
    log_filename = f'training_log_{timestamp}.log'
    logger = setup_logger(log_filename)
    logger.info(f"Starting training with batch size: {batch_size}, lr: {learning_rate}")

    # Create checkpoint directory if it doesn't exist
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Initialize lists to store metrics
    train_losses = []
    train_accs = []
    val_losses = []
    val_accs = []

    # Calculate class weights for loss function
    train_labels = torch.cat([y for _, y, _ in train_dataloader])
    class_counts = torch.bincount(train_labels)
    class_weights = 1.0 / class_counts.float()
    class_weights = class_weights / class_weights.sum() * len(class_counts)
    class_weights = class_weights.to(device)

    # Set up loss function
    if use_focal_loss:
        criterion = FocalLoss(alpha=class_weights, gamma=2.0, device=device)
        logger.info("Using Focal Loss with class weights and gamma=2.0")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        logger.info("Using Cross Entropy Loss with class weights")

    # Create optimizer with parameter groups
    param_groups = get_optimizer_grouped_parameters(model, weight_decay)

    # Create optimizer with different learning rates for different parts of the model
    optimizer = optim.AdamW([
        {'params': group['params'], 'lr': learning_rate * group['lr_mult'],
         'weight_decay': group['weight_decay']}
        for group in param_groups
    ])

    # Create learning rate scheduler
    lr_scheduler = WarmupCosineScheduler(
        optimizer,
        warmup_epochs=3,
        total_epochs=num_epochs,
        min_lr_ratio=0.05
    )

    # Training loop
    best_val_acc = 0.0

    for epoch in range(num_epochs):
        # Update learning rate
        lr_scheduler.step(epoch)
        current_lr = optimizer.param_groups[0]['lr']
        logger.info(f"Epoch {epoch+1}/{num_epochs}, Learning Rate: {current_lr:.6f}")

        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for batch_idx, (inputs, targets, _) in enumerate(train_dataloader):
            inputs, targets = inputs.to(device), targets.to(device)

            # Apply mixup augmentation if enabled
            if mixup_alpha > 0:
                inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, mixup_alpha)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            # Calculate loss with mixup if enabled
            if mixup_alpha > 0:
                loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
            else:
                loss = criterion(outputs, targets)

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track metrics
            running_loss += loss.item()

            # For accuracy calculation, we can only do this properly without mixup
            if mixup_alpha <= 0:
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

            # Log progress
            if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(train_dataloader):
                logger.info(f"Epoch [{epoch+1}/{num_epochs}], Batch [{batch_idx+1}/{len(train_dataloader)}], "
                           f"Loss: {loss.item():.4f}")

        # Calculate training metrics
        train_loss = running_loss / len(train_dataloader)
        train_losses.append(train_loss)

        if mixup_alpha <= 0:
            train_acc = 100. * correct / total
            train_accs.append(train_acc)
            logger.info(f"Training Loss: {train_loss:.4f}, Accuracy: {train_acc:.2f}%")
        else:
            # If using mixup, we can't easily calculate accuracy during training
            train_accs.append(0)  # Placeholder
            logger.info(f"Training Loss: {train_loss:.4f}")

        # Validation phase
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for inputs, targets, _ in val_dataloader:
                inputs, targets = inputs.to(device), targets.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, targets)

                # Track metrics
                val_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()

        # Calculate validation metrics
        val_loss = val_loss / len(val_dataloader)
        val_losses.append(val_loss)

        val_acc = 100. * correct / total
        val_accs.append(val_acc)

        logger.info(f"Validation Loss: {val_loss:.4f}, Accuracy: {val_acc:.2f}%")

        # Save checkpoint if this is the best model so far
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            checkpoint_path = os.path.join(checkpoint_dir, f'best_model_epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, checkpoint_path)
            logger.info(f"New best model saved with validation accuracy: {val_acc:.2f}%")

        # Save regular checkpoint every 5 epochs
        if (epoch + 1) % 5 == 0:
            checkpoint_path = os.path.join(checkpoint_dir, f'model_epoch_{epoch+1}.pth')
            save_checkpoint(model, optimizer, epoch, train_loss, val_loss, val_acc, checkpoint_path)

    # Training complete
    logger.info("Training completed.")

    # Prepare training history
    history = {
        'train_loss': train_losses,
        'train_acc': train_accs,
        'val_loss': val_losses,
        'val_acc': val_accs
    }

    return model, history

# 17. Evaluation Function
def evaluate_model(model, test_dataloader, device, class_names, visualize=True):
    """
    Evaluate model on test data with detailed metrics

    Args:
        model: Trained model
        test_dataloader: DataLoader for test data
        device: Computation device
        class_names: List of class names
        visualize: Whether to visualize results

    Returns:
        Dictionary with evaluation metrics
    """
    model.eval()

    all_predictions = []
    all_targets = []
    correct = 0
    total = 0

    # Collect predictions
    with torch.no_grad():
        for inputs, targets, filenames in test_dataloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)

            # Track overall accuracy
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Store predictions and targets for detailed metrics
            all_predictions.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    # Calculate accuracy
    accuracy = 100. * correct / total

    # Calculate confusion matrix
    cm = confusion_matrix(all_targets, all_predictions)

    # Calculate class-wise metrics
    report = classification_report(all_targets, all_predictions,
                                  target_names=class_names,
                                  output_dict=True)

    # Calculate sensitivity and specificity
    # For binary classification: malignant is typically positive class (index 1)
    if len(class_names) == 2:
        tn, fp, fn, tp = cm.ravel()
        sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
    else:
        # For multi-class, we can compute macro-averaged values
        sensitivity = np.mean([report[class_name]['recall'] for class_name in class_names])
        specificity = 0  # Requires one-vs-rest calculation for multi-class

    # Print detailed metrics
    print(f"\nModel Evaluation on Test Data:")
    print(f"Overall Accuracy: {accuracy:.2f}%")

    if len(class_names) == 2:
        print(f"Sensitivity (True Positive Rate): {sensitivity:.4f}")
        print(f"Specificity (True Negative Rate): {specificity:.4f}")

    print("\nClassification Report:")
    for class_name in class_names:
        print(f"  {class_name}:")
        print(f"    Precision: {report[class_name]['precision']:.4f}")
        print(f"    Recall: {report[class_name]['recall']:.4f}")
        print(f"    F1-Score: {report[class_name]['f1-score']:.4f}")

    # Display confusion matrix
    if visualize:
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                    xticklabels=class_names, yticklabels=class_names)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.show()

        # Display sample predictions
        plot_prediction_samples(model, test_dataloader, device, class_names)

    # Return metrics in a dictionary
    metrics = {
        'accuracy': accuracy,
        'confusion_matrix': cm,
        'classification_report': report,
    }

    if len(class_names) == 2:
        metrics['sensitivity'] = sensitivity
        metrics['specificity'] = specificity

    return metrics

# 18. Grad-CAM Visualization for Explainability
class GradCAM:
    """
    Implements Grad-CAM for model explainability
    """
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        # Register hooks
        self.register_hooks()

    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output

        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0]

        # Register the hooks
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def generate_heatmap(self, input_image, class_idx=None):
        """
        Generate Grad-CAM heatmap

        Args:
            input_image: Input tensor (1, C, H, W)
            class_idx: Target class index. If None, uses the predicted class

        Returns:
            Heatmap as numpy array
        """
        # Ensure model is in eval mode
        self.model.eval()

        # Get model prediction
        output = self.model(input_image)

        # If class_idx is None, use the predicted class
        if class_idx is None:
            _, class_idx = torch.max(output, 1)
            class_idx = class_idx.item()

        # Zero gradients
        self.model.zero_grad()

        # Backward pass with the target class
        target = torch.zeros_like(output)
        target[0, class_idx] = 1
        output.backward(gradient=target)

        # Get gradients and activations
        gradients = self.gradients.detach().cpu()
        activations = self.activations.detach().cpu()

        # Global average pooling of gradients
        weights = torch.mean(gradients, dim=[2, 3], keepdim=True)

        # Weighted combination of activation maps
        heatmap = torch.sum(weights * activations, dim=1, keepdim=True)

        # ReLU on the heatmap
        heatmap = F.relu(heatmap)

        # Normalize heatmap
        heatmap = F.interpolate(heatmap, size=(input_image.size(2), input_image.size(3)),
                                mode='bilinear', align_corners=False)

        # Min-max normalization
        heatmap_min = torch.min(heatmap)
        heatmap_max = torch.max(heatmap)
        normalized_heatmap = (heatmap - heatmap_min) / (heatmap_max - heatmap_min + 1e-10)

        # Convert to numpy
        heatmap_np = normalized_heatmap[0, 0].numpy()

        return heatmap_np

def visualize_heatmap(model, image_path, device, class_names, target_layer=None):
    """
    Visualize Grad-CAM heatmap for a single image

    Args:
        model: Trained model
        image_path: Path to image file
        device: Computation device
        class_names: List of class names
        target_layer: Target layer for Grad-CAM. If None, uses the last layer of CNN backbone

    Returns:
        Prediction and visualization
    """
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    transform = get_transforms(train=False)
    image_tensor = transform(image).unsqueeze(0).to(device)

    # If target_layer is not specified, use the last convolutional layer of the CNN backbone
    if target_layer is None:
        for name, module in model.cnn.named_modules():
            if isinstance(module, nn.Conv2d):
                target_layer = module

        # If we haven't found a suitable layer, try to get one from the model directly
        if target_layer is None:
            print("Could not automatically find a suitable target layer. Please specify one.")
            return None

    # Create Grad-CAM instance
    grad_cam = GradCAM(model, target_layer)

    # Get prediction
    model.eval()
    with torch.no_grad():
        outputs = model(image_tensor)
        probabilities = F.softmax(outputs, dim=1)
        confidence, predicted_class = torch.max(probabilities, 1)

    # Generate heatmap for the predicted class
    heatmap = grad_cam.generate_heatmap(image_tensor, predicted_class.item())

    # Convert image tensor to numpy for visualization
    # Denormalize the image
    img = image_tensor[0].cpu().permute(1, 2, 0).numpy()
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img = std * img + mean
    img = np.clip(img, 0, 1)

    # Create colormap for heatmap
    heatmap_colored = plt.cm.jet(heatmap)[:, :, :3]

    # Resize heatmap to match image dimensions
    heatmap_resized = cv2.resize(heatmap_colored, (img.shape[1], img.shape[0]))

    # Create overlaid image
    overlaid = img * 0.7 + heatmap_resized * 0.3

    # Visualize
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))

    # Original image
    axes[0].imshow(img)
    axes[0].set_title("Original Image")
    axes[0].axis('off')

    # Heatmap
    axes[1].imshow(heatmap_colored)
    axes[1].set_title("Grad-CAM Heatmap")
    axes[1].axis('off')

    # Overlaid image
    axes[2].imshow(overlaid)
    axes[2].set_title(f"Prediction: {class_names[predicted_class.item()]} ({confidence.item()*100:.2f}%)")
    axes[2].axis('off')

    plt.tight_layout()
    plt.show()

    # Return prediction results
    prediction_results = {
        'predicted_class': class_names[predicted_class.item()],
        'confidence': confidence.item() * 100,  # Convert to percentage
        'probabilities': {
            class_name: prob.item() * 100
            for class_name, prob in zip(class_names, probabilities[0])
        }
    }

    return prediction_results

# 19. Deployment Helper Functions
def load_model_from_checkpoint(checkpoint_path, device):
    """
    Load a model from checkpoint

    Args:
        checkpoint_path: Path to checkpoint file
        device: Computation device

    Returns:
        Loaded model
    """
    # Initialize model
    model = ViTCNNEnsemble(num_classes=2)
    model = model.to(device)

    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])

    # Print info
    print(f"Model loaded from checkpoint: {checkpoint_path}")
    print(f"Checkpoint was saved at epoch {checkpoint['epoch']+1}")
    print(f"Validation accuracy: {checkpoint['val_acc']:.2f}%")

    return model

def prepare_model_for_export(model, input_shape=(1, 3, 224, 224)):
    """
    Prepare the model for export to ONNX or TorchScript

    Args:
        model: PyTorch model
        input_shape: Input tensor shape

    Returns:
        Model ready for export
    """
    # Set model to evaluation mode
    model.eval()

    # Create an example input
    dummy_input = torch.randn(input_shape, requires_grad=True)

    # Trace the model with JIT
    traced_model = torch.jit.trace(model, dummy_input)

    return traced_model, dummy_input

def export_to_onnx(model, dummy_input, onnx_path):
    """
    Export model to ONNX format

    Args:
        model: PyTorch model
        dummy_input: Example input tensor
        onnx_path: Output path for ONNX file

    Returns:
        Path to exported model
    """
    # Export the model
    torch.onnx.export(
        model,                      # model being run
        dummy_input,                # model input
        onnx_path,                  # where to save the model
        export_params=True,         # store the trained parameter weights inside the model file
        opset_version=12,           # the ONNX version to export the model to
        do_constant_folding=True,   # optimization
        input_names=['input'],      # the model's input names
        output_names=['output'],    # the model's output names
        dynamic_axes={
            'input': {0: 'batch_size'},    # variable length axes
            'output': {0: 'batch_size'}
        }
    )

    print(f"Model exported to ONNX format: {onnx_path}")
    return onnx_path

# 20. Main Execution Function
# 20. Main Execution Function
def main():
    """
    Main execution function to run the training and evaluation pipeline
    """
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Create datasets
    train_dataset = BreastCancerDataset(
        data_dir='AI_demos/custom_dataset/train',
        transform=get_transforms(train=True),
        apply_preprocessing=True  # Apply advanced preprocessing
    )

    val_dataset = BreastCancerDataset(
        data_dir='AI_demos/custom_dataset/val',
        transform=get_transforms(train=False),
        apply_preprocessing=True
    )

    test_dataset = BreastCancerDataset(
        data_dir='AI_demos/custom_dataset/test',
        transform=get_transforms(train=False),
        apply_preprocessing=True
    )

    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=16,
        shuffle=True,
        num_workers=4
    )

    val_dataloader = DataLoader(
        val_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=4
    )

    test_dataloader = DataLoader(
        test_dataset,
        batch_size=16,
        shuffle=False,
        num_workers=4
    )

    # Initialize model
    print("Initializing model...")
    model = ViTCNNEnsemble(num_classes=len(class_names))
    model = model.to(device)

    # Train model
    print("Training model...")
    trained_model, history = train_model(
        model=model,
        train_dataloader=train_dataloader,
        val_dataloader=val_dataloader,
        num_epochs=30,
        batch_size=16,
        learning_rate=0.0001,
        weight_decay=0.05,
        mixup_alpha=0.2,
        checkpoint_dir='checkpoints',
        device=device
    )

    # Plot training history
    plot_training_history(
        train_losses=history['train_loss'],
        train_accs=history['train_acc'],
        test_losses=history['val_loss'],
        test_accs=history['val_acc']
    )

    # Evaluate model
    print("Evaluating model...")
    metrics = evaluate_model(
        model=trained_model,
        test_dataloader=test_dataloader,
        device=device,
        class_names=class_names,
        visualize=True
    )

    # Print final results
    print("\nFinal Evaluation Results:")
    print(f"Accuracy: {metrics['accuracy']:.2f}%")

    if len(class_names) == 2:
        print(f"Sensitivity: {metrics['sensitivity']:.4f}")
        print(f"Specificity: {metrics['specificity']:.4f}")

    # Export model
    print("Exporting model...")
    traced_model, dummy_input = prepare_model_for_export(trained_model)
    export_to_onnx(traced_model, dummy_input, 'breast_cancer_classifier.onnx')

    print("Pipeline completed successfully!")

if __name__ == "__main__":
    main()