In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import torchvision
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
import gc
import warnings
from tqdm import tqdm
from transformers import BeitModel, BeitConfig, AutoImageProcessor
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc
import timm
from torch.cuda.amp import autocast, GradScaler
import random
from collections import defaultdict
import torch.nn.functional as F  # Ensure this is imported
from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau


warnings.filterwarnings('ignore')

# Set seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed()

# Enhanced configuration for 99% accuracy
CONFIG = {
    'data_dir': "D:/Work/project work/archive_lungs/chest_xray/chest_xray",
    'batch_size': 8,
    'eval_batch_size': 8,
    'num_epochs': 50,
    'learning_rate': 0.00002,  # Lower learning rate for more precise convergence
    'image_size': 384,  # Higher resolution for more detailed features
    'beit_model': "microsoft/beit-large-patch16-384",  # Larger, higher-capacity model
    'use_mixed_precision': True,
    'gradient_accumulation_steps': 2,
    'weight_decay': 0.05,  # Higher weight decay to prevent overfitting
    'class_weight_normal': 1.5,
    'class_weight_pneumonia': 1.0,
    'augmentation_strength': 'extreme',  # Very strong augmentations
    'cutmix_prob': 0.5,  # CutMix augmentation probability
    'mixup_prob': 0.5,  # MixUp augmentation probability
    'mixup_alpha': 0.4,
    'scheduler': 'cosine_with_warmup',
    'warmup_epochs': 5,
    'patience': 10,  # More patience for early stopping
    'use_test_time_augmentation': True,  # Use TTA for better inference
    'tta_transforms': 5,  # Number of test-time augmentations
    'use_ensemble': True,  # Use model ensemble for best results
    'ensemble_weights': [0.7, 0.3],  # Weights for ensemble models
    'use_ema': True,  # Exponential moving average for model weights
    'ema_decay': 0.999,
    'clip_grad_norm': 1.0,
    'use_cosine_annealing': True,
    'final_div_factor': 1e4,  # Maximum lr reduction factor
    'plateau_patience': 5,  # Patience for learning rate reduction
    'plateau_factor': 0.5,  # Factor by which learning rate will be reduced
    'min_lr': 1e-7,  # Minimum learning rate
    'dropout_rate': 0.3,  # Dropout rate for regularization
    'stochastic_depth_prob': 0.1,  # Stochastic depth probability
    'label_smoothing': 0.1,  # Label smoothing factor
    'use_advanced_normalization': True,  # Use more advanced normalization techniques
    'use_timm': True,  # Use timm library for model
    'timm_model': 'beit_large_patch16_384',  # Timm model name
    'use_dynamic_loss_scaling': True,  # Dynamic loss scaling for mixed precision
    'exponential_moving_average': True,  # Use EMA for model weights
    'val_percent': 0.1,  # Percentage of data to use for validation
    'balanced_sampler': True,  # Use balanced sampler for class balance
    'lookahead_optimizer': True,  # Use lookahead optimizer
    'lookahead_k': 5,  # Lookahead parameter k
    'lookahead_alpha': 0.5,  # Lookahead parameter alpha
    'swa_start': 30,  # When to start Stochastic Weight Averaging
    'swa_freq': 5,  # Frequency of SWA model updates
    'swa_lr': 0.001,  # SWA learning rate
    'training_visualization': True,  # Visualize training progress
    'target_accuracy': 0.99,  # Target accuracy
    'precision_weight': 0.5,  # Weight for precision in custom loss
    'recall_weight': 0.5,  # Weight for recall in custom loss
    'use_weighted_ensemble': True,  # Use weighted ensemble
    'pixel_normalization': True,  # Use pixel normalization
    'use_multiscale_inputs': True,  # Use multi-scale inputs
    'use_progressive_resizing': False,  # Use progressive resizing
    'start_size': 384,  # Starting image size
    'progressive_epochs': [10, 20, 30],  # Epochs to increase size
    'progressive_sizes': [256, 320, 384],  # Progressive sizes
    'use_multi_stage_training': True,  # Train model in multiple stages
    'geometric_aug_prob': 0.75,  # Probability of geometric augmentations
    'photometric_aug_prob': 0.75,  # Probability of photometric augmentations
    'normalize_means': [0.5056, 0.5056, 0.5056],  # Custom normalization means
    'normalize_stds': [0.252, 0.252, 0.252],  # Custom normalization stds
    'advanced_post_processing': True,  # Apply advanced post-processing
    'use_attention_pooling': True,  # Use attention pooling
    'use_deep_supervision': True,  # Use deep supervision
    'discriminative_lr': True,  # Use discriminative learning rates
    'discriminative_lr_factor': 0.1,  # Factor between layer groups
}

# Device configuration
torch.cuda.empty_cache()
gc.collect()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Memory optimization for CUDA
if torch.cuda.is_available():
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    try:
        torch.cuda.set_per_process_memory_fraction(0.9)
    except:
        pass
    print(f"Total GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Exponential Moving Average for model weights
class ModelEMA:
    def __init__(self, model, decay=0.9999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}
        self.register()

    def register(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                new_average = (1 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.shadow
                self.backup[name] = param.data
                param.data = self.shadow[name]

    def restore(self):
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                assert name in self.backup
                param.data = self.backup[name]
        self.backup = {}

# Lookahead optimizer wrapper
class Lookahead(optim.Optimizer):
    def __init__(self, optimizer, k=5, alpha=0.5):
        self.optimizer = optimizer
        self.k = k
        self.alpha = alpha
        self.param_groups = self.optimizer.param_groups
        self.defaults = self.optimizer.defaults
        self.state = defaultdict(dict)
        self.fast_state = self.optimizer.state
        for group in self.param_groups:
            group['counter'] = 0

    def update(self, group):
        for fast in group['params']:
            param_state = self.state[fast]
            if 'slow_param' not in param_state:
                param_state['slow_param'] = torch.zeros_like(fast.data)
                param_state['slow_param'].copy_(fast.data)
            slow = param_state['slow_param']
            slow += self.alpha * (fast.data - slow)
            fast.data.copy_(slow)

    def update_lookahead(self):
        for group in self.param_groups:
            self.update(group)

    def step(self, closure=None):
        loss = self.optimizer.step(closure)
        for group in self.param_groups:
            group['counter'] += 1
            if group['counter'] % self.k == 0:
                self.update(group)
        return loss

    def state_dict(self):
        fast_state_dict = self.optimizer.state_dict()
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in self.state.items()
        }
        fast_state = fast_state_dict['state']
        param_groups = fast_state_dict['param_groups']
        return {
            'fast_state': fast_state,
            'slow_state': slow_state,
            'param_groups': param_groups,
        }

    def load_state_dict(self, state_dict):
        fast_state_dict = {
            'state': state_dict['fast_state'],
            'param_groups': state_dict['param_groups'],
        }
        self.optimizer.load_state_dict(fast_state_dict)
        slow_state = {
            (id(k) if isinstance(k, torch.Tensor) else k): v
            for k, v in state_dict['slow_state'].items()
        }
        self.state.update(slow_state)

# Advanced transformations for extreme data augmentation
class AdvancedTransforms:
    @staticmethod
    def get_transforms(mode='train', augmentation_strength='extreme', image_size=384):
        if mode == 'train':
            if augmentation_strength == 'extreme':
                return transforms.Compose([
                    transforms.Resize((image_size, image_size)),
                    transforms.RandomApply([
                        transforms.RandomRotation(30),
                    ], p=0.7),
                    transforms.RandomApply([
                        transforms.RandomAffine(
                            degrees=20, 
                            translate=(0.2, 0.2), 
                            scale=(0.8, 1.2),
                            shear=20
                        ),
                    ], p=0.5),
                    transforms.RandomApply([
                        transforms.GaussianBlur(kernel_size=5, sigma=(0.1, 2.0)),
                    ], p=0.3),
                    transforms.RandomApply([
                        transforms.ColorJitter(
                            brightness=0.3,
                            contrast=0.3,
                            saturation=0.2,
                            hue=0.1
                        ),
                    ], p=0.7),
                    transforms.RandomApply([
                        transforms.RandomPerspective(distortion_scale=0.3, p=1.0),
                    ], p=0.3),
                    transforms.RandomVerticalFlip(p=0.3),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomApply([
                        transforms.RandomAutocontrast(p=1.0),
                    ], p=0.3),
                    transforms.RandomApply([
                        transforms.RandomEqualize(p=1.0),
                    ], p=0.3),
                    transforms.RandomApply([
                        transforms.RandomAdjustSharpness(sharpness_factor=2.0, p=1.0),
                    ], p=0.3),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=CONFIG['normalize_means'],
                        std=CONFIG['normalize_stds']
                    ),
                    transforms.RandomApply([
                        transforms.RandomErasing(
                            p=1.0, 
                            scale=(0.02, 0.2), 
                            ratio=(0.3, 3.3), 
                            value=0
                        ),
                    ], p=0.4),
                ])
            elif augmentation_strength == 'strong':
                return transforms.Compose([
                    transforms.Resize((image_size, image_size)),
                    transforms.RandomRotation(20),
                    transforms.RandomAffine(
                        degrees=15, 
                        translate=(0.15, 0.15), 
                        scale=(0.85, 1.15)
                    ),
                    transforms.ColorJitter(
                        brightness=0.2,
                        contrast=0.2,
                        saturation=0.1
                    ),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomVerticalFlip(p=0.2),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=CONFIG['normalize_means'],
                        std=CONFIG['normalize_stds']
                    ),
                    transforms.RandomErasing(p=0.3, scale=(0.02, 0.15)),
                ])
            else:  # moderate or default
                return transforms.Compose([
                    transforms.Resize((image_size, image_size)),
                    transforms.RandomHorizontalFlip(p=0.5),
                    transforms.RandomRotation(10),
                    transforms.ColorJitter(brightness=0.1, contrast=0.1),
                    transforms.ToTensor(),
                    transforms.Normalize(
                        mean=CONFIG['normalize_means'],
                        std=CONFIG['normalize_stds']
                    ),
                ])
        else:  # test/val transforms
            return transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=CONFIG['normalize_means'],
                    std=CONFIG['normalize_stds']
                ),
            ])

