In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import time
import os
import math
from PIL import Image
import gc
import warnings
from tqdm import tqdm
from transformers import BeitModel, BeitConfig
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
from sklearn.model_selection import KFold, StratifiedKFold
import copy
import pandas as pd
from torch.utils.data import Subset, ConcatDataset

warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

set_seed()

# Improved configuration with k-fold settings
CONFIG = {
    'data_dir': "D:\\Work\\dont_plot_images\\augmented_data",
    'batch_size': 8,
    'eval_batch_size': 16,
    'num_epochs': 50,  # Reduced since we'll be training multiple times
    'learning_rate': 0.00002,  # Reduced learning rate for more stable fine-tuning
    'image_size': 224,
    'beit_model': "microsoft/beit-base-patch16-224-pt22k-ft22k",  # Using the finetuned model
    'use_mixed_precision': True,
    'gradient_accumulation_steps': 2,  # Effective batch size of 16
    'weight_decay': 0.04,  # Higher weight decay for better regularization
    'freeze_backbone_epochs': 0,  # Don't freeze backbone to improve recall
    'optimization_metric': 'f1_macro',  # Optimize for macro F1 score across all classes
    'focal_loss_gamma': 3.0,  # Increased focal loss gamma to focus more on hard examples
    'class_weights': [1.0, 0.8, 1.0, 2.0, 1.5, 2.0],  # Adjusted weights for classes with lower F1 scores
    'augmentation_strength': 'strong',  # Stronger augmentations to improve generalization
    'scheduler': 'cosine_warmup',  # Custom cosine scheduler with warmup
    'use_checkpoint': False,  # Use checkpoint for model saving if available
    'patience': 7,  # Reduced for faster training in k-fold
    'warmup_epochs': 3,  # Reduced for faster training
    'num_classes': 6,  # Number of classes
    'dropout_rate': 0.3,  # Increased dropout for better regularization
    'use_mixup': True,  # Enable mixup augmentation
    'mixup_alpha': 0.2,  # Mixup interpolation factor
    'use_ensemble': True,  # Use model ensemble for final predictions
    'ensemble_epochs': [15, 20, 25, 30],  # Adjusted for shorter training
    'use_class_balanced_loss': True,  # Use class-balanced loss
    # K-fold specific settings
    'k_folds': 5,          # Number of folds
    'stratified': True,    # Whether to use stratified folds (recommended for imbalanced data)
    'save_all_folds': True, # Whether to save models from all folds or just the best
    'run_kfold': True      # Whether to run k-fold or standard training
}

# Device configuration
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Memory optimization for CUDA
if torch.cuda.is_available():
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    try:
        torch.cuda.set_per_process_memory_fraction(0.85)
    except:
        pass
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Enhanced data transforms with stronger augmentations
def get_data_transforms(augmentation_strength='medium'):
    if augmentation_strength == 'strong':
        # Stronger augmentations for better generalization
        train_transform = transforms.Compose([
            transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomVerticalFlip(p=0.2),  # X-rays might be flipped
            transforms.RandomRotation(20),
            transforms.RandomAffine(
                degrees=20, translate=(0.15, 0.15), 
                scale=(0.85, 1.15), shear=10),
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.1, hue=0.1),
            transforms.RandomPerspective(distortion_scale=0.2, p=0.3),
            transforms.RandomAutocontrast(p=0.3),
            transforms.RandomEqualize(p=0.2),
            transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.2, scale=(0.02, 0.15)),
        ])
    elif augmentation_strength == 'medium':
        # Medium augmentations
        train_transform = transforms.Compose([
            transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(15),
            transforms.ColorJitter(brightness=0.1, contrast=0.1),
            transforms.RandomAffine(degrees=15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.1),
        ])
    else:
        # Light augmentations
        train_transform = transforms.Compose([
            transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])
    
    # Validation transforms
    val_transform = transforms.Compose([
        transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])
    
    return {
        'train': train_transform,
        'val': val_transform,
        'test': val_transform
    }

# Get data transforms
data_transforms = get_data_transforms(CONFIG.get('augmentation_strength', 'strong'))

# Create class-specific transforms for problematic classes
def get_transform_for_class(class_idx, phase='train'):
    if phase != 'train':
        return data_transforms[phase]
    
    # Enhanced augmentation for classes with lower performance
    if class_idx in [3, 5]:  # 's3' and 'un' classes
        return transforms.Compose([
            transforms.Resize((CONFIG['image_size'], CONFIG['image_size'])),
            transforms.RandomHorizontalFlip(p=0.7),
            transforms.RandomVerticalFlip(p=0.5),
            transforms.RandomRotation(30),
            transforms.RandomAffine(
                degrees=30, translate=(0.2, 0.2), 
                scale=(0.8, 1.2), shear=15),
            transforms.ColorJitter(
                brightness=0.3, contrast=0.3, saturation=0.2, hue=0.2),
            transforms.RandomAutocontrast(p=0.5),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
            transforms.RandomErasing(p=0.3, scale=(0.02, 0.2)),
        ])
    else:
        # Use standard strong augmentation for other classes
        return data_transforms['train']

# Custom dataset class for class-specific augmentation
class ClassSpecificImageFolder(datasets.ImageFolder):
    def __getitem__(self, index):
        path, target = self.samples[index]
        sample = self.loader(path)
        
        # Apply class-specific transform
        transform = get_transform_for_class(target, 'train')
        sample = transform(sample)
        
        return sample, target

# Mixup augmentation for training
def mixup_data(x, y, alpha=0.2):
    """Applies mixup augmentation to a batch of images"""
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    """Applies mixup criterion to the predictions"""
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Focal Loss implementation for handling class imbalance
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.alpha = alpha  # Weight for each class
        self.reduction = reduction
        
    def forward(self, inputs, targets):
        # Get CrossEntropyLoss
        ce_loss = F.cross_entropy(
            inputs, targets, 
            weight=self.alpha, 
            reduction='none'
        )
        
        # Get probabilities
        pt = torch.exp(-ce_loss)
        
        # Calculate focal loss
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        # Apply reduction
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Class-Balanced Loss for better handling of imbalanced classes
class ClassBalancedLoss(nn.Module):
    def __init__(self, samples_per_class, num_classes=6, beta=0.9999, gamma=2.0):
        super(ClassBalancedLoss, self).__init__()
        self.beta = beta
        self.gamma = gamma
        self.num_classes = num_classes
        
        # Calculate effective number of samples
        effective_num = 1.0 - np.power(beta, samples_per_class)
        effective_num = np.where(effective_num == 0, 1e-8, effective_num)  # Avoid division by zero
        
        self.weights = (1.0 - beta) / np.array(effective_num)
        self.weights = self.weights / np.sum(self.weights) * num_classes
        self.weights = torch.tensor(self.weights).float().to(device)
        
    def forward(self, inputs, targets):
        # Get probabilities
        probs = F.softmax(inputs, dim=1)
        
        # Get target probabilities
        target_one_hot = F.one_hot(targets, self.num_classes).float()
        
        # Calculate focal loss
        pt = torch.sum(target_one_hot * probs, dim=1)
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        focal_loss = (1 - pt) ** self.gamma * ce_loss
        
        # Apply class balanced weights
        cb_loss = torch.mean(self.weights[targets] * focal_loss)
        
        return cb_loss

# Custom learning rate scheduler with warmup
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, min_lr_ratio=0.0):
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(min_lr_ratio, 0.5 * (1.0 + math.cos(math.pi * progress)))
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