# CutMix implementation
def cutmix(data, target, alpha=1.0):
    indices = torch.randperm(data.size(0)).to(data.device)
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.random.beta(alpha, alpha)
    lam = max(lam, 1 - lam)
    
    batch_size = data.size(0)
    bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam)
    data[:, :, bbx1:bbx2, bby1:bby2] = shuffled_data[:, :, bbx1:bbx2, bby1:bby2]
    
    # Adjust lambda to exactly match pixel ratio
    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size(-1) * data.size(-2)))
    
    return data, target, shuffled_target, lam

def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = int(W * cut_rat)
    cut_h = int(H * cut_rat)

    # Uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

# MixUp implementation
def mixup(data, target, alpha=1.0):
    indices = torch.randperm(data.size(0)).to(data.device)
    shuffled_data = data[indices]
    shuffled_target = target[indices]

    lam = np.random.beta(alpha, alpha)
    lam = max(lam, 1 - lam)
    
    data = lam * data + (1 - lam) * shuffled_data
    
    return data, target, shuffled_target, lam

# Multi-scale training
def get_multiscale_sizes(base_size=384, scales=[0.75, 1.0, 1.25]):
    return [int(base_size * scale) for scale in scales]

# High-performance BEiT model with attention pooling
class HighPerformanceBeitModel(nn.Module):
    def __init__(self, num_classes=2, pretrained=True, dropout_rate=0.3):
        super(HighPerformanceBeitModel, self).__init__()
        
        # Use timm for advanced model features
        if CONFIG.get('use_timm', False):
            self.backbone = timm.create_model(
                CONFIG['timm_model'],
                pretrained=pretrained,
                num_classes=0,
                drop_rate=dropout_rate,
                drop_path_rate=CONFIG.get('stochastic_depth_prob', 0.1)
            )
            print(f"Created timm model: {CONFIG['timm_model']}")
            feature_dim = self.backbone.num_features
        else:
            # Use transformers BEiT model
            model_name = CONFIG['beit_model']
            self.backbone = BeitModel.from_pretrained(model_name)
            print(f"Created transformers model: {model_name}")
            feature_dim = self.backbone.config.hidden_size
            
        # Layer normalization
        self.norm = nn.LayerNorm(feature_dim)
        
        # Attention pooling
        if CONFIG.get('use_attention_pooling', False):
            self.attention_pool = nn.Sequential(
                nn.Linear(feature_dim, feature_dim // 2),
                nn.Tanh(),
                nn.Linear(feature_dim // 2, 1)
            )
        
        # Multiple classifier heads
        self.classifier = nn.Sequential(
            nn.Dropout(dropout_rate),
            nn.Linear(feature_dim, 1024),
            nn.GELU(),
            nn.LayerNorm(1024),
            nn.Dropout(dropout_rate),
            nn.Linear(1024, 512),
            nn.GELU(),
            nn.LayerNorm(512),
            nn.Dropout(dropout_rate),
            nn.Linear(512, num_classes)
        )
        
        # Deep supervision
        if CONFIG.get('use_deep_supervision', False):
            self.aux_classifier = nn.Linear(feature_dim, num_classes)
        
        # Apply weight initialization
        self._init_weights()
    
    def _init_weights(self):
        # Initialize weights for better convergence
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.trunc_normal_(m.weight, std=0.02)
                if m.bias is not None:
                    torch.nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LayerNorm):
                torch.nn.init.ones_(m.weight)
                torch.nn.init.zeros_(m.bias)
    
    def attention_pooling(self, embeddings):
        # Apply attention mechanism
        attention_weights = self.attention_pool(embeddings)
        attention_weights = F.softmax(attention_weights, dim=1)
        attended_embedding = torch.sum(attention_weights * embeddings, dim=1)
        return attended_embedding
    
    def forward(self, x):
        if CONFIG.get('use_timm', False):
            # TIMM backbone
            features = self.backbone(x)
            
            # Apply normalization
            normalized = self.norm(features)
            
            # Apply attention pooling if enabled
            if CONFIG.get('use_attention_pooling', False):
                pooled = self.attention_pooling(normalized)
                logits = self.classifier(pooled)
            else:
                logits = self.classifier(normalized)
            
            # Apply deep supervision if enabled
            if CONFIG.get('use_deep_supervision', False) and self.training:
                aux_logits = self.aux_classifier(normalized)
                return logits, aux_logits
            else:
                return logits
        else:
            # Transformers backbone
            outputs = self.backbone(x)
            
            # Get CLS token
            cls_token = outputs.last_hidden_state[:, 0]
            
            # Apply normalization
            normalized = self.norm(cls_token)
            
            # Apply classifier
            logits = self.classifier(normalized)
            
            # Apply deep supervision if enabled
            if CONFIG.get('use_deep_supervision', False) and self.training:
                aux_logits = self.aux_classifier(normalized)
                return logits, aux_logits
            else:
                return logits

# Combined loss function
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.5, gamma=2.0, class_weights=None, label_smoothing=0.1):
        super(CombinedLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.class_weights = class_weights
        self.label_smoothing = label_smoothing
        self.ce_loss = nn.CrossEntropyLoss(
            weight=class_weights,
            label_smoothing=label_smoothing,
            reduction='none'
        )
    
    def forward(self, logits, targets, aux_logits=None, aux_targets=None, lam=None, targets_b=None):
        # Calculate standard cross-entropy loss
        ce = self.ce_loss(logits, targets)
        
        # Apply focal modulation if enabled
        if self.gamma > 0:
            pt = torch.exp(-ce)
            focal_weight = (1 - pt) ** self.gamma
            ce = focal_weight * ce
        
        # Apply auxiliary loss if provided
        if aux_logits is not None and aux_targets is not None:
            aux_ce = self.ce_loss(aux_logits, aux_targets)
            if self.gamma > 0:
                aux_pt = torch.exp(-aux_ce)
                aux_focal_weight = (1 - aux_pt) ** self.gamma
                aux_ce = aux_focal_weight * aux_ce
            loss = ce.mean() * 0.8 + aux_ce.mean() * 0.2
        else:
            loss = ce.mean()
        
        # Apply mixup loss if provided
        if lam is not None and targets_b is not None:
            ce_b = self.ce_loss(logits, targets_b)
            if self.gamma > 0:
                pt_b = torch.exp(-ce_b)
                focal_weight_b = (1 - pt_b) ** self.gamma
                ce_b = focal_weight_b * ce_b
            
            # Apply mixup weighting
            loss = lam * ce.mean() + (1 - lam) * ce_b.mean()
            
            # Apply auxiliary loss if provided
            if aux_logits is not None and aux_targets is not None:
                aux_ce_b = self.ce_loss(aux_logits, targets_b)
                if self.gamma > 0:
                    aux_pt_b = torch.exp(-aux_ce_b)
                    aux_focal_weight_b = (1 - aux_pt_b) ** self.gamma
                    aux_ce_b = aux_focal_weight_b * aux_ce_b
                
                # Combine main and auxiliary losses
                aux_loss = lam * aux_ce.mean() + (1 - lam) * aux_ce_b.mean()
                loss = loss * 0.8 + aux_loss * 0.2
        
        return loss

# Test-time augmentation
def test_time_augmentation(model, image, n_augments=5):
    model.eval()
    
    # Base prediction
    with torch.no_grad():
        base_pred = model(image.unsqueeze(0))
        base_pred = base_pred.softmax(dim=1) if isinstance(base_pred, torch.Tensor) else base_pred[0].softmax(dim=1)
        
    if n_augments <= 1:
        return base_pred
    
    # Define TTA transforms
    tta_transforms = [
        transforms.RandomHorizontalFlip(p=1.0),
        transforms.RandomVerticalFlip(p=1.0),
        transforms.RandomRotation(degrees=10),
        transforms.ColorJitter(brightness=0.1, contrast=0.1),
        transforms.GaussianBlur(kernel_size=5, sigma=0.5),
    ]
    
    # Original image dimensions
    c, h, w = image.shape
    
    # Initialize list to store all predictions
    all_preds = [base_pred]
    
    # Apply each transformation
    with torch.no_grad():
        for i in range(min(n_augments - 1, len(tta_transforms))):
            # Apply transform
            aug_image = tta_transforms[i](image)
            
            # Get prediction
            pred = model(aug_image.unsqueeze(0))
            pred = pred.softmax(dim=1) if isinstance(pred, torch.Tensor) else pred[0].softmax(dim=1)
            
            # Add to predictions
            all_preds.append(pred)
    
    # Average predictions
    final_pred = torch.stack(all_preds).mean(dim=0)
    
    return final_pred

# Continue the train_model function
def train_model(model, dataloaders, dataset_sizes, criterion, optimizer, scheduler=None, num_epochs=30):
    """Enhanced training function with advanced techniques"""
    since = time.time()
    
    # Initialize EMA if enabled
    if CONFIG.get('exponential_moving_average', False):
        ema_model = ModelEMA(model, decay=CONFIG.get('ema_decay', 0.999))
    
    # Set up checkpoint directory
    checkpoint_dir = 'high_acc_checkpoints'
    os.makedirs(checkpoint_dir, exist_ok=True)
    
    # Best model tracking
    best_model_path = os.path.join(checkpoint_dir, 'best_model.pt')
    best_acc = 0.0
    best_f1 = 0.0
    
    # History tracking
    history = {
        'train_loss': [], 'train_acc': [],
        'val_loss': [], 'val_acc': [],
        'val_precision': [], 'val_recall': [], 'val_f1': [],
        'val_specificity': [], 'val_sensitivity': [], 'lr': []
    }
    
    # Gradient scaling for mixed precision
    scaler = GradScaler(enabled=CONFIG.get('use_mixed_precision', True))
    
    # Early stopping
    patience = CONFIG.get('patience', 10)
    counter = 0
    
    # Progressive resizing
    if CONFIG.get('use_progressive_resizing', False):
        progressive_sizes = CONFIG.get('progressive_sizes', [256, 320, 384])
        progressive_epochs = CONFIG.get('progressive_epochs', [10, 20, 30])
        current_size_idx = 0
    
    # Learning rate warmup
    warmup_epochs = CONFIG.get('warmup_epochs', 5)
    
    # Gradient accumulation steps
    grad_accum_steps = CONFIG.get('gradient_accumulation_steps', 1)
    
    # Training loop
    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 30)
        
        # Update image size for progressive resizing
        if CONFIG.get('use_progressive_resizing', False):
            if current_size_idx < len(progressive_epochs) and epoch >= progressive_epochs[current_size_idx]:
                current_size = progressive_sizes[current_size_idx]
                print(f"Increasing image size to {current_size}")
                
                # Update data transforms
                for phase in ['train', 'val']:
                    if phase in dataloaders:
                        dataloaders[phase].dataset.transform = AdvancedTransforms.get_transforms(
                            mode=phase,
                            augmentation_strength=CONFIG.get('augmentation_strength', 'extreme'),
                            image_size=current_size
                        )
                
                current_size_idx = min(current_size_idx + 1, len(progressive_sizes) - 1)
        
        # Learning rate warmup
        if epoch < warmup_epochs:
            # Linear warmup
            warmup_factor = (epoch + 1) / warmup_epochs
            for param_group in optimizer.param_groups:
                param_group['lr'] = param_group['initial_lr'] * warmup_factor
        
        # Store current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        history['lr'].append(current_lr)
        print(f"Current learning rate: {current_lr:.7f}")
        
        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()
            
            running_loss = 0.0
            running_corrects = 0
            
            # Progress bar for batches
            pbar = tqdm(dataloaders[phase])
            
            # Reset gradient accumulation counter
            accum_step = 0
            
            # Track all predictions for metrics calculation
            all_labels = []
            all_preds = []
            all_probs = []
            
            # Process each batch
            for inputs, labels in pbar:
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                # Store labels for metrics
                all_labels.append(labels.cpu().numpy())
                
                # Zero gradients for optimizer
                if phase == 'train' and accum_step % grad_accum_steps == 0:
                    optimizer.zero_grad()
                
                # Forward pass with mixed precision
                with autocast(enabled=CONFIG.get('use_mixed_precision', True)):
                    # Apply mixup or cutmix for training
                    mixup_applied = False
                    cutmix_applied = False
                    lam = None
                    labels_b = None
                    
                    if phase == 'train':
                        # Apply mixup with probability
                        if np.random.rand() < CONFIG.get('mixup_prob', 0.5):
                            inputs, labels_a, labels_b, lam = mixup(
                                inputs, labels, alpha=CONFIG.get('mixup_alpha', 0.4)
                            )
                            mixup_applied = True
                        # Apply cutmix with probability
                        elif np.random.rand() < CONFIG.get('cutmix_prob', 0.5):
                            inputs, labels_a, labels_b, lam = cutmix(
                                inputs, labels, alpha=CONFIG.get('mixup_alpha', 0.4)
                            )
                            cutmix_applied = True
                    
                    # Forward pass through model
                    with torch.set_grad_enabled(phase == 'train'):
                        if CONFIG.get('use_deep_supervision', False) and phase == 'train':
                            # Forward pass with deep supervision
                            outputs, aux_outputs = model(inputs)
                            if mixup_applied or cutmix_applied:
                                loss = criterion(outputs, labels, aux_outputs, labels, lam, labels_b)
                            else:
                                loss = criterion(outputs, labels, aux_outputs, labels)
                        else:
                            # Standard forward pass
                            outputs = model(inputs)
                            if mixup_applied or cutmix_applied:
                                loss = criterion(outputs, labels, lam=lam, targets_b=labels_b)
                            else:
                                loss = criterion(outputs, labels)
                        
                        # Scale loss for gradient accumulation
                        if phase == 'train':
                            loss = loss / grad_accum_steps
                        
                        # Get predictions
                        _, preds = torch.max(outputs, 1)
                        
                        # Store predictions for metrics
                        if phase == 'val':
                            all_preds.append(preds.cpu().numpy())
                            probs = F.softmax(outputs, dim=1)
                            all_probs.append(probs.cpu().numpy())
                        
                        # Backward pass and optimize
                        if phase == 'train':
                            # Use gradient scaling
                            scaler.scale(loss).backward()
                            
                            # Update weights if accumulated enough gradients
                            if (accum_step + 1) % grad_accum_steps == 0:
                                # Unscale for gradient clipping
                                scaler.unscale_(optimizer)
                                
                                # Gradient clipping
                                torch.nn.utils.clip_grad_norm_(
                                    model.parameters(), 
                                    max_norm=CONFIG.get('clip_grad_norm', 1.0)
                                )
                                
                                # Optimizer step
                                scaler.step(optimizer)
                                scaler.update()
                                
                                # Update EMA model if enabled
                                if CONFIG.get('exponential_moving_average', False):
                                    ema_model.update()
                                
                                # Step scheduler if provided and not in warmup
                                if scheduler is not None and epoch >= warmup_epochs:
                                    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                                        # This will be stepped after the epoch
                                        pass
                                    else:
                                        scheduler.step()
                            
                            # Increment accumulation step
                            accum_step += 1
                
                # Calculate statistics
                if not (mixup_applied or cutmix_applied):
                    batch_loss = loss.item() * inputs.size(0)
                    batch_corrects = torch.sum(preds == labels.data)
                    running_loss += batch_loss
                    running_corrects += batch_corrects
                    
                    # Update progress bar
                    current_loss_avg = running_loss / ((pbar.n + 1) * inputs.size(0))
                    current_acc = running_corrects.double() / ((pbar.n + 1) * inputs.size(0))
                    pbar.set_postfix({
                        'loss': f'{current_loss_avg:.4f}',
                        'acc': f'{current_acc:.4f}'
                    })
                
                # Clean up memory
                del inputs, outputs, loss
                if phase == 'train' and (mixup_applied or cutmix_applied):
                    del labels_b
                
                # Force garbage collection
                torch.cuda.empty_cache()
                gc.collect()
            
            # Calculate epoch metrics
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            # Record metrics
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            else:  # validation phase
                # Process all predictions for metrics
                all_labels = np.concatenate(all_labels)
                all_preds = np.concatenate(all_preds)
                all_probs = np.concatenate(all_probs)
                
                # Calculate detailed metrics
                precision = precision_score(all_labels, all_preds, average='binary', zero_division=0)
                recall = recall_score(all_labels, all_preds, average='binary', zero_division=0)
                f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)
                
                # Calculate confusion matrix
                cm = confusion_matrix(all_labels, all_preds)
                if cm.shape == (2, 2):
                    tn, fp, fn, tp = cm.ravel()
                    specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
                    sensitivity = recall  # Same as recall
                else:
                    specificity = 0
                    sensitivity = 0
                
                # Record validation metrics
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                history['val_precision'].append(precision)
                history['val_recall'].append(recall)
                history['val_f1'].append(f1)
                history['val_specificity'].append(specificity)
                history['val_sensitivity'].append(sensitivity)
                
                # Print validation metrics
                print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
                print(f'Precision: {precision:.4f} Recall: {recall:.4f} F1: {f1:.4f}')
                print(f'Specificity: {specificity:.4f}')
                
                if cm.shape == (2, 2):
                    print(f'Confusion Matrix: TN={tn}, FP={fp}, FN={fn}, TP={tp}')
                
                # Step ReduceLROnPlateau scheduler if used
                if scheduler is not None and epoch >= warmup_epochs:
                    if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                        # Using F1 score as metric for scheduler
                        scheduler.step(f1)
                
                # Check if model improved
                improved = False
                if epoch_acc > best_acc:
                    best_acc = epoch_acc
                    improved = True
                    metric_name = 'accuracy'
                if f1 > best_f1:
                    best_f1 = f1
                    improved = True
                    metric_name = 'F1'
                
                # Save checkpoint if improved
                if improved:
                    counter = 0
                    print(f"Saving model (improved {metric_name})")
                    
                    # Save model state
                    if CONFIG.get('exponential_moving_average', False):
                        # Apply EMA weights for saving
                        ema_model.apply_shadow()
                        model_state = model.state_dict()
                        # Restore original weights
                        ema_model.restore()
                    else:
                        model_state = model.state_dict()
                    
                    # Save checkpoint
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model_state,
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                        'best_acc': best_acc,
                        'best_f1': best_f1,
                        'history': history,
                        'config': CONFIG,
                    }, best_model_path)
                    
                    # Save epoch-specific model for ensemble
                    if CONFIG.get('use_ensemble', False):
                        torch.save({
                            'epoch': epoch,
                            'model_state_dict': model_state,
                            'acc': epoch_acc.item(),
                            'f1': f1
                        }, os.path.join(checkpoint_dir, f'model_epoch_{epoch}.pt'))
                else:
                    # Increment early stopping counter
                    counter += 1
                    print(f"Early stopping counter: {counter}/{patience}")
                    
                    # Check if we reached target accuracy
                    if epoch_acc >= CONFIG.get('target_accuracy', 0.99):
                        print(f"Reached target accuracy of {CONFIG.get('target_accuracy', 0.99):.4f}!")
                        break
                
                # Early stopping
                if counter >= patience:
                    print(f"Early stopping triggered after {epoch+1} epochs")
                    break
        
        print()
    
    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}, Best F1: {best_f1:.4f}')
    
    # Load best model weights
    checkpoint = torch.load(best_model_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    
    return model, history

def create_dataloaders(data_dir, image_size=384, val_split=0.1):
    """Create dataloaders with validation split and advanced augmentations"""
    # Get transformations
    train_transform = AdvancedTransforms.get_transforms(
        mode='train',
        augmentation_strength=CONFIG.get('augmentation_strength', 'extreme'),
        image_size=CONFIG.get('start_size', image_size)
    )
    
    val_transform = AdvancedTransforms.get_transforms(
        mode='val',
        augmentation_strength='none',
        image_size=CONFIG.get('start_size', image_size)
    )
    
    test_transform = AdvancedTransforms.get_transforms(
        mode='val',
        augmentation_strength='none',
        image_size=image_size
    )
    
    # Load datasets
    train_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'train'),
        transform=train_transform
    )
    
    # Create or load test dataset
    test_dataset = datasets.ImageFolder(
        os.path.join(data_dir, 'test'),
        transform=test_transform
    )
    
    # Create validation set from training data
    if val_split > 0:
        # Get train dataset size
        train_size = len(train_dataset)
        
        # Calculate split sizes
        val_size = int(train_size * val_split)
        new_train_size = train_size - val_size
        
        # Split train dataset
        train_subset, val_subset = torch.utils.data.random_split(
            train_dataset, 
            [new_train_size, val_size]
        )
        
        # Apply correct transforms to validation subset
        val_subset.dataset.transform = val_transform
    else:
        train_subset = train_dataset
        val_subset = None
    
    # Create class weights for balanced sampling
    if CONFIG.get('balanced_sampler', False):
        # Count classes in train set
        targets = [label for _, label in train_dataset.samples]
        class_counts = np.bincount(targets)
        
        # Calculate weights
        class_weights = 1.0 / class_counts
        weights = class_weights[targets]
        
        # Create sampler
        train_sampler = torch.utils.data.WeightedRandomSampler(
            weights=weights,
            num_samples=len(train_subset),
            replacement=True
        )
    else:
        train_sampler = None
    
    # Create dataloaders
    dataloaders = {
        'train': torch.utils.data.DataLoader(
            train_subset,
            batch_size=CONFIG['batch_size'],
            sampler=train_sampler,
            shuffle=train_sampler is None,
            num_workers=0,
            pin_memory=False
        ),
        'test': torch.utils.data.DataLoader(
            test_dataset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )
    }
    
    # Add validation dataloader if we have a validation set
    if val_subset is not None:
        dataloaders['val'] = torch.utils.data.DataLoader(
            val_subset,
            batch_size=CONFIG['batch_size'],
            shuffle=False,
            num_workers=0,
            pin_memory=False
        )
    else:
        # Use test set as validation
        dataloaders['val'] = dataloaders['test']
    
    # Calculate dataset sizes
    dataset_sizes = {
        'train': len(train_subset),
        'val': len(val_subset) if val_subset is not None else len(test_dataset),
        'test': len(test_dataset)
    }
    
    # Get class names
    class_names = train_dataset.classes
    
    return dataloaders, dataset_sizes, class_names