# Enhanced BEiT model with attention mechanism
class EnhancedBeitClassifier(nn.Module):
    def __init__(self, num_classes=6, dropout_rate=0.3, freeze_backbone=False):
        super(EnhancedBeitClassifier, self).__init__()
        
        # Load pre-trained BEiT model with finetuned weights
        try:
            self.beit = BeitModel.from_pretrained(CONFIG['beit_model'])
            print(f"Successfully loaded BEiT model: {CONFIG['beit_model']}")
        except Exception as e:
            print(f"Error loading BEiT model: {e}")
            raise
        
        # Get feature dimension
        self.feature_dim = self.beit.config.hidden_size
        print(f"BEiT feature dimension: {self.feature_dim}")
        
        # Add self-attention layer after feature extraction
        self.attention = nn.MultiheadAttention(embed_dim=self.feature_dim, num_heads=8, dropout=0.1)
        
        # Layer normalization for more stable training
        self.norm = nn.LayerNorm(self.feature_dim)
        
        # Improved classifier with additional layers
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(self.feature_dim, 768),
            nn.GELU(),
            nn.LayerNorm(768),
            nn.Dropout(dropout_rate),
            nn.Linear(768, 384),
            nn.GELU(),
            nn.LayerNorm(384),
            nn.Dropout(dropout_rate),
            nn.Linear(384, num_classes)
        )
        
        # Initialize classifier weights
        self._init_classifier()
        
        # Option to freeze backbone
        self.freeze_backbone = freeze_backbone
        if freeze_backbone:
            self._freeze_backbone()
    
    def _init_classifier(self):
        """Initialize classifier weights properly"""
        for m in self.classifier.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.LayerNorm):
                nn.init.constant_(m.weight, 1.0)
                nn.init.constant_(m.bias, 0)
    
    def _freeze_backbone(self):
        """Freeze the BEiT backbone"""
        for param in self.beit.parameters():
            param.requires_grad = False
    
    def unfreeze_backbone(self):
        """Unfreeze the BEiT backbone"""
        for param in self.beit.parameters():
            param.requires_grad = True
        self.freeze_backbone = False
    
    def forward(self, x):
        # Extract features from BEiT with or without gradient
        if self.freeze_backbone:
            with torch.no_grad():
                outputs = self.beit(x)
        else:
            outputs = self.beit(x)
        
        # Get CLS token
        cls_token = outputs.last_hidden_state[:, 0]
        
        # Apply self-attention
        cls_token_unsqueezed = cls_token.unsqueeze(0)  # Add sequence dimension
        attn_output, _ = self.attention(
            cls_token_unsqueezed, 
            cls_token_unsqueezed, 
            cls_token_unsqueezed
        )
        cls_token = attn_output.squeeze(0)  # Remove sequence dimension
        
        # Apply layer norm
        normalized = self.norm(cls_token)
        
        # Apply classifier
        logits = self.classifier(normalized)
        
        return logits

# Ensemble model for improved performance
class EnsembleModel(nn.Module):
    def __init__(self, model_paths, num_classes):
        super(EnsembleModel, self).__init__()
        self.models = nn.ModuleList()
        
        # Load models
        for path in model_paths:
            # Create a new model
            model = EnhancedBeitClassifier(num_classes=num_classes)
            
            # Load weights
            checkpoint = torch.load(path)
            model.load_state_dict(checkpoint['model_state_dict'])
            
            # Set to eval mode
            model.eval()
            
            # Add to model list
            self.models.append(model)
            
        print(f"Created ensemble with {len(self.models)} models")
    
    def forward(self, x):
        # Get predictions from each model
        outputs = []
        for model in self.models:
            with torch.no_grad():
                outputs.append(model(x))
        
        # Average the logits
        return torch.mean(torch.stack(outputs), dim=0)

# K-Fold Ensemble model
class KFoldEnsemble(nn.Module):
    def __init__(self, models):
        super(KFoldEnsemble, self).__init__()
        self.models = nn.ModuleList(models)
        
    def forward(self, x):
        outputs = []
        for model in self.models:
            model.eval()
            with torch.no_grad():
                outputs.append(model(x))
        
        # Average the outputs
        return torch.mean(torch.stack(outputs), dim=0)

def load_data():
    """Load dataset with enhanced handling for 6 classes and class-specific augmentation"""
    try:
        # Create train dataset with class-specific transforms
        train_dataset = ClassSpecificImageFolder(os.path.join(CONFIG['data_dir'], 'train'))
        
        # Create validation and test datasets with standard transforms
        image_datasets = {'train': train_dataset}
        for split in ['val', 'test']:
            if os.path.exists(os.path.join(CONFIG['data_dir'], split)):
                image_datasets[split] = datasets.ImageFolder(
                    os.path.join(CONFIG['data_dir'], split),
                    data_transforms[split]
                )
        
        # Get class names
        class_names = image_datasets['train'].classes
        print(f"Classes: {class_names}")
        CONFIG['num_classes'] = len(class_names)  # Update config with actual number of classes
        
        # Print dataset info
        for split in ['train', 'val', 'test']:
            if os.path.exists(os.path.join(CONFIG['data_dir'], split)):
                labels = [label for _, label in image_datasets[split].samples]
                unique_labels, counts = np.unique(labels, return_counts=True)
                print(f"{split} dataset:")
                for i, class_name in enumerate(class_names):
                    class_idx = np.where(unique_labels == i)[0]
                    class_count = counts[class_idx[0]] if len(class_idx) > 0 else 0
                    print(f"  {class_name}: {class_count} images")
        
        # Calculate class weights based on training set distribution
        if len(class_names) == CONFIG['num_classes']:
            train_labels = np.array([label for _, label in image_datasets['train'].samples])
            class_counts = np.bincount(train_labels, minlength=CONFIG['num_classes'])
            
            # Use provided class weights but print statistics
            print(f"Class distribution: {class_counts}")
            print(f"Using class weights: {CONFIG['class_weights']}")
        
        # Create weighted sampler for training to handle class imbalance
        train_labels = np.array([label for _, label in image_datasets['train'].samples])
        weights = np.array(CONFIG['class_weights'])[train_labels]
        
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=weights,
            num_samples=len(weights),
            replacement=True
        )
        
        # Create data loaders with enhanced settings
        dataloaders = {
            'train': torch.utils.data.DataLoader(
                image_datasets['train'],
                batch_size=CONFIG['batch_size'],
                sampler=sampler,
                num_workers=0,
                pin_memory=True,
                drop_last=True
            )
        }
        
        # Add validation and test dataloaders if available
        for split in ['val', 'test']:
            if os.path.exists(os.path.join(CONFIG['data_dir'], split)):
                dataloaders[split] = torch.utils.data.DataLoader(
                    image_datasets[split],
                    batch_size=CONFIG['eval_batch_size'],
                    shuffle=False,
                    num_workers=0,
                    pin_memory=True
                )
        
        dataset_sizes = {x: len(image_datasets[x]) for x in image_datasets.keys()}
        print(f"Dataset sizes: {dataset_sizes}")
        
        return dataloaders, dataset_sizes, class_names
    
    except Exception as e:
        print(f"Error loading data: {e}")
        raise

# New K-Fold Data Loader
def load_data_kfold():
    """Load and prepare data for k-fold cross validation"""
    try:
        # Load all data (train, val, test)
        full_dataset = ClassSpecificImageFolder(os.path.join(CONFIG['data_dir'], 'train'))
        
        # Add val and test data to get a larger dataset for cross-validation
        val_dataset = None
        test_dataset = None
        
        if os.path.exists(os.path.join(CONFIG['data_dir'], 'val')):
            val_dataset = datasets.ImageFolder(
                os.path.join(CONFIG['data_dir'], 'val'),
                data_transforms['train']  # Use train transforms for all data
            )
        
        if os.path.exists(os.path.join(CONFIG['data_dir'], 'test')):
            test_dataset = datasets.ImageFolder(
                os.path.join(CONFIG['data_dir'], 'test'),
                data_transforms['train']  # Use train transforms for all data
            )
        
        # Combine all datasets for cross-validation
        all_datasets = [full_dataset]
        if val_dataset:
            all_datasets.append(val_dataset)
        if test_dataset:
            all_datasets.append(test_dataset)
        
        # Merge all datasets for k-fold cross-validation
        combined_dataset = ConcatDataset(all_datasets)
        
        # Get class names from the original dataset
        class_names = full_dataset.classes
        print(f"Classes: {class_names}")
        CONFIG['num_classes'] = len(class_names)
        
        # Get all labels for stratification
        all_labels = []
        for dataset in all_datasets:
            all_labels.extend([label for _, label in dataset.samples])
        
        # Print overall dataset statistics
        unique_labels, counts = np.unique(all_labels, return_counts=True)
        print(f"Combined dataset:")
        for i, class_name in enumerate(class_names):
            class_idx = np.where(unique_labels == i)[0]
            class_count = counts[class_idx[0]] if len(class_idx) > 0 else 0
            print(f"  {class_name}: {class_count} images")
        
        # Count total samples
        total_samples = len(combined_dataset)
        print(f"Total dataset size for k-fold: {total_samples} images")
        
        return combined_dataset, class_names, all_labels
    
    except Exception as e:
        print(f"Error loading data for k-fold: {e}")
        import traceback
        traceback.print_exc()
        raise

# Function to create dataloaders for a specific fold
def create_fold_dataloaders(dataset, train_idx, val_idx, class_names):
    """Create train and validation dataloaders for a specific fold"""
    
    # Create train and validation subsets
    train_subset = Subset(dataset, train_idx)
    val_subset = Subset(dataset, val_idx)
    
    # Extract labels for the current fold's training set (for weighted sampling)
    train_labels = []
    for idx in train_idx:
        # Handle both ImageFolder and ConcatDataset
        if isinstance(dataset, ConcatDataset):
            # Find which dataset the index belongs to
            dataset_idx = 0
            idx_tmp = idx
            while idx_tmp >= len(dataset.datasets[dataset_idx]):
                idx_tmp -= len(dataset.datasets[dataset_idx])
                dataset_idx += 1
            train_labels.append(dataset.datasets[dataset_idx].samples[idx_tmp][1])
        else:
            train_labels.append(dataset.samples[idx][1])
    
    train_labels = np.array(train_labels)
    
    # Calculate class distribution for this fold
    class_counts = np.bincount(train_labels, minlength=CONFIG['num_classes'])
    print(f"Fold train class distribution: {class_counts}")
    
    # Create weighted sampler for the training fold
    weights = np.array(CONFIG['class_weights'])[train_labels]
    sampler = torch.utils.data.WeightedRandomSampler(
        weights=weights,
        num_samples=len(weights),
        replacement=True
    )
    
    # Create data loaders for this fold
    train_loader = torch.utils.data.DataLoader(
        train_subset,
        batch_size=CONFIG['batch_size'],
        sampler=sampler,
        num_workers=0,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = torch.utils.data.DataLoader(
        val_subset,
        batch_size=CONFIG['eval_batch_size'],
        shuffle=False,
        num_workers=0,
        pin_memory=True
    )
    
    # Create dataloader dictionary and sizes
    dataloaders = {
        'train': train_loader,
        'val': val_loader
    }
    
    dataset_sizes = {
        'train': len(train_subset),
        'val': len(val_subset)
    }
    
    return dataloaders, dataset_sizes

def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler=None, num_epochs=30):
    """Train the enhanced BEiT model with improved techniques for multi-class classification"""
    since = time.time()
    
    # Set up checkpoint directory
    checkpoint_dir = 'checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Best model tracking
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pt')
    best_acc = 0.0
    best_f1 = 0.0
    best_metric = 0.0  # Track the optimization metric
    
    # History tracking
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_precision': [], 'val_recall': [], 'val_f1': [],
        'val_f1_macro': [], 'val_f1_weighted': []
    }
    
    # Gradient scaling for mixed precision
    scaler = torch.cuda.amp.GradScaler(enabled=CONFIG['use_mixed_precision'])
    
    # Early stopping
    patience = CONFIG.get('patience', 10)
    counter = 0
    
    # Gradient accumulation
    grad_accum_steps = CONFIG['gradient_accumulation_steps']
    
    # Evaluation phase (val or test)
    eval_phase = 'val' if 'val' in dataloaders else 'test'
    
    # Mixup settings
    use_mixup = CONFIG.get('use_mixup', True)
    mixup_alpha = CONFIG.get('mixup_alpha', 0.2)
    
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 30)
        
        # Unfreeze backbone if specified
        if epoch == CONFIG.get('freeze_backbone_epochs', 0) and model.freeze_backbone:
            print("Unfreezing backbone for fine-tuning...")
            model.unfreeze_backbone()
            # Adjust optimizer if needed
            if hasattr(optimizer, 'param_groups'):
                for param_group in optimizer.param_groups:
                    if 'beit' in param_group.get('name', ''):
                        param_group['lr'] = CONFIG['learning_rate'] * 0.1
        
        # Each epoch has training and validation phases
        for phase in ['train', eval_phase]:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            all_labels = []
            all_preds = []
            all_probs = []
            
            # Reset gradient accumulation counter
            accum_step = 0
            
            # Process data with progress bar
            pbar = tqdm(dataloaders[phase], desc=phase)
            
            for inputs, labels in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Zero gradients only when accumulation steps reached
                if phase == 'train' and accum_step % grad_accum_steps == 0:
                    optimizer.zero_grad()
                
                # Forward pass with mixed precision in training
                with torch.cuda.amp.autocast(enabled=CONFIG['use_mixed_precision']):
                    with torch.set_grad_enabled(phase == 'train'):
                        # Apply mixup in training phase if enabled
                        if phase == 'train' and use_mixup and np.random.random() < 0.5:
                            # Apply mixup
                            inputs_mixed, targets_a, targets_b, lam = mixup_data(inputs, labels, alpha=mixup_alpha)
                            outputs = model(inputs_mixed)
                            loss = mixup_criterion(criterion, outputs, targets_a, targets_b, lam)
                            
                            # For predictions collection (use first target for simplicity)
                            _, preds = torch.max(outputs, 1)
                            all_labels.extend(targets_a.cpu().numpy())
                            all_preds.extend(preds.cpu().numpy())
                        else:
                            # Standard forward pass
                            outputs = model(inputs)
                            loss = criterion(outputs, labels)
                            
                            # Get predictions
                            _, preds = torch.max(outputs, 1)
                            
                            # Collect predictions and labels
                            all_labels.extend(labels.cpu().numpy())
                            all_preds.extend(preds.cpu().numpy())
                            
                            # Get probabilities for metrics in validation phase
                            if phase != 'train':
                                probs = F.softmax(outputs, dim=1)
                                all_probs.extend(probs.cpu().numpy())
                        
                        # Scale loss for gradient accumulation in training
                        if phase == 'train':
                            loss = loss / grad_accum_steps
                        
                        # Backward + optimize only in training phase
                        if phase == 'train':
                            # Use scaler for mixed precision
                            scaler.scale(loss).backward()
                            
                            # Step optimizer if accumulation steps reached
                            if (accum_step + 1) % grad_accum_steps == 0:
                                # Unscale for gradient clipping
                                scaler.unscale_(optimizer)
                                
                                # Clip gradients
                                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                                
                                # Optimizer step with scaler
                                scaler.step(optimizer)
                                scaler.update()
                                
                                # Step scheduler if provided
                                if scheduler is not None:
                                    if CONFIG.get('scheduler', 'cosine_warmup') == 'cosine_warmup':
                                        scheduler.step()
                            
                            accum_step += 1
                
                # Statistics - scale loss back up for reporting
                if phase == 'train':
                    current_loss = loss.item() * inputs.size(0) * grad_accum_steps
                else:
                    current_loss = loss.item() * inputs.size(0)
                
                running_loss += current_loss
                
                # Update progress bar
                current_loss_avg = running_loss / ((pbar.n + 1) * inputs.size(0))
                pbar.set_postfix({'loss': f'{current_loss_avg:.4f}'})
                
                # Clean up memory
                del inputs, outputs, loss
                torch.cuda.empty_cache()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / dataset_sizes[phase]
            
            # Calculate metrics
            all_labels = np.array(all_labels)
            all_preds = np.array(all_preds)
            
            epoch_acc = accuracy_score(all_labels, all_preds)
            
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc)
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            else:
                # Calculate additional validation metrics for multi-class
                try:
                    all_probs = np.array(all_probs)
                    
                    # Calculate precision, recall, f1 for multi-class classification
                    precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
                    recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
                    f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
                    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
                    
                    # Calculate confusion matrix
                    cm = confusion_matrix(all_labels, all_preds)
                    
                    # Store validation metrics
                    history['val_loss'].append(epoch_loss)
                    history['val_acc'].append(epoch_acc)
                    history['val_precision'].append(precision)
                    history['val_recall'].append(recall)
                    history['val_f1'].append(f1)
                    history['val_f1_macro'].append(f1_macro)
                    
                    # Print detailed metrics
                    print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                    print(f'Precision: {precision:.4f} Recall: {recall:.4f}')
                    print(f'F1 (weighted): {f1:.4f} F1 (macro): {f1_macro:.4f}')
                    
                    # Choose optimization metric based on configuration
                    optimization_metric = CONFIG.get('optimization_metric', 'f1_macro')
                    if optimization_metric == 'f1_macro':
                        current_metric = f1_macro
                    elif optimization_metric == 'f1_weighted':
                        current_metric = f1
                    elif optimization_metric == 'accuracy':
                        current_metric = epoch_acc
                    else:
                        current_metric = f1_macro
                    
                    # Check if this is the best model
                    improved = False
                    if current_metric > best_metric:
                        best_metric = current_metric
                        improved = True
                        metric_name = optimization_metric
                    
                    # Save best metrics independently as well
                    if epoch_acc > best_acc:
                        best_acc = epoch_acc
                    if f1 > best_f1:
                        best_f1 = f1
                    
                    if improved:
                        counter = 0
                        print(f"Saving best model (improved {metric_name}: {current_metric:.4f})")
                        
                        # Save model checkpoint
                        torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'metrics': {
                                'accuracy': epoch_acc,
                                'precision': precision,
                                'recall': recall,
                                'f1_weighted': f1,
                                'f1_macro': f1_macro,
                                'confusion_matrix': cm.tolist()
                            },
                            'history': history
                        }, best_model_path)
                        
                        # Also save epoch-specific checkpoint for potential ensemble
                        torch.save({
                            'epoch': epoch,
                            'model_state_dict': model.state_dict(),
                            'metrics': {
                                'accuracy': epoch_acc,
                                'f1_weighted': f1,
                                'f1_macro': f1_macro
                            }
                        }, os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pt'))
                    else:
                        counter += 1
                        print(f"Early stopping counter: {counter}/{patience}")
                
                except Exception as e:
                    import traceback
                    print(f"Error calculating metrics: {e}")
                    traceback.print_exc()
            
            # Clean memory after each phase
            torch.cuda.empty_cache()
            gc.collect()
        
        # Check for early stopping
        if counter >= patience:
            print(f"Early stopping triggered at epoch {epoch+1}")
            break
        
        print()
    
    # Calculate training time
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best metrics: Acc: {best_acc:.4f}, F1 (weighted): {best_f1:.4f}, Best {CONFIG["optimization_metric"]}: {best_metric:.4f}')
    
    # Load best model weights
    if os.path.exists(best_model_path):
        checkpoint = torch.load(best_model_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        best_metrics = checkpoint.get('metrics', {})
        print("Loaded best model from checkpoint")
        
        # Print best model metrics
        if best_metrics:
            print("\nBest model metrics:")
            for k, v in best_metrics.items():
                if k != 'confusion_matrix':
                    print(f"  {k}: {v:.4f}")
    
    return model, history

def plot_training_results(history):
    """Plot enhanced training curves with more metrics for multi-class"""
    plt.figure(figsize=(20, 15))
    
    # Plot accuracy
    plt.subplot(3, 2, 1)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Model Accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(alpha=0.3)
    
    # Plot loss
    plt.subplot(3, 2, 2)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')
    plt.grid(alpha=0.3)
    
    # Plot precision, recall, f1
    plt.subplot(3, 2, 3)
    plt.plot(history['val_precision'], label='Precision (weighted)')
    plt.plot(history['val_recall'], label='Recall (weighted)')
    plt.plot(history['val_f1'], label='F1 Score (weighted)')
    plt.title('Precision, Recall, and F1 Score (Weighted)')
    plt.ylabel('Score')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(alpha=0.3)
    
    # Plot F1 scores (macro vs weighted)
    plt.subplot(3, 2, 4)
    plt.plot(history['val_f1'], label='F1 (weighted)')
    plt.plot(history['val_f1_macro'], label='F1 (macro)')
    plt.title('F1 Scores (Macro vs Weighted)')
    plt.ylabel('Score')
    plt.xlabel('Epoch')
    plt.legend(loc='lower right')
    plt.grid(alpha=0.3)
    
    # Extract metrics with matching epochs
    epochs = range(1, len(history['val_precision']) + 1)
    
    # Plot precision vs recall
    plt.subplot(3, 2, 5)
    plt.scatter(history['val_recall'], history['val_precision'], c=epochs, cmap='viridis')
    plt.colorbar(label='Epoch')
    
    # Add annotations for some points
    for i, epoch in enumerate(epochs):
        if i % 5 == 0 or i == len(epochs) - 1:  # Annotate every 5th epoch and last epoch
            plt.annotate(
                f'{epoch}', 
                (history['val_recall'][i], history['val_precision'][i]),
                textcoords="offset points",
                xytext=(0, 10),
                ha='center'
            )
    
    plt.title('Precision vs Recall (Weighted)')
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.grid(alpha=0.3)
    
    # Plot F1 macro vs Accuracy
    plt.subplot(3, 2, 6)
    plt.scatter(history['val_acc'], history['val_f1_macro'], c=epochs, cmap='viridis')
    plt.colorbar(label='Epoch')
    
    # Add annotations for some points
    for i, epoch in enumerate(epochs):
        if i % 5 == 0 or i == len(epochs) - 1:  # Annotate every 5th epoch and last epoch
            plt.annotate(
                f'{epoch}', 
                (history['val_acc'][i], history['val_f1_macro'][i]),
                textcoords="offset points",
                xytext=(0, 10),
                ha='center'
            )
    
    plt.title('F1 Macro vs Accuracy')
    plt.xlabel('Accuracy')
    plt.ylabel('F1 Macro')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('enhanced_beit_multiclass_training_results.png', dpi=300)

def evaluate_model(model, dataloader, class_names):
    """Enhanced evaluation for multi-class classification"""
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    # Collect predictions
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            inputs = inputs.to(device)
            
            # Use mixed precision for inference
            with torch.cuda.amp.autocast(enabled=CONFIG['use_mixed_precision']):
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                probs = F.softmax(outputs, dim=1)
            
            # Collect results
            all_labels.extend(labels.numpy())
            all_preds.extend(preds.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
            # Clean memory
            del inputs, outputs, preds, probs
            torch.cuda.empty_cache()
            
    # Convert to arrays
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    # Calculate standard metrics
    accuracy = accuracy_score(all_labels, all_preds)
    
    # Calculate precision, recall, f1 for multi-class classification
    precision_weighted = precision_score(all_labels, all_preds, average='weighted', zero_division=0)
    recall_weighted = recall_score(all_labels, all_preds, average='weighted', zero_division=0)
    f1_weighted = f1_score(all_labels, all_preds, average='weighted', zero_division=0)
    
    precision_macro = precision_score(all_labels, all_preds, average='macro', zero_division=0)
    recall_macro = recall_score(all_labels, all_preds, average='macro', zero_division=0)
    f1_macro = f1_score(all_labels, all_preds, average='macro', zero_division=0)
    
    # Calculate per-class metrics
    precision_per_class = precision_score(all_labels, all_preds, average=None, zero_division=0)
    recall_per_class = recall_score(all_labels, all_preds, average=None, zero_division=0)
    f1_per_class = f1_score(all_labels, all_preds, average=None, zero_division=0)
    
    # Confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Print metrics
    print("\n===== MULTI-CLASS MODEL EVALUATION =====")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Weighted Metrics:")
    print(f"  Precision: {precision_weighted:.4f}")
    print(f"  Recall: {recall_weighted:.4f}")
    print(f"  F1 Score: {f1_weighted:.4f}")
    print(f"Macro Metrics:")
    print(f"  Precision: {precision_macro:.4f}")
    print(f"  Recall: {recall_macro:.4f}")
    print(f"  F1 Score: {f1_macro:.4f}")
    
    print("\nPer-Class Metrics:")
    for i, class_name in enumerate(class_names):
        if i < len(precision_per_class):
            print(f"  {class_name}:")
            print(f"    Precision: {precision_per_class[i]:.4f}")
            print(f"    Recall: {recall_per_class[i]:.4f}")
            print(f"    F1 Score: {f1_per_class[i]:.4f}")
    
    # Plot confusion matrix with improved visualization
    plt.figure(figsize=(10, 8))
    
    # Create a normalized confusion matrix
    cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Plot the confusion matrix
    im = plt.imshow(cm_normalized, interpolation='nearest', cmap=plt.cm.Blues, vmin=0, vmax=1)
    plt.title('Normalized Confusion Matrix')
    plt.colorbar(im)
    
    # Add class labels
    tick_marks = np.arange(len(class_names))
    plt.xticks(tick_marks, class_names, rotation=45, ha='right')
    plt.yticks(tick_marks, class_names)
    
    # Add text annotations
    thresh = cm_normalized.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, f"{cm[i, j]}\n({cm_normalized[i, j]:.2f})",
                   horizontalalignment="center", fontsize=9,
                   color="white" if cm_normalized[i, j] > thresh else "black")
    
    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.savefig('enhanced_beit_multiclass_cm.png', dpi=300)
    
    # Plot per-class metrics
    plt.figure(figsize=(12, 6))
    x = np.arange(len(class_names))
    width = 0.25
    
    # Plot bars for precision, recall, and F1 for each class
    plt.bar(x - width, precision_per_class, width, label='Precision')
    plt.bar(x, recall_per_class, width, label='Recall')
    plt.bar(x + width, f1_per_class, width, label='F1 Score')
    
    plt.xlabel('Class')
    plt.ylabel('Score')
    plt.title('Per-Class Performance Metrics')
    plt.xticks(x, class_names, rotation=45, ha='right')
    plt.legend()
    plt.tight_layout()
    plt.savefig('enhanced_beit_multiclass_per_class_metrics.png', dpi=300)
    
    # Return comprehensive metrics
    return {
        'accuracy': accuracy,
        'precision_weighted': precision_weighted,
        'recall_weighted': recall_weighted,
        'f1_weighted': f1_weighted,
        'precision_macro': precision_macro,
        'recall_macro': recall_macro,
        'f1_macro': f1_macro,
        'precision_per_class': precision_per_class.tolist(),
        'recall_per_class': recall_per_class.tolist(),
        'f1_per_class': f1_per_class.tolist(),
        'confusion_matrix': cm.tolist(),
        'class_names': class_names
    }

def create_ensemble():
    """Create ensemble from saved model checkpoints"""
    # Get the paths of saved checkpoints
    checkpoint_dir = 'checkpoints'
    
    # Use the configured ensemble epochs or default epochs
    ensemble_epochs = CONFIG.get('ensemble_epochs', [25, 35, 45, 55])
    
    model_paths = [
        os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pt')
        for epoch in ensemble_epochs
        if os.path.exists(os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pt'))
    ]
    
    if len(model_paths) > 1:
        print(f"Creating ensemble from {len(model_paths)} checkpoints: {ensemble_epochs}")
        ensemble = EnsembleModel(model_paths, CONFIG['num_classes'])
        ensemble = ensemble.to(device)
        
        # Save the ensemble
        torch.save({
            'model_type': 'ensemble',
            'model_paths': model_paths,
            'config': CONFIG
        }, 'ensemble_model.pth')
        
        return ensemble
    else:
        print("Not enough models for ensemble")
        return None

def main():
    try:
        print("=== Enhanced BEiT Multi-Class Classification ===")
        
        # Load dataset with enhanced augmentations
        print("Loading multi-class dataset with enhanced augmentations...")
        dataloaders, dataset_sizes, class_names = load_data()
        
        # Create enhanced BEiT model
        print("Creating enhanced BEiT model with attention...")
        model = EnhancedBeitClassifier(
            num_classes=CONFIG['num_classes'],
            dropout_rate=CONFIG.get('dropout_rate', 0.3),
            freeze_backbone=CONFIG.get('freeze_backbone_epochs', 0) > 0
        )
        model = model.to(device)
        
        # Count parameters
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        print(f"Total parameters: {total_params:,}")
        print(f"Trainable parameters: {trainable_params:,}")
        
        # Get class counts for class-balanced loss
        train_labels = np.array([label for _, label in dataloaders['train'].dataset.samples])
        class_counts = np.bincount(train_labels, minlength=CONFIG['num_classes'])
        
        # Create loss function
        if CONFIG.get('use_class_balanced_loss', True):
            # Create class-balanced loss
            criterion = ClassBalancedLoss(
                samples_per_class=class_counts,
                num_classes=CONFIG['num_classes'],
                beta=0.9999,
                gamma=CONFIG.get('focal_loss_gamma', 3.0)
            )
            print(f"Using Class-Balanced Loss with gamma={CONFIG.get('focal_loss_gamma', 3.0)}")
        else:
            # Create standard focal loss with class weights
            weights = torch.tensor(CONFIG['class_weights']).float().to(device)
            
            criterion = FocalLoss(
                gamma=CONFIG.get('focal_loss_gamma', 3.0),
                alpha=weights
            )
            print(f"Using Focal Loss (gamma={CONFIG.get('focal_loss_gamma', 3.0)}) with class weights")
        
        # Create optimizer with weight decay
        optimizer = optim.AdamW(
            model.parameters(),
            lr=CONFIG['learning_rate'],
            weight_decay=CONFIG.get('weight_decay', 0.04),
            eps=1e-8
        )
        
        # Create scheduler
        if CONFIG.get('scheduler', 'cosine_warmup') == 'cosine_warmup':
            # Calculate total steps and warmup steps
            total_steps = CONFIG['num_epochs'] * len(dataloaders['train'])
            warmup_steps = CONFIG['warmup_epochs'] * len(dataloaders['train'])
            
            # Create custom scheduler with warmup
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, 
                warmup_steps, 
                total_steps, 
                min_lr_ratio=0.01
            )
            print(f"Using custom cosine scheduler with {CONFIG['warmup_epochs']} epochs of warmup")
        elif CONFIG.get('scheduler', '') == 'cosine':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer,
                T_max=CONFIG['num_epochs'],
                eta_min=CONFIG['learning_rate'] * 0.001
            )
            print("Using Cosine Annealing scheduler")
        else:
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(
                optimizer,
                mode='max',
                factor=0.5,
                patience=5,
                verbose=True
            )
            print("Using ReduceLROnPlateau scheduler")
        
        # Check if we should try to load a previous checkpoint
        if CONFIG.get('use_checkpoint', False):
            checkpoint_path = 'checkpoints/best_model.pt'
            if os.path.exists(checkpoint_path):
                checkpoint = torch.load(checkpoint_path)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                start_epoch = checkpoint.get('epoch', 0) + 1
                print(f"Loaded checkpoint from epoch {start_epoch-1}")
                
                # Print metrics from checkpoint
                if 'metrics' in checkpoint:
                    print("Checkpoint metrics:")
                    for k, v in checkpoint['metrics'].items():
                        if k != 'confusion_matrix':
                            print(f"  {k}: {v:.4f}")
                
                # Ask for confirmation before continuing training
                confirmation = input("Continue training from checkpoint? (y/n): ")
                if confirmation.lower() != 'y':
                    print("Starting fresh training...")
                    # Create fresh model
                    model = EnhancedBeitClassifier(
                        num_classes=CONFIG['num_classes'],
                        dropout_rate=CONFIG.get('dropout_rate', 0.3),
                        freeze_backbone=CONFIG.get('freeze_backbone_epochs', 0) > 0
                    )
                    model = model.to(device)
                    
                    # Create fresh optimizer
                    optimizer = optim.AdamW(
                        model.parameters(),
                        lr=CONFIG['learning_rate'],
                        weight_decay=CONFIG.get('weight_decay', 0.04),
                        eps=1e-8
                    )
        
        # Train model
        print("Starting model training with enhanced techniques...")
        model, history = train_model(
            model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, 
            num_epochs=CONFIG['num_epochs']
        )
        
        # Plot results
        plot_training_results(history)
        
        # Comprehensive evaluation
        print("Performing comprehensive evaluation...")
        test_phase = 'test' if 'test' in dataloaders else 'val'
        metrics = evaluate_model(model, dataloaders[test_phase], class_names)
        
        # Try to create ensemble if enabled
        if CONFIG.get('use_ensemble', True):
            print("Creating and evaluating ensemble model...")
            ensemble = create_ensemble()
            if ensemble is not None:
                ensemble_metrics = evaluate_model(ensemble, dataloaders[test_phase], class_names)
                print("\nEnsemble vs Single Model Comparison:")
                print(f"Single Model F1 Macro: {metrics['f1_macro']:.4f}")
                print(f"Ensemble Model F1 Macro: {ensemble_metrics['f1_macro']:.4f}")
                
                # Use ensemble for final model if it's better
                if ensemble_metrics['f1_macro'] > metrics['f1_macro']:
                    print("Ensemble model performs better - using it as final model")
                    model = ensemble
                    metrics = ensemble_metrics
        
        # Save model
        print("Saving enhanced model...")
        model = model.to('cpu')  # Move to CPU for saving
        torch.save({
            'model_state_dict': model.state_dict(),
            'class_names': class_names,
            'config': CONFIG,
            'metrics': metrics,
            'history': history
        }, 'enhanced_beit_multiclass_model.pth')
        
        print("Enhanced model saved to 'enhanced_beit_multiclass_model.pth'")
        
        # Create example usage code
        usage_code = """
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from transformers import BeitModel
import numpy as np

# Load the model
def load_enhanced_model(model_path):
    # Define the model class
    class EnhancedBeitClassifier(nn.Module):
        def __init__(self, num_classes=6, dropout_rate=0.3):
            super(EnhancedBeitClassifier, self).__init__()
            # Load pre-trained BEiT model
            self.beit = BeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
            self.feature_dim = self.beit.config.hidden_size
            
            # Add attention mechanism
            self.attention = nn.MultiheadAttention(embed_dim=self.feature_dim, num_heads=8, dropout=0.1)
            
            # Layer norm
            self.norm = nn.LayerNorm(self.feature_dim)
            
            # Classifier
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(self.feature_dim, 768),
                nn.GELU(),
                nn.LayerNorm(768),
                nn.Dropout(dropout_rate),
                nn.Linear(768, 384),
                nn.GELU(),
                nn.LayerNorm(384),
                nn.Dropout(dropout_rate),
                nn.Linear(384, num_classes)
            )
        
        def forward(self, x):
            # Extract features
            outputs = self.beit(x)
            cls_token = outputs.last_hidden_state[:, 0]
            
            # Apply attention
            cls_token_unsqueezed = cls_token.unsqueeze(0)
            attn_output, _ = self.attention(cls_token_unsqueezed, cls_token_unsqueezed, cls_token_unsqueezed)
            cls_token = attn_output.squeeze(0)
            
            # Apply layer norm
            normalized = self.norm(cls_token)
            
            # Apply classifier
            logits = self.classifier(normalized)
            return logits
    
    # Load the checkpoint
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Create model and load weights
    model = EnhancedBeitClassifier(num_classes=len(checkpoint.get('class_names', 6)))
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint

# Function to predict on a single image
def predict_multiclass(model, image_path, checkpoint):
    # Set up transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load and preprocess the image
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    
    # Make prediction
    with torch.no_grad():
        output = model(input_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)[0]
        
        # Get predicted class
        _, predicted_class = torch.max(probabilities, 0)
        predicted_class = predicted_class.item()
    
    # Get class name and probability
    class_names = checkpoint.get('class_names', ['he', 's1', 's2', 's3', 's4', 'un'])
    class_name = class_names[predicted_class]
    probability = probabilities[predicted_class].item() * 100  # Convert to percentage
    
    # Display results
    plt.figure(figsize=(10, 6))
    
    # Show the image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Prediction: {class_name} ({probability:.1f}%)")
    plt.axis('off')
    
    # Show probability bars for all classes
    plt.subplot(1, 2, 2)
    class_probs = probabilities.cpu().numpy() * 100
    
    # Sort by probability
    sorted_idx = np.argsort(class_probs)[::-1]
    sorted_classes = [class_names[i] for i in sorted_idx]
    sorted_probs = class_probs[sorted_idx]
    
    # Plot horizontal bars
    y_pos = np.arange(len(sorted_classes))
    plt.barh(y_pos, sorted_probs, align='center')
    plt.yticks(y_pos, sorted_classes)
    plt.xlabel('Probability (%)')
    plt.title('Class Probabilities')
    
    # Add probability values
    for i, v in enumerate(sorted_probs):
        plt.text(v + 1, i, f"{v:.1f}%", va='center')
    
    plt.tight_layout()
    plt.show()
    
    # Return detailed results
    class_probabilities = {class_names[i]: probabilities[i].item() * 100 for i in range(len(class_names))}
    
    return {
        'predicted_class': class_name,
        'confidence': probability,
        'all_probabilities': class_probabilities
    }

# Example usage
model, checkpoint = load_enhanced_model('enhanced_beit_multiclass_model.pth')
result = predict_multiclass(model, 'path/to/your/image.jpg', checkpoint)
print(f"Prediction: {result['predicted_class']} with {result['confidence']:.1f}% confidence")
print("\nAll class probabilities:")
sorted_probs = dict(sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True))
for cls, prob in sorted_probs.items():
    print(f"  {cls}: {prob:.2f}%")
        """
        
        with open('use_enhanced_multiclass_model.py', 'w') as f:
            f.write(usage_code)
        
        print("Usage code saved to use_enhanced_multiclass_model.py")
        
    except Exception as e:
        import traceback
        print(f"Error in main: {e}")
        traceback.print_exc()

# K-Fold Cross-Validation implementation
def main_kfold():
    try:
        print("=== Enhanced BEiT Multi-Class Classification with K-Fold Cross-Validation ===")
        
        # Load all data for k-fold cross validation
        print("Loading and preparing dataset for k-fold cross validation...")
        combined_dataset, class_names, all_labels = load_data_kfold()
        
        # Initialize k-fold cross validation
        k_folds = CONFIG.get('k_folds', 5)
        
        # Choose between standard or stratified k-fold
        if CONFIG.get('stratified', True):
            print(f"Using Stratified {k_folds}-Fold Cross Validation")
            kfold = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
            splits = kfold.split(np.arange(len(combined_dataset)), all_labels)
        else:
            print(f"Using standard {k_folds}-Fold Cross Validation")
            kfold = KFold(n_splits=k_folds, shuffle=True, random_state=42)
            splits = kfold.split(np.arange(len(combined_dataset)))
        
        # Store results for each fold
        fold_results = []
        best_models = []
        best_fold_score = 0
        best_fold_idx = -1
        
        # Run k-fold cross validation
        for fold, (train_idx, val_idx) in enumerate(splits):
            print(f"\n{'='*40}")
            print(f"FOLD {fold+1}/{k_folds}")
            print(f"{'='*40}")
            
            # Create dataloaders for this fold
            dataloaders, dataset_sizes = create_fold_dataloaders(
                combined_dataset, train_idx, val_idx, class_names
            )
            
            # Create model for this fold
            print(f"Creating model for fold {fold+1}...")
            model = EnhancedBeitClassifier(
                num_classes=CONFIG['num_classes'],
                dropout_rate=CONFIG.get('dropout_rate', 0.3),
                freeze_backbone=CONFIG.get('freeze_backbone_epochs', 0) > 0
            )
            model = model.to(device)
            
            # Get class counts for class-balanced loss
            train_labels = []
            for idx in train_idx:
                if isinstance(combined_dataset, ConcatDataset):
                    # Find which dataset the index belongs to
                    dataset_idx = 0
                    temp_idx = idx
                    while temp_idx >= len(combined_dataset.datasets[dataset_idx]):
                        temp_idx -= len(combined_dataset.datasets[dataset_idx])
                        dataset_idx += 1
                    train_labels.append(combined_dataset.datasets[dataset_idx].samples[temp_idx][1])
                else:
                    train_labels.append(combined_dataset.samples[idx][1])
            
            train_labels = np.array(train_labels)
            class_counts = np.bincount(train_labels, minlength=CONFIG['num_classes'])
            
            # Create loss function
            if CONFIG.get('use_class_balanced_loss', True):
                # Create class-balanced loss
                criterion = ClassBalancedLoss(
                    samples_per_class=class_counts,
                    num_classes=CONFIG['num_classes'],
                    beta=0.9999,
                    gamma=CONFIG.get('focal_loss_gamma', 3.0)
                )
                print(f"Using Class-Balanced Loss with gamma={CONFIG.get('focal_loss_gamma', 3.0)}")
            else:
                # Create standard focal loss with class weights
                weights = torch.tensor(CONFIG['class_weights']).float().to(device)
                
                criterion = FocalLoss(
                    gamma=CONFIG.get('focal_loss_gamma', 3.0),
                    alpha=weights
                )
                print(f"Using Focal Loss (gamma={CONFIG.get('focal_loss_gamma', 3.0)}) with class weights")
            
            # Create optimizer with weight decay
            optimizer = optim.AdamW(
                model.parameters(),
                lr=CONFIG['learning_rate'],
                weight_decay=CONFIG.get('weight_decay', 0.04),
                eps=1e-8
            )
            
            # Create scheduler
            if CONFIG.get('scheduler', 'cosine_warmup') == 'cosine_warmup':
                # Calculate total steps and warmup steps
                total_steps = CONFIG['num_epochs'] * len(dataloaders['train'])
                warmup_steps = CONFIG['warmup_epochs'] * len(dataloaders['train'])
                
                # Create custom scheduler with warmup
                scheduler = get_cosine_schedule_with_warmup(
                    optimizer, 
                    warmup_steps, 
                    total_steps, 
                    min_lr_ratio=0.01
                )
                print(f"Using custom cosine scheduler with {CONFIG['warmup_epochs']} epochs of warmup")
            else:
                scheduler = None
            
            # Train model for this fold
            print(f"Training model for fold {fold+1}...")
            model, history = train_model(
                model, dataloaders, dataset_sizes, criterion, optimizer, scheduler, 
                num_epochs=CONFIG['num_epochs']
            )
            
            # Evaluate model on validation set
            print(f"Evaluating model for fold {fold+1}...")
            metrics = evaluate_model(model, dataloaders['val'], class_names)
            
            # Store results for this fold
            fold_metrics = {
                'fold': fold + 1,
                'accuracy': metrics['accuracy'],
                'precision_weighted': metrics['precision_weighted'],
                'recall_weighted': metrics['recall_weighted'],
                'f1_weighted': metrics['f1_weighted'],
                'f1_macro': metrics['f1_macro'],
                'history': history
            }
            
            fold_results.append(fold_metrics)
            
            # Store the best model across all folds
            optimization_metric = CONFIG.get('optimization_metric', 'f1_macro')
            current_metric = fold_metrics[optimization_metric]
            
            # Check if this is the best fold
            if current_metric > best_fold_score:
                best_fold_score = current_metric
                best_fold_idx = fold
                
                # Save the best model overall
                print(f"New best fold! Saving model from fold {fold+1}")
                torch.save({
                    'fold': fold + 1,
                    'model_state_dict': model.state_dict(),
                    'metrics': metrics,
                    'history': history,
                    'class_names': class_names
                }, f'best_kfold_model.pth')
            
            # Save this fold's model if requested
            if CONFIG.get('save_all_folds', True):
                # Create folder for fold models if it doesn't exist
                os.makedirs('fold_models', exist_ok=True)
                
                # Save fold model
                torch.save({
                    'fold': fold + 1,
                    'model_state_dict': model.state_dict(),
                    'metrics': metrics,
                    'history': history,
                    'class_names': class_names
                }, f'fold_models/fold_{fold+1}_model.pth')
            
            # Keep model in memory for potential ensemble
            model = model.to('cpu')  # Move to CPU to save GPU memory
            best_models.append(model)
            
            # Plot training curves for this fold
            plt.figure(figsize=(15, 10))
            
            # Plot accuracy
            plt.subplot(2, 2, 1)
            plt.plot(history['train_acc'], label='Train Accuracy')
            plt.plot(history['val_acc'], label='Validation Accuracy')
            plt.title(f'Fold {fold+1} - Model Accuracy')
            plt.ylabel('Accuracy')
            plt.xlabel('Epoch')
            plt.legend(loc='lower right')
            plt.grid(alpha=0.3)
            
            # Plot loss
            plt.subplot(2, 2, 2)
            plt.plot(history['train_loss'], label='Train Loss')
            plt.plot(history['val_loss'], label='Validation Loss')
            plt.title(f'Fold {fold+1} - Model Loss')
            plt.ylabel('Loss')
            plt.xlabel('Epoch')
            plt.legend(loc='upper right')
            plt.grid(alpha=0.3)
            
            # Plot precision, recall, f1
            plt.subplot(2, 2, 3)
            plt.plot(history['val_precision'], label='Precision (weighted)')
            plt.plot(history['val_recall'], label='Recall (weighted)')
            plt.plot(history['val_f1'], label='F1 Score (weighted)')
            plt.title(f'Fold {fold+1} - Precision, Recall, and F1 Score')
            plt.ylabel('Score')
            plt.xlabel('Epoch')
            plt.legend(loc='lower right')
            plt.grid(alpha=0.3)
            
            # Plot F1 scores (macro vs weighted)
            plt.subplot(2, 2, 4)
            plt.plot(history['val_f1'], label='F1 (weighted)')
            plt.plot(history['val_f1_macro'], label='F1 (macro)')
            plt.title(f'Fold {fold+1} - F1 Scores')
            plt.ylabel('Score')
            plt.xlabel('Epoch')
            plt.legend(loc='lower right')
            plt.grid(alpha=0.3)
            
            plt.tight_layout()
            plt.savefig(f'fold_{fold+1}_training_curves.png', dpi=200)
            plt.close()
            
            # Clear some memory
            torch.cuda.empty_cache()
            gc.collect()
        
        # Create a summary DataFrame of fold results
        results_df = pd.DataFrame(fold_results)
        print("\n=== K-Fold Cross-Validation Results ===")
        print(results_df[['fold', 'accuracy', 'f1_weighted', 'f1_macro']])
        
        # Calculate mean and std of metrics across folds
        mean_accuracy = results_df['accuracy'].mean()
        std_accuracy = results_df['accuracy'].std()
        mean_f1_weighted = results_df['f1_weighted'].mean()
        std_f1_weighted = results_df['f1_weighted'].std()
        mean_f1_macro = results_df['f1_macro'].mean()
        std_f1_macro = results_df['f1_macro'].std()
        
        print("\n=== Overall Cross-Validation Performance ===")
        print(f"Accuracy: {mean_accuracy:.4f} ± {std_accuracy:.4f}")
        print(f"F1 (weighted): {mean_f1_weighted:.4f} ± {std_f1_weighted:.4f}")
        print(f"F1 (macro): {mean_f1_macro:.4f} ± {std_f1_macro:.4f}")
        
        # Save the validation results
        results_df.to_csv('kfold_validation_results.csv', index=False)
        
        # Create an ensemble from all folds if requested
        if CONFIG.get('use_ensemble', True) and len(best_models) > 1:
            print("\n=== Creating Ensemble from All Folds ===")
            
            # Create a simple ensemble model
            ensemble_model = KFoldEnsemble(best_models).to(device)
            
            # Save the ensemble model info
            torch.save({
                'model_type': 'kfold_ensemble',
                'num_folds': k_folds,
                'class_names': class_names,
            }, 'kfold_ensemble_info.pth')
            
            # Save individual fold models for the ensemble
            os.makedirs('ensemble_models', exist_ok=True)
            for i, model in enumerate(best_models):
                torch.save({
                    'fold': i + 1,
                    'model_state_dict': model.state_dict(),
                }, f'ensemble_models/fold_{i+1}_model.pth')
            
            print(f"Ensemble model created from {len(best_models)} folds")
            print("Ensemble model information saved to 'kfold_ensemble_info.pth'")
            print("Individual fold models saved in 'ensemble_models/' directory")
            
            # Create usage code for the ensemble model
            usage_code = """
import torch
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import torch.nn as nn
from transformers import BeitModel
import numpy as np
import os

# Load the k-fold ensemble model
def load_kfold_ensemble(ensemble_info_path='kfold_ensemble_info.pth', models_dir='ensemble_models'):
    # Define the model class
    class EnhancedBeitClassifier(nn.Module):
        def __init__(self, num_classes=6, dropout_rate=0.3):
            super(EnhancedBeitClassifier, self).__init__()
            # Load pre-trained BEiT model
            self.beit = BeitModel.from_pretrained("microsoft/beit-base-patch16-224-pt22k-ft22k")
            self.feature_dim = self.beit.config.hidden_size
            
            # Add attention mechanism
            self.attention = nn.MultiheadAttention(embed_dim=self.feature_dim, num_heads=8, dropout=0.1)
            
            # Layer norm
            self.norm = nn.LayerNorm(self.feature_dim)
            
            # Classifier
            self.classifier = nn.Sequential(
                nn.Dropout(dropout_rate),
                nn.Linear(self.feature_dim, 768),
                nn.GELU(),
                nn.LayerNorm(768),
                nn.Dropout(dropout_rate),
                nn.Linear(768, 384),
                nn.GELU(),
                nn.LayerNorm(384),
                nn.Dropout(dropout_rate),
                nn.Linear(384, num_classes)
            )
        
        def forward(self, x):
            # Extract features
            outputs = self.beit(x)
            cls_token = outputs.last_hidden_state[:, 0]
            
            # Apply attention
            cls_token_unsqueezed = cls_token.unsqueeze(0)
            attn_output, _ = self.attention(cls_token_unsqueezed, cls_token_unsqueezed, cls_token_unsqueezed)
            cls_token = attn_output.squeeze(0)
            
            # Apply layer norm
            normalized = self.norm(cls_token)
            
            # Apply classifier
            logits = self.classifier(normalized)
            return logits
    
    # Define the ensemble model
    class KFoldEnsemble(nn.Module):
        def __init__(self, models):
            super(KFoldEnsemble, self).__init__()
            self.models = nn.ModuleList(models)
            
        def forward(self, x):
            outputs = []
            for model in self.models:
                model.eval()
                with torch.no_grad():
                    outputs.append(model(x))
            
            # Average the outputs
            return torch.mean(torch.stack(outputs), dim=0)
    
    # Load the ensemble info
    ensemble_info = torch.load(ensemble_info_path, map_location='cpu')
    class_names = ensemble_info.get('class_names')
    num_classes = len(class_names)
    
    # Create individual models
    models = []
    fold_paths = [f for f in os.listdir(models_dir) if f.startswith('fold_') and f.endswith('_model.pth')]
    
    for model_file in sorted(fold_paths):
        # Create model
        model = EnhancedBeitClassifier(num_classes=num_classes)
        
        # Load weights
        checkpoint = torch.load(os.path.join(models_dir, model_file), map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
        
        # Add to ensemble
        models.append(model)
    
    # Create the ensemble
    ensemble = KFoldEnsemble(models)
    ensemble.eval()
    
    return ensemble, ensemble_info

# Function to predict with the ensemble
def predict_with_ensemble(ensemble, image_path, ensemble_info):
    # Set up transforms
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
    
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    
    # Make prediction
    with torch.no_grad():
        output = ensemble(input_tensor)
        probabilities = torch.nn.functional.softmax(output, dim=1)[0]
        
        # Get predicted class
        _, predicted_class = torch.max(probabilities, 0)
        predicted_class = predicted_class.item()
    
    # Get class name and probability
    class_names = ensemble_info.get('class_names', ['he', 's1', 's2', 's3', 's4', 'un'])
    class_name = class_names[predicted_class]
    probability = probabilities[predicted_class].item() * 100  # Convert to percentage
    
    # Display results
    plt.figure(figsize=(10, 6))
    
    # Show the image
    plt.subplot(1, 2, 1)
    plt.imshow(image)
    plt.title(f"Ensemble Prediction: {class_name} ({probability:.1f}%)")
    plt.axis('off')
    
    # Show probability bars for all classes
    plt.subplot(1, 2, 2)
    class_probs = probabilities.cpu().numpy() * 100
    
    # Sort by probability
    sorted_idx = np.argsort(class_probs)[::-1]
    sorted_classes = [class_names[i] for i in sorted_idx]
    sorted_probs = class_probs[sorted_idx]
    
    # Plot horizontal bars
    y_pos = np.arange(len(sorted_classes))
    plt.barh(y_pos, sorted_probs, align='center')
    plt.yticks(y_pos, sorted_classes)
    plt.xlabel('Probability (%)')
    plt.title('Class Probabilities from K-Fold Ensemble')
    
    # Add probability values
    for i, v in enumerate(sorted_probs):
        plt.text(v + 1, i, f"{v:.1f}%", va='center')
    
    plt.tight_layout()
    plt.show()
    
    # Return detailed results
    class_probabilities = {class_names[i]: probabilities[i].item() * 100 for i in range(len(class_names))}
    
    return {
        'predicted_class': class_name,
        'confidence': probability,
        'all_probabilities': class_probabilities
    }

# Example usage
ensemble, ensemble_info = load_kfold_ensemble()
result = predict_with_ensemble(ensemble, 'path/to/your/image.jpg', ensemble_info)
print(f"Ensemble Prediction: {result['predicted_class']} with {result['confidence']:.1f}% confidence")
print("\nAll class probabilities:")
sorted_probs = dict(sorted(result['all_probabilities'].items(), key=lambda x: x[1], reverse=True))
for cls, prob in sorted_probs.items():
    print(f"  {cls}: {prob:.2f}%")
            """
            
            with open('use_kfold_ensemble_model.py', 'w') as f:
                f.write(usage_code)
            
            print("K-fold ensemble usage code saved to use_kfold_ensemble_model.py")
            
        print("\nK-fold cross-validation complete!")
        
    except Exception as e:
        import traceback
        print(f"Error in k-fold cross-validation: {e}")
        traceback.print_exc()

if __name__ == "__main__":
    if CONFIG.get('run_kfold', True):
        main_kfold()
    else:
        main()