def evaluate_model(model, dataloader, class_names):
    """Comprehensive model evaluation with test-time augmentation"""
    model.eval()
    
    all_labels = []
    all_preds = []
    all_probs = []
    
    # Use test-time augmentation if enabled
    use_tta = CONFIG.get('use_test_time_augmentation', False)
    tta_transforms = CONFIG.get('tta_transforms', 5)
    
    # Process all batches
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Evaluating"):
            all_labels.extend(labels.numpy())
            
            if use_tta:
                # Process each image with TTA
                batch_probs = []
                for i in range(inputs.size(0)):
                    # Apply TTA to single image
                    img = inputs[i].to(device)
                    probs = test_time_augmentation(model, img, n_augments=tta_transforms)
                    batch_probs.append(probs.cpu().numpy())
                
                # Convert to array
                batch_probs = np.array(batch_probs)
                all_probs.extend(batch_probs)
                
                # Get predictions from probabilities
                batch_preds = np.argmax(batch_probs, axis=1)
                all_preds.extend(batch_preds)
            else:
                # Standard forward pass
                inputs = inputs.to(device)
                outputs = model(inputs)
                _, preds = torch.max(outputs, 1)
                
                probs = F.softmax(outputs, dim=1)
                
                all_preds.extend(preds.cpu().numpy())
                all_probs.extend(probs.cpu().numpy())
            
            # Clean up memory
            del inputs
            torch.cuda.empty_cache()
    
    # Convert to arrays
    all_labels = np.array(all_labels)
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    
    # Calculate metrics
    accuracy = accuracy_score(all_labels, all_preds)
    precision = precision_score(all_labels, all_preds, average='binary', zero_division=0)
    recall = recall_score(all_labels, all_preds, average='binary', zero_division=0)
    f1 = f1_score(all_labels, all_preds, average='binary', zero_division=0)
    
    # Calculate confusion matrix
    cm = confusion_matrix(all_labels, all_preds)
    
    # Extract metrics from confusion matrix
    if cm.shape == (2, 2):
        tn, fp, fn, tp = cm.ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
        sensitivity = recall  # Same as recall
    else:
        specificity = 0
        sensitivity = 0
        tn, fp, fn, tp = 0, 0, 0, 0
    
    # Calculate AUC
    try:
        auc_score = roc_auc_score(all_labels, all_probs[:, 1])
    except:
        auc_score = 0
    
    # Print metrics
    print("\n===== FINAL MODEL EVALUATION =====")
    print(f"Accuracy: {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall (Sensitivity): {recall:.4f}")
    print(f"Specificity: {specificity:.4f}")
    print(f"F1 Score: {f1:.4f}")
    print(f"AUC: {auc_score:.4f}")
    
    print("\nConfusion Matrix:")
    print(f"  True Negatives: {tn}, False Positives: {fp}")
    print(f"  False Negatives: {fn}, True Positives: {tp}")
    
    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
    plt.title('Confusion Matrix')
    plt.colorbar()
    plt.xticks([0, 1], class_names)
    plt.yticks([0, 1], class_names)
    
    # Add text to cells
    thresh = cm.max() / 2.
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, format(cm[i, j], 'd'),
                   horizontalalignment="center", fontsize=16,
                   color="white" if cm[i, j] > thresh else "black")
    
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()
    plt.savefig('final_confusion_matrix.png', dpi=300)
    
    # Return all metrics
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'specificity': specificity,
        'f1': f1,
        'auc': auc_score,
        'confusion_matrix': cm,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'true_positives': tp
    }

def create_model_ensemble(checkpoint_dir, class_names, top_k=3, weights=None):
    """Create ensemble from multiple model checkpoints"""
    # Find all model checkpoints
    checkpoint_files = [f for f in os.listdir(checkpoint_dir) if f.startswith('model_epoch_')]
    
    if len(checkpoint_files) == 0:
        print("No checkpoint files found for ensemble")
        return None
    
    # Load metrics for each checkpoint
    checkpoint_metrics = []
    for file in checkpoint_files:
        try:
            checkpoint = torch.load(os.path.join(checkpoint_dir, file))
            checkpoint_metrics.append({
                'file': file,
                'acc': checkpoint.get('acc', 0),
                'f1': checkpoint.get('f1', 0),
                'epoch': checkpoint.get('epoch', 0)
            })
        except:
            print(f"Could not load checkpoint {file}")
    
    # Sort by F1 score
    checkpoint_metrics = sorted(checkpoint_metrics, key=lambda x: x['f1'], reverse=True)
    
    # Take top K models
    top_models = checkpoint_metrics[:top_k]
    
    print(f"Using top {top_k} models for ensemble:")
    for i, model_info in enumerate(top_models):
        print(f"  {i+1}. {model_info['file']} - F1: {model_info['f1']:.4f}, Acc: {model_info['acc']:.4f}")
    
    # Create models
    models = []
    for model_info in top_models:
        # Create model
        model = HighPerformanceBeitModel(num_classes=len(class_names))
        
        # Load weights
        checkpoint = torch.load(os.path.join(checkpoint_dir, model_info['file']))
        model.load_state_dict(checkpoint['model_state_dict'])
        
        # Set to evaluation mode
        model.eval()
        
        # Add to list
        models.append(model.to(device))
    
    # Create ensemble weights if not provided
    if weights is None:
        # Normalize F1 scores to get weights
        total_f1 = sum(model_info['f1'] for model_info in top_models)
        weights = [model_info['f1'] / total_f1 for model_info in top_models]
    
    # Ensure weights sum to 1
    weights = np.array(weights) / sum(weights)
    
    print(f"Ensemble weights: {weights}")
    
    return models, weights

def predict_with_ensemble(models, weights, inputs, use_tta=False, tta_transforms=5):
    """Make prediction using ensemble of models"""
    all_probs = []
    
    for i, model in enumerate(models):
        # Apply test-time augmentation if enabled
        if use_tta:
            batch_probs = []
            for j in range(inputs.size(0)):
                # Apply TTA to single image
                img = inputs[j].to(device)
                probs = test_time_augmentation(model, img, n_augments=tta_transforms)
                batch_probs.append(probs.cpu().numpy())
            
            # Convert to array
            batch_probs = np.array(batch_probs)
        else:
            # Standard forward pass
            with torch.no_grad():
                outputs = model(inputs)
                probs = F.softmax(outputs, dim=1).cpu().numpy()
                batch_probs = probs
        
        # Apply weight
        weighted_probs = batch_probs * weights[i]
        all_probs.append(weighted_probs)
    
    # Sum weighted probabilities
    final_probs = sum(all_probs)
    
    # Get predictions from probabilities
    final_preds = np.argmax(final_probs, axis=1)
    
    return final_preds, final_probs

def plot_training_history(history):
    """Plot detailed training history"""
    plt.figure(figsize=(20, 15))
    
    # Plot loss
    plt.subplot(3, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.plot(history['val_loss'], label='Validation Loss')
    plt.title('Loss vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Loss', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    
    # Plot accuracy
    plt.subplot(3, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Accuracy', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    
    # Plot precision and recall
    plt.subplot(3, 2, 3)
    plt.plot(history['val_precision'], label='Precision')
    plt.plot(history['val_recall'], label='Recall')
    plt.title('Precision & Recall vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Score', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    
    # Plot F1 score
    plt.subplot(3, 2, 4)
    plt.plot(history['val_f1'], label='F1 Score')
    plt.title('F1 Score vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('F1 Score', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    
    # Plot specificity and sensitivity
    plt.subplot(3, 2, 5)
    plt.plot(history['val_specificity'], label='Specificity')
    plt.plot(history['val_sensitivity'], label='Sensitivity')
    plt.title('Specificity & Sensitivity vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Score', fontsize=12)
    plt.legend(fontsize=12)
    plt.grid(alpha=0.3)
    
    # Plot learning rate
    plt.subplot(3, 2, 6)
    plt.plot(history['lr'], label='Learning Rate')
    plt.title('Learning Rate vs. Epoch', fontsize=14)
    plt.xlabel('Epoch', fontsize=12)
    plt.ylabel('Learning Rate', fontsize=12)
    plt.yscale('log')
    plt.grid(alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('training_history.png', dpi=300)
    
def main():
    """Main function to run the training and evaluation"""
    print("=== High Accuracy BEiT Pneumonia Detection ===")
    
    # Set seed for reproducibility
    set_seed(42)
    
    # Create dataloaders
    print("Creating dataloaders with advanced augmentations...")
    dataloaders, dataset_sizes, class_names = create_dataloaders(
        data_dir=CONFIG['data_dir'],
        image_size=CONFIG['image_size'],
        val_split=CONFIG.get('val_percent', 0.1)
    )
    
    print(f"Classes: {class_names}")
    print(f"Dataset sizes: {dataset_sizes}")
    
    # Create model
    print("Creating high-performance BEiT model...")
    model = HighPerformanceBeitModel(
        num_classes=len(class_names),
        pretrained=True,
        dropout_rate=CONFIG.get('dropout_rate', 0.3)
    )
    model = model.to(device)
    
    # Count parameters
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    
    # Create loss function
    if CONFIG.get('use_weighted_loss', True):
        # Setup class weights
        weight_normal = CONFIG.get('class_weight_normal', 1.5) 
        weight_pneumonia = CONFIG.get('class_weight_pneumonia', 1.0)
        class_weights = torch.tensor([weight_normal, weight_pneumonia]).float().to(device)
        
        # Create combined loss
        criterion = CombinedLoss(
            alpha=0.5,
            gamma=CONFIG.get('focal_loss_gamma', 2.0),
            class_weights=class_weights,
            label_smoothing=CONFIG.get('label_smoothing', 0.1)
        )
        print(f"Using Combined Loss with class weights [{weight_normal}, {weight_pneumonia}]")
    else:
        # Standard cross entropy
        criterion = nn.CrossEntropyLoss(
            label_smoothing=CONFIG.get('label_smoothing', 0.1)
        )
        print("Using CrossEntropyLoss with label smoothing")
    
    # Create optimizer with discriminative learning rates
    if CONFIG.get('discriminative_lr', False):
        # Group parameters by layers
        backbone_params = list(model.backbone.parameters())
        classifier_params = list(model.classifier.parameters())
        
        # Add other parameters if they exist
        other_params = []
        if hasattr(model, 'norm'):
            other_params.extend(list(model.norm.parameters()))
        if hasattr(model, 'attention_pool'):
            other_params.extend(list(model.attention_pool.parameters()))
        if hasattr(model, 'aux_classifier'):
            other_params.extend(list(model.aux_classifier.parameters()))
        
        # Create parameter groups with different learning rates
        lr = CONFIG['learning_rate']
        factor = CONFIG.get('discriminative_lr_factor', 0.1)
        param_groups = [
            {'params': backbone_params, 'lr': lr * factor, 'initial_lr': lr * factor},
            {'params': other_params, 'lr': lr * (factor * 5), 'initial_lr': lr * (factor * 5)},
            {'params': classifier_params, 'lr': lr, 'initial_lr': lr}
        ]
        print(f"Using discriminative learning rates: {lr * factor}, {lr * (factor * 5)}, {lr}")
    else:
        # Single learning rate for all parameters
        param_groups = model.parameters()
        lr = CONFIG['learning_rate']
        print(f"Using single learning rate: {lr}")
    
    # Create optimizer
    optimizer = optim.AdamW(
        param_groups,
        lr=CONFIG['learning_rate'],
        weight_decay=CONFIG.get('weight_decay', 0.05),
        eps=1e-8
    )
    
    # Apply Lookahead if enabled
    if CONFIG.get('lookahead_optimizer', False):
        optimizer = Lookahead(
            optimizer,
            k=CONFIG.get('lookahead_k', 5),
            alpha=CONFIG.get('lookahead_alpha', 0.5)
        )
        print("Using Lookahead optimizer wrapper")
    
    # Create scheduler
    if CONFIG.get('scheduler', 'cosine_with_warmup') == 'cosine_with_warmup':
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=CONFIG['num_epochs'],
            eta_min=CONFIG.get('min_lr', 1e-7)
        )
        print("Using CosineAnnealingLR scheduler")
    else:
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            mode='max',
            factor=CONFIG.get('plateau_factor', 0.5),
            patience=CONFIG.get('plateau_patience', 5),
            verbose=True,
            min_lr=CONFIG.get('min_lr', 1e-7)
        )
        print("Using ReduceLROnPlateau scheduler")
    
    # Train model
    print(f"Starting training for {CONFIG['num_epochs']} epochs...")
    
    # Check if a checkpoint should be loaded
    start_from_checkpoint = False
    best_model_path = os.path.join('high_acc_checkpoints', 'best_model.pt')
    
    if os.path.exists(best_model_path):
        print(f"Found existing checkpoint at {best_model_path}")
        response = input("Continue from checkpoint? (y/n): ")
        if response.lower() == 'y':
            try:
                checkpoint = torch.load(best_model_path)
                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                if scheduler is not None and 'scheduler_state_dict' in checkpoint and checkpoint['scheduler_state_dict']:
                    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
                    
                start_epoch = checkpoint.get('epoch', 0) + 1
                best_acc = checkpoint.get('best_acc', 0)
                best_f1 = checkpoint.get('best_f1', 0)
                history = checkpoint.get('history', {})
                
                print(f"Loaded checkpoint from epoch {start_epoch-1}")
                print(f"Best accuracy so far: {best_acc:.4f}, Best F1: {best_f1:.4f}")
                
                start_from_checkpoint = True
            except Exception as e:
                print(f"Error loading checkpoint: {e}")
                print("Starting fresh training")
                start_from_checkpoint = False
    
    # Train the model
    model, history = train_model(
        model=model,
        dataloaders=dataloaders,
        dataset_sizes=dataset_sizes,
        criterion=criterion,
        optimizer=optimizer,
        scheduler=scheduler,
        num_epochs=CONFIG['num_epochs']
    )
    
    # Plot training history
    if CONFIG.get('training_visualization', True):
        plot_training_history(history)
    
    # Evaluate on test set
    print("\nEvaluating final model on test set...")
    metrics = evaluate_model(model, dataloaders['test'], class_names)
    
    # Create ensemble if enabled
    if CONFIG.get('use_ensemble', True):
        print("\nCreating model ensemble...")
        ensemble_models, ensemble_weights = create_model_ensemble(
            checkpoint_dir='high_acc_checkpoints',
            class_names=class_names,
            top_k=3,
            weights=CONFIG.get('ensemble_weights', None)
        )
        
        if ensemble_models:
            print("Evaluating ensemble model...")
            
            # Process test set with ensemble
            ensemble_all_labels = []
            ensemble_all_preds = []
            
            for inputs, labels in tqdm(dataloaders['test'], desc="Ensemble evaluation"):
                ensemble_all_labels.extend(labels.numpy())
                
                inputs = inputs.to(device)
                preds, _ = predict_with_ensemble(
                    ensemble_models, 
                    ensemble_weights, 
                    inputs,
                    use_tta=CONFIG.get('use_test_time_augmentation', True),
                    tta_transforms=CONFIG.get('tta_transforms', 5)
                )
                
                ensemble_all_preds.extend(preds)
                
                # Clean up memory
                del inputs
                torch.cuda.empty_cache()
            
            # Calculate ensemble metrics
            ensemble_accuracy = accuracy_score(ensemble_all_labels, ensemble_all_preds)
            ensemble_precision = precision_score(ensemble_all_labels, ensemble_all_preds, average='binary', zero_division=0)
            ensemble_recall = recall_score(ensemble_all_labels, ensemble_all_preds, average='binary', zero_division=0)
            ensemble_f1 = f1_score(ensemble_all_labels, ensemble_all_preds, average='binary', zero_division=0)
            
            print(f"\nEnsemble Results:")
            print(f"Accuracy: {ensemble_accuracy:.4f}")
            print(f"Precision: {ensemble_precision:.4f}")
            print(f"Recall: {ensemble_recall:.4f}")
            print(f"F1 Score: {ensemble_f1:.4f}")
            
            # Compare with single model
            print("\nComparison:")
            print(f"Single Model Accuracy: {metrics['accuracy']:.4f}")
            print(f"Ensemble Model Accuracy: {ensemble_accuracy:.4f}")
            print(f"Improvement: {(ensemble_accuracy - metrics['accuracy']) * 100:.2f}%")
    
    # Create HTML report
    report_html = f"""
    <!DOCTYPE html>
    <html lang="en">
    <head>
        <meta charset="UTF-8">
        <meta name="viewport" content="width=device-width, initial-scale=1.0">
        <title>Pneumonia Detection Model Report</title>
        <style>
            body {{ font-family: Arial, sans-serif; line-height: 1.6; margin: 0; padding: 20px; color: #333; }}
            .container {{ max-width: 1200px; margin: 0 auto; }}
            h1 {{ color: #2c3e50; border-bottom: 2px solid #3498db; padding-bottom: 10px; }}
            h2 {{ color: #3498db; margin-top: 30px; }}
            .metric-container {{ display: flex; flex-wrap: wrap; gap: 20px; margin: 20px 0; }}
            .metric-box {{ background-color: #f8f9fa; border-radius: 8px; padding: 20px; box-shadow: 0 2px 10px rgba(0,0,0,0.1); width: 200px; }}
            .metric-value {{ font-size: 28px; font-weight: bold; color: #2980b9; margin: 10px 0; }}
            .metric-name {{ font-weight: bold; font-size: 16px; }}
            .alert {{ background-color: #e74c3c; color: white; padding: 10px; border-radius: 5px; margin: 20px 0; }}
            .success {{ background-color: #2ecc71; color: white; padding: 10px; border-radius: 5px; margin: 20px 0; }}
            table {{ border-collapse: collapse; width: 100%; margin: 20px 0; }}
            th, td {{ border: 1px solid #ddd; padding: 12px; text-align: left; }}
            th {{ background-color: #f2f2f2; }}
            tr:nth-child(even) {{ background-color: #f9f9f9; }}
            tr:hover {{ background-color: #f5f5f5; }}
            img {{ max-width: 100%; height: auto; margin: 20px 0; border: 1px solid #ddd; }}
            .image-container {{ margin: 30px 0; }}
            .footer {{ margin-top: 50px; border-top: 1px solid #ddd; padding-top: 20px; font-size: 12px; color: #777; }}
        </style>
    </head>
    <body>
        <div class="container">
            <h1>Pneumonia Detection Model Report</h1>
            
            <h2>Model Performance</h2>
            <div class="metric-container">
                <div class="metric-box">
                    <div class="metric-name">Accuracy</div>
                    <div class="metric-value">{metrics['accuracy']:.4f}</div>
                </div>
                <div class="metric-box">
                    <div class="metric-name">Precision</div>
                    <div class="metric-value">{metrics['precision']:.4f}</div>
                </div>
                <div class="metric-box">
                    <div class="metric-name">Recall (Sensitivity)</div>
                    <div class="metric-value">{metrics['recall']:.4f}</div>
                </div>
                <div class="metric-box">
                    <div class="metric-name">Specificity</div>
                    <div class="metric-value">{metrics['specificity']:.4f}</div>
                </div>
                <div class="metric-box">
                    <div class="metric-name">F1 Score</div>
                    <div class="metric-value">{metrics['f1']:.4f}</div>
                </div>
                <div class="metric-box">
                    <div class="metric-name">AUC</div>
                    <div class="metric-value">{metrics['auc']:.4f}</div>
                </div>
            </div>
            
            {'<div class="success">✓ Model achieved target accuracy of 99%!</div>' if metrics['accuracy'] >= 0.99 else '<div class="alert">⚠️ Model did not reach target accuracy of 99%.</div>'}
            
            <h2>Confusion Matrix</h2>
            <table>
                <tr>
                    <th></th>
                    <th>Predicted Normal</th>
                    <th>Predicted Pneumonia</th>
                </tr>
                <tr>
                    <th>Actual Normal</th>
                    <td>{metrics['true_negatives']}</td>
                    <td>{metrics['false_positives']}</td>
                </tr>
                <tr>
                    <th>Actual Pneumonia</th>
                    <td>{metrics['false_negatives']}</td>
                    <td>{metrics['true_positives']}</td>
                </tr>
            </table>
            
            <h2>Training History</h2>
            <div class="image-container">
                <img src="training_history.png" alt="Training History" />
            </div>
            
            <h2>Confusion Matrix Visualization</h2>
            <div class="image-container">
                <img src="final_confusion_matrix.png" alt="Confusion Matrix" />
            </div>
            
            <h2>Model Configuration</h2>
            <table>
                <tr>
                    <th>Parameter</th>
                    <th>Value</th>
                </tr>
                <tr>
                    <td>Model</td>
                    <td>{CONFIG.get('timm_model', CONFIG.get('beit_model', 'Unknown'))}</td>
                </tr>
                <tr>
                    <td>Image Size</td>
                    <td>{CONFIG['image_size']}</td>
                </tr>
                <tr>
                    <td>Learning Rate</td>
                    <td>{CONFIG['learning_rate']}</td>
                </tr>
                <tr>
                    <td>Batch Size</td>
                    <td>{CONFIG['batch_size']} (Effective: {CONFIG['batch_size'] * CONFIG['gradient_accumulation_steps']})</td>
                </tr>
                <tr>
                    <td>Augmentation Strength</td>
                    <td>{CONFIG.get('augmentation_strength', 'extreme')}</td>
                </tr>
                <tr>
                    <td>Class Weights</td>
                    <td>Normal: {CONFIG.get('class_weight_normal', 1.5)}, Pneumonia: {CONFIG.get('class_weight_pneumonia', 1.0)}</td>
                </tr>
            </table>
            
            <div class="footer">
                <p>Report generated on {time.strftime('%Y-%m-%d %H:%M:%S')}</p>
                <p>High-Accuracy BEiT Pneumonia Detection Model</p>
            </div>
        </div>
    </body>
    </html>
    """
    
    # Save HTML report
    with open('pneumonia_detection_report.html', 'w') as f:
        f.write(report_html)
    
    print("Generated HTML report: pneumonia_detection_report.html")
    
    # Save the model
    final_model_path = 'high_accuracy_pneumonia_model.pth'
    torch.save({
        'model_state_dict': model.state_dict(),
        'config': CONFIG,
        'class_names': class_names,
        'metrics': metrics,
        'history': history
    }, final_model_path)
    
    print(f"Saved final model to {final_model_path}")
    
    # Create usage example
    usage_code = """
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np

def load_high_accuracy_model(model_path):
    # Load checkpoint
    checkpoint = torch.load(model_path, map_location='cpu')
    
    # Extract configuration
    config = checkpoint.get('config', {})
    class_names = checkpoint.get('class_names', ['NORMAL', 'PNEUMONIA'])
    
    # Create model
    model = create_model(len(class_names), config)
    
    # Load weights
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    
    return model, checkpoint

def create_model(num_classes, config):
    # Check if timm is being used
    if config.get('use_timm', False):
        import timm
        
        # Create timm-based model
        model = timm.create_model(
            config.get('timm_model', 'beit_base_patch16_224'),
            pretrained=False,
            num_classes=0
        )
        
        # Add custom head
        feature_dim = model.num_features
        dropout_rate = config.get('dropout_rate', 0.3)
        
        # Create classifier layers
        class CustomModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.backbone = model
                self.norm = nn.LayerNorm(feature_dim)
                self.classifier = nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.Linear(feature_dim, 1024),
                    nn.GELU(),
                    nn.LayerNorm(1024),
                    nn.Dropout(dropout_rate),
                    nn.Linear(1024, 512),
                    nn.GELU(),
                    nn.LayerNorm(512),
                    nn.Dropout(dropout_rate),
                    nn.Linear(512, num_classes)
                )
            
            def forward(self, x):
                x = self.backbone(x)
                x = self.norm(x)
                x = self.classifier(x)
                return x
        
        return CustomModel()
    else:
        # Use transformers BEiT model
        from transformers import BeitModel
        
        # Define model architecture
        class BeitClassifier(nn.Module):
            def __init__(self, num_classes, dropout_rate=0.3):
                super().__init__()
                
                # Load BEiT model
                self.beit = BeitModel.from_pretrained(config.get('beit_model', 'microsoft/beit-base-patch16-224'))
                
                # Get feature dimension
                self.feature_dim = self.beit.config.hidden_size
                
                # Add layer norm
                self.norm = nn.LayerNorm(self.feature_dim)
                
                # Add classifier
                self.classifier = nn.Sequential(
                    nn.Dropout(dropout_rate),
                    nn.Linear(self.feature_dim, 1024),
                    nn.GELU(),
                    nn.LayerNorm(1024),
                    nn.Dropout(dropout_rate),
                    nn.Linear(1024, 512),
                    nn.GELU(),
                    nn.LayerNorm(512), 
                    nn.Dropout(dropout_rate),
                    nn.Linear(512, num_classes)
                )
            
            def forward(self, x):
                # Extract features
                outputs = self.beit(x)
                cls_token = outputs.last_hidden_state[:, 0]
                
                # Apply norm
                normalized = self.norm(cls_token)
                
                # Apply classifier
                logits = self.classifier(normalized)
                
                return logits
        
        return BeitClassifier(num_classes, dropout_rate=config.get('dropout_rate', 0.3))

def predict_pneumonia(model, image_path, checkpoint=None):

    # Get configuration
    config = checkpoint.get('config', {}) if checkpoint else {}
    class_names = checkpoint.get('class_names', ['NORMAL', 'PNEUMONIA']) if checkpoint else ['NORMAL', 'PNEUMONIA']
    
    # Get normalization parameters
    normalize_means = config.get('normalize_means', [0.5056, 0.5056, 0.5056])
    normalize_stds = config.get('normalize_stds', [0.252, 0.252, 0.252])
    
    # Create transform
    transform = transforms.Compose([
        transforms.Resize((config.get('image_size', 384), config.get('image_size', 384))),
        transforms.ToTensor(),
        transforms.Normalize(mean=normalize_means, std=normalize_stds)
    ])
    
    # Load and preprocess image
    image = Image.open(image_path).convert('RGB')
    input_tensor = transform(image).unsqueeze(0)
    
    # Get prediction
    with torch.no_grad():
        outputs = model(input_tensor)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)[0]
        predicted_class = torch.argmax(probabilities, dim=0).item()
    
    # Get class name and probability
    class_name = class_names[predicted_class]
    probability = probabilities[predicted_class].item() * 100
    
    # Get pneumonia probability
    pneumonia_idx = class_names.index('PNEUMONIA') if 'PNEUMONIA' in class_names else 1
    pneumonia_prob = probabilities[pneumonia_idx].item() * 100
    
    # Display results
    print(f"Prediction: {class_name}")
    print(f"Confidence: {probability:.1f}%")
    print(f"Pneumonia Probability: {pneumonia_prob:.1f}%")
    
    # Display image with prediction
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    plt.title(f"Prediction: {class_name} ({probability:.1f}%)", size=20)
    plt.axis('off')
    
    # Add probability bars
    plt.figure(figsize=(10, 3))
    colors = ['green', 'red'] if pneumonia_idx == 1 else ['red', 'green']
    bars = plt.bar(class_names, [probabilities[i].item() * 100 for i in range(len(class_names))], color=colors)
    
    # Add labels on bars
    for bar in bars:
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width()/2.,
            height + 1,
            f'{height:.1f}%',
            ha='center',
            va='bottom',
            fontsize=14
        )
    
    plt.title('Prediction Probabilities', size=16)
    plt.ylabel('Probability (%)', size=14)
    plt.ylim(0, 100)
    plt.xticks(fontsize=14)
    plt.yticks(fontsize=12)
    plt.tight_layout()
    plt.show()
    
    return {
        'class': class_name,
        'probability': probability,
        'pneumonia_probability': pneumonia_prob
    }

# Example usage
if __name__ == "__main__":
    # Load the model
    model, checkpoint = load_high_accuracy_model('high_accuracy_pneumonia_model.pth')
    
    # Make prediction
    predict_pneumonia(model, 'path/to/your/xray.jpg', checkpoint)
"""
    
    # Save usage example
    with open('use_high_accuracy_model.py', 'w') as f:
        f.write(usage_code)
    
    print("Saved usage example to 'use_high_accuracy_model.py'")
    
    # Print final summary
    print("\n" + "="*50)
    print("HIGH ACCURACY PNEUMONIA DETECTION MODEL TRAINING COMPLETE")
    print("="*50)
    print(f"Final Accuracy: {metrics['accuracy']:.4f}")
    print(f"Precision: {metrics['precision']:.4f}, Recall: {metrics['recall']:.4f}")
    print(f"F1 Score: {metrics['f1']:.4f}, AUC: {metrics['auc']:.4f}")
    
    if metrics['accuracy'] >= 0.99:
        print("\n✓ Successfully achieved target accuracy of 99%!")
    else:
        print(f"\n Did not reach target accuracy of 99% (got {metrics['accuracy']:.2%}).")
        
    print("\nModel saved, report generated, and usage example created.")
    print("="*50)

if __name__ == "__main__":
    main()