<a href="https://colab.research.google.com/github/RinorRexhaj/DocuForge/blob/main/Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
import shutil
from tqdm.notebook import tqdm
import numpy as np
import json
import matplotlib.pyplot as plt
from sklearn.metrics import (
    confusion_matrix, classification_report, roc_curve, auc
)
import seaborn as sns
import os

In [None]:
import timm
from collections import defaultdict
import random
from PIL import Image, ImageFilter
import cv2
import sklearn.metrics

# Focal Loss implementation
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=True, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = nn.functional.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        else:
            BCE_loss = nn.functional.binary_cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss

# Advanced Noise Augmentation
class NoiseAugmentation:
    def __init__(self, noise_prob=0.3):
        self.noise_prob = noise_prob
    
    def add_gaussian_noise(self, image, mean=0, std=0.05):
        """Add Gaussian noise to image"""
        if isinstance(image, Image.Image):
            image = np.array(image)
        
        noise = np.random.normal(mean, std, image.shape).astype(np.float32)
        noisy_image = image.astype(np.float32) + noise * 255
        noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8)
        
        return Image.fromarray(noisy_image) if len(noisy_image.shape) == 3 else Image.fromarray(noisy_image, mode='L')
    
    def add_salt_pepper_noise(self, image, salt_prob=0.01, pepper_prob=0.01):
        """Add salt and pepper noise"""
        if isinstance(image, Image.Image):
            image = np.array(image)
        
        noisy_image = image.copy()
        
        # Salt noise
        salt_mask = np.random.random(image.shape[:2]) < salt_prob
        noisy_image[salt_mask] = 255
        
        # Pepper noise
        pepper_mask = np.random.random(image.shape[:2]) < pepper_prob
        noisy_image[pepper_mask] = 0
        
        return Image.fromarray(noisy_image) if len(noisy_image.shape) == 3 else Image.fromarray(noisy_image, mode='L')
    
    def add_blur(self, image, blur_radius=1.5):
        """Add blur to image"""
        if isinstance(image, Image.Image):
            return image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
        return image
    
    def __call__(self, image):
        if random.random() < self.noise_prob:
            noise_type = random.choice(['gaussian', 'salt_pepper', 'blur'])
            if noise_type == 'gaussian':
                return self.add_gaussian_noise(image, std=random.uniform(0.02, 0.08))
            elif noise_type == 'salt_pepper':
                return self.add_salt_pepper_noise(image, 
                                                salt_prob=random.uniform(0.005, 0.02),
                                                pepper_prob=random.uniform(0.005, 0.02))
            elif noise_type == 'blur':
                return self.add_blur(image, blur_radius=random.uniform(0.5, 2.0))
        return image

In [2]:
from google.colab import drive
drive.mount('/content/drive')

drive_dataset_path = '/content/drive/MyDrive/DocuForge/dataset'
local_dataset_path = '/content/dataset'

# Function to copy dataset with progress
def copy_dataset(src, dst):
    if not os.path.exists(dst):
        os.makedirs(dst)

    for root, dirs, files in os.walk(src):
        # Recreate directory structure
        rel_path = os.path.relpath(root, src)
        dest_dir = os.path.join(dst, rel_path)
        os.makedirs(dest_dir, exist_ok=True)

        # Copy files with progress bar
        for file in tqdm(files, desc=f"Copying {rel_path}", unit="file"):
            src_file = os.path.join(root, file)
            dest_file = os.path.join(dest_dir, file)
            if not os.path.exists(dest_file):
                shutil.copy2(src_file, dest_file)

# Run it
copy_dataset(drive_dataset_path, local_dataset_path)

print("✅ Dataset copied successfully!")

Mounted at /content/drive


Copying .: 0file [00:00, ?file/s]

Copying test: 0file [00:00, ?file/s]

Copying test/authentic:   0%|          | 0/300 [00:00<?, ?file/s]

Copying test/forged:   0%|          | 0/300 [00:00<?, ?file/s]

Copying train: 0file [00:00, ?file/s]

Copying train/forged:   0%|          | 0/1400 [00:00<?, ?file/s]

Copying train/authentic:   0%|          | 0/1400 [00:00<?, ?file/s]

Copying val: 0file [00:00, ?file/s]

Copying val/authentic:   0%|          | 0/300 [00:00<?, ?file/s]

Copying val/forged:   0%|          | 0/300 [00:00<?, ?file/s]

✅ Dataset copied successfully!


In [None]:
data_path = '/content/dataset/'

IMG_SIZE = 224  # ResNet50 default input size

# Initialize noise augmentation with higher probability
noise_aug = NoiseAugmentation(noise_prob=0.5)

# Even more aggressive training transforms for final accuracy push
train_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomRotation(20),  # More rotation
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.7, 1.0)),  # More aggressive cropping
    transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.4, hue=0.15),  # Stronger color jitter
    transforms.RandomHorizontalFlip(p=0.6),  # Higher flip probability
    transforms.RandomVerticalFlip(p=0.4),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.5))], p=0.4),
    transforms.RandomApply([transforms.Lambda(lambda x: noise_aug(x))], p=0.5),  # Higher noise probability
    transforms.RandomPerspective(distortion_scale=0.3, p=0.4),  # More perspective distortion
    transforms.RandomAffine(degrees=0, translate=(0.15, 0.15), scale=(0.85, 1.15), shear=8),  # More affine transforms
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
    transforms.RandomErasing(p=0.3, scale=(0.02, 0.2), ratio=(0.3, 3.3))  # More random erasing
])

val_transforms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

# Test-time augmentation transforms
test_transforms_tta = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.FiveCrop(IMG_SIZE),  # Create 5 crops
    transforms.Lambda(lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
    transforms.Lambda(lambda tensors: torch.stack([transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(t) for t in tensors]))
])

# Directories inside Google Drive
train_dir = data_path + 'train'
val_dir = data_path + 'val'
test_dir = data_path + 'test'

# Datasets
train_dataset = datasets.ImageFolder(train_dir, transform=train_transforms)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transforms)
test_dataset = datasets.ImageFolder(test_dir, transform=val_transforms)

# DataLoaders with reduced workers to avoid warnings
train_loader = DataLoader(train_dataset, batch_size=28, shuffle=True, pin_memory=True, num_workers=2, drop_last=True)  # Reduced batch size for stability
val_loader = DataLoader(val_dataset, batch_size=56, shuffle=False, pin_memory=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=56, shuffle=False, pin_memory=True, num_workers=2)

print(f"Classes: {train_dataset.classes}")
print(f"Train: {len(train_dataset)} | Val: {len(val_dataset)} | Test: {len(test_dataset)}")

# Calculate class weights for handling imbalance
class_counts = defaultdict(int)
for _, label in train_dataset.samples:
    class_counts[label] += 1

total_samples = sum(class_counts.values())
class_weights = {cls: total_samples / (len(class_counts) * count) for cls, count in class_counts.items()}
print(f"Class weights: {class_weights}")

# Convert to tensor for loss function
weight_tensor = torch.tensor([class_weights[0], class_weights[1]], dtype=torch.float32)

print("🔥 Enhanced data augmentation configured for final accuracy push!")

Classes: ['authentic', 'forged']
Train: 2800 | Val: 600 | Test: 600


In [9]:
param_grid = {
    'learning_rate': [0.001, 0.01, 0.0001],
    'batch_size': [16, 32, 64],
    'optimizer': ['adam', 'sgd', 'adamw'],
    'weight_decay': [1e-4, 1e-3, 0],
    'dropout_rate': [0.3, 0.5, 0.7],
    'hidden_units': [128, 256, 512]
}

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

class MultiScaleFeatureExtractor(nn.Module):
    """Multi-scale feature extraction with attention"""
    def __init__(self, base_model, num_classes=1, feature_dim=None):
        super(MultiScaleFeatureExtractor, self).__init__()
        self.base_model = base_model
        
        # Determine feature dimensions based on model type
        if feature_dim is not None:
            in_features = feature_dim
        else:
            # Get feature dimensions before replacing the final layer
            if hasattr(self.base_model, 'classifier') and hasattr(self.base_model.classifier, 'in_features'):
                in_features = self.base_model.classifier.in_features
                self.base_model.classifier = nn.Identity()
            elif hasattr(self.base_model, 'fc') and hasattr(self.base_model.fc, 'in_features'):
                in_features = self.base_model.fc.in_features
                self.base_model.fc = nn.Identity()
            elif hasattr(self.base_model, 'head') and hasattr(self.base_model.head, 'in_features'):
                in_features = self.base_model.head.in_features
                self.base_model.head = nn.Identity()
            else:
                # Default feature dimensions for common architectures
                model_name = str(type(self.base_model).__name__).lower()
                if 'resnet' in model_name:
                    in_features = 2048  # ResNet50
                elif 'efficientnet' in model_name:
                    in_features = 1792  # EfficientNet-B4
                elif 'convnext' in model_name:
                    in_features = 1024  # ConvNeXt-Base
                else:
                    in_features = 2048  # Default fallback
        
        # Store feature dimension for forward pass
        self.feature_dim = in_features
        
        # Multi-scale processing
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.global_max_pool = nn.AdaptiveMaxPool2d(1)
        
        # Attention mechanism
        self.attention_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(in_features, max(in_features // 16, 1), 1),
            nn.ReLU(),
            nn.Conv2d(max(in_features // 16, 1), in_features, 1),
            nn.Sigmoid()
        )
        
        # Classification head with dropout and batch norm
        self.classifier = nn.Sequential(
            nn.Linear(in_features * 2, 512),  # *2 for avg and max pooling
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
        
    def forward(self, x):
        features = self.base_model(x)
        
        # If features are 4D (from feature extraction), apply pooling
        if len(features.shape) == 4:
            # Apply attention
            attention_weights = self.attention_pool(features)
            features = features * attention_weights
            
            # Global pooling
            avg_pool = self.global_pool(features).flatten(1)
            max_pool = self.global_max_pool(features).flatten(1)
            features = torch.cat([avg_pool, max_pool], dim=1)
        elif len(features.shape) == 2:
            # Features are already flattened, duplicate for avg+max pooling simulation
            features = torch.cat([features, features], dim=1)
        
        return self.classifier(features)

class EnsembleModel(nn.Module):
    """Ensemble of multiple architectures"""
    def __init__(self, num_classes=1):
        super(EnsembleModel, self).__init__()
        
        # EfficientNet-B4
        self.efficientnet = timm.create_model('efficientnet_b4', pretrained=True, num_classes=0)
        self.efficientnet_classifier = MultiScaleFeatureExtractor(self.efficientnet, num_classes, feature_dim=1792)
        
        # ConvNeXt-Base
        self.convnext = timm.create_model('convnext_base', pretrained=True, num_classes=0)
        self.convnext_classifier = MultiScaleFeatureExtractor(self.convnext, num_classes, feature_dim=1024)
        
        # ResNet50 (keep original)
        self.resnet = models.resnet50(weights='IMAGENET1K_V2')
        # Store the feature dimension before replacing
        resnet_features = self.resnet.fc.in_features
        self.resnet.fc = nn.Identity()
        self.resnet_classifier = MultiScaleFeatureExtractor(self.resnet, num_classes, feature_dim=resnet_features)
        
        # Vision Transformer - simpler approach
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, num_classes=0)
        self.vit_classifier = nn.Sequential(
            nn.Linear(768, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_classes)
        )
        
        # Ensemble fusion
        self.fusion = nn.Sequential(
            nn.Linear(4, 8),  # 4 models
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(8, 1)
        )
        
        # Model weights for weighted averaging
        self.model_weights = nn.Parameter(torch.ones(4) / 4)
        
    def forward(self, x):
        # Get predictions from all models
        efficientnet_out = self.efficientnet_classifier(x)
        convnext_out = self.convnext_classifier(x)
        resnet_out = self.resnet_classifier(x)
        vit_out = self.vit_classifier(self.vit(x))
        
        # Stack predictions
        ensemble_input = torch.cat([efficientnet_out, convnext_out, resnet_out, vit_out], dim=1)
        
        # Fusion approach 1: Neural network fusion
        fused_output = self.fusion(ensemble_input)
        
        # Fusion approach 2: Weighted averaging (alternative)
        weighted_output = (self.model_weights[0] * efficientnet_out + 
                          self.model_weights[1] * convnext_out + 
                          self.model_weights[2] * resnet_out + 
                          self.model_weights[3] * vit_out)
        
        # Use neural network fusion as primary output
        return fused_output

# Create the ensemble model
print("🔧 Creating ensemble model...")
model = EnsembleModel(num_classes=1)
model = model.to(device)

# Freeze early layers for transfer learning
def freeze_early_layers(model, freeze_ratio=0.7):
    """Freeze early layers of all models in ensemble"""
    total_params = 0
    frozen_params = 0
    
    for name, param in model.named_parameters():
        total_params += 1
        # Freeze early layers but keep classification heads trainable
        if ('classifier' not in name and 'fusion' not in name and 
            'model_weights' not in name and total_params < freeze_ratio * len(list(model.parameters()))):
            param.requires_grad = False
            frozen_params += 1
        else:
            param.requires_grad = True
    
    print(f"Frozen {frozen_params}/{total_params} parameters")

# Apply gradual unfreezing strategy
freeze_early_layers(model, freeze_ratio=0.8)

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params:,}/{total_params:,} ({100*trainable_params/total_params:.1f}%)")
print(f"Model size: {total_params * 4 / (1024**2):.1f} MB")
print("✅ Ensemble model created successfully!")

In [4]:
for name, param in model.named_parameters():
    if "layer3" in name or "layer4" in name or "fc" in name:
        param.requires_grad = True

trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
total_params = sum(p.numel() for p in model.parameters())
print(f"Trainable parameters: {trainable_params}/{total_params}")

Trainable parameters: 22587905/24032833


In [None]:
# Advanced loss function - Focal Loss with adjusted parameters for better recall
criterion = FocalLoss(alpha=1.5, gamma=3.0, logits=True)  # Higher gamma for harder examples

# Separate learning rates for different parts of the ensemble
backbone_params = []
classifier_params = []
fusion_params = []

for name, param in model.named_parameters():
    if param.requires_grad:
        if 'classifier' in name or 'fc' in name or 'head' in name:
            classifier_params.append(param)
        elif 'fusion' in name or 'model_weights' in name:
            fusion_params.append(param)
        else:
            backbone_params.append(param)

# More aggressive optimizer settings
optimizer = torch.optim.AdamW([
    {"params": backbone_params, "lr": 2e-5, "weight_decay": 5e-5},      # Slightly higher for backbones
    {"params": classifier_params, "lr": 1e-3, "weight_decay": 1e-3},    # Higher for classification heads
    {"params": fusion_params, "lr": 2e-3, "weight_decay": 5e-4}         # Highest for fusion layers
], eps=1e-8, betas=(0.9, 0.999))

# Cosine annealing with warm restarts - more cycles
scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer, T_0=3, T_mult=2, eta_min=1e-8  # Shorter cycles, lower minimum
)

# Warmer warmup for better initialization
warmup_scheduler = torch.optim.lr_scheduler.LinearLR(
    optimizer, start_factor=0.05, total_iters=5  # Lower start, longer warmup
)

# Early stopping with more patience for final push
class EarlyStopping:
    def __init__(self, patience=12, min_delta=0.0005, restore_best_weights=True):  # Increased patience
        self.patience = patience
        self.min_delta = min_delta
        self.restore_best_weights = restore_best_weights
        self.best_loss = None
        self.counter = 0
        self.best_weights = None
        
    def __call__(self, val_loss, model):
        if self.best_loss is None:
            self.best_loss = val_loss
            self.save_checkpoint(model)
        elif val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
            self.save_checkpoint(model)
        else:
            self.counter += 1
            
        if self.counter >= self.patience:
            if self.restore_best_weights:
                model.load_state_dict(self.best_weights)
            return True
        return False
    
    def save_checkpoint(self, model):
        self.best_weights = model.state_dict().copy()

early_stopping = EarlyStopping(patience=12, min_delta=0.0005)

print("✅ Advanced training configuration initialized!")
print(f"Backbone params: {len(backbone_params)}")
print(f"Classifier params: {len(classifier_params)}")
print(f"Fusion params: {len(fusion_params)}")
print("🚀 Ready for final push to 90%+ accuracy!")

In [None]:
EPOCHS = 30  # Increased epochs
SAVE_DIR = "saved_models"
os.makedirs(SAVE_DIR, exist_ok=True)

best_val_acc = 0.0
best_val_loss = float('inf')

# Advanced tracking
train_losses, val_losses = [], []
train_accs, val_accs = [], []
train_f1s, val_f1s = [], []
learning_rates = []

# Gradual unfreezing schedule
def unfreeze_layers(model, epoch):
    """Gradually unfreeze layers during training"""
    if epoch == 5:  # Unfreeze more layers after 5 epochs
        for name, param in model.named_parameters():
            if 'layer4' in name or 'blocks.3' in name or 'stages.3' in name:
                param.requires_grad = True
        print("🔓 Unfroze layer4/blocks.3/stages.3")
    elif epoch == 10:  # Unfreeze even more after 10 epochs
        for name, param in model.named_parameters():
            if 'layer3' in name or 'blocks.2' in name or 'stages.2' in name:
                param.requires_grad = True
        print("🔓 Unfroze layer3/blocks.2/stages.2")
    elif epoch == 15:  # Fine-tune all layers
        for param in model.parameters():
            param.requires_grad = True
        print("🔓 Unfroze all layers for fine-tuning")

# Mixed precision training with updated API
scaler = torch.amp.GradScaler('cuda') if device.type == 'cuda' else None

print(f"Starting training for {EPOCHS} epochs...")
print("=" * 60)

for epoch in range(EPOCHS):
    # Gradual unfreezing
    unfreeze_layers(model, epoch)
    
    # Training phase
    model.train()
    train_loss, correct, total = 0.0, 0, 0
    all_train_preds, all_train_labels = [], []
    
    # Use warmup scheduler for first 3 epochs
    if epoch < 3:
        current_scheduler = warmup_scheduler
    else:
        current_scheduler = scheduler
    
    progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    
    for batch_idx, (imgs, labels) in enumerate(progress_bar):
        imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
        
        optimizer.zero_grad()
        
        # Mixed precision forward pass with updated API
        if scaler is not None:
            with torch.amp.autocast('cuda'):
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            
            scaler.scale(loss).backward()
            # Gradient clipping for stability
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            loss.backward()
            # Gradient clipping for stability
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        train_loss += loss.item()
        
        # Calculate predictions and accuracy
        with torch.no_grad():
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # Store for F1 calculation
            all_train_preds.extend(preds.cpu().numpy().flatten())
            all_train_labels.extend(labels.cpu().numpy().flatten())
        
        # Update progress bar
        current_acc = correct / total
        progress_bar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{current_acc:.3f}',
            'LR': f'{optimizer.param_groups[0]["lr"]:.2e}'
        })
    
    # Calculate training metrics
    train_acc = correct / total
    train_f1 = sklearn.metrics.f1_score(all_train_labels, all_train_preds, average='binary')
    train_losses.append(train_loss / len(train_loader))
    train_accs.append(train_acc)
    train_f1s.append(train_f1)
    learning_rates.append(optimizer.param_groups[0]['lr'])
    
    # Validation phase
    model.eval()
    val_loss, correct, total = 0.0, 0, 0
    all_val_preds, all_val_labels, all_val_probs = [], [], []
    
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc="Validation"):
            imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
            
            if scaler is not None:
                with torch.amp.autocast('cuda'):
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            
            probs = torch.sigmoid(outputs)
            preds = (probs > 0.5).float()
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            # Store predictions and probabilities
            all_val_preds.extend(preds.cpu().numpy().flatten())
            all_val_labels.extend(labels.cpu().numpy().flatten())
            all_val_probs.extend(probs.cpu().numpy().flatten())
    
    # Calculate validation metrics
    val_acc = correct / total
    val_f1 = sklearn.metrics.f1_score(all_val_labels, all_val_preds, average='binary')
    val_precision = sklearn.metrics.precision_score(all_val_labels, all_val_preds, average='binary')
    val_recall = sklearn.metrics.recall_score(all_val_labels, all_val_preds, average='binary')
    val_auc = sklearn.metrics.roc_auc_score(all_val_labels, all_val_probs)
    
    val_losses.append(val_loss / len(val_loader))
    val_accs.append(val_acc)
    val_f1s.append(val_f1)
    
    # Learning rate scheduling
    if epoch >= 3:
        scheduler.step()
    else:
        warmup_scheduler.step()
    
    # Print epoch results
    print(f"\nEpoch {epoch+1}/{EPOCHS} Results:")
    print(f"Train - Loss: {train_loss/len(train_loader):.4f}, Acc: {train_acc:.4f}, F1: {train_f1:.4f}")
    print(f"Val   - Loss: {val_loss/len(val_loader):.4f}, Acc: {val_acc:.4f}, F1: {val_f1:.4f}")
    print(f"Val   - Precision: {val_precision:.4f}, Recall: {val_recall:.4f}, AUC: {val_auc:.4f}")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save current epoch model
    model_path = os.path.join(SAVE_DIR, f"ensemble_epoch_{epoch+1}.pth")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'val_acc': val_acc,
        'val_loss': val_loss / len(val_loader),
        'train_acc': train_acc,
        'train_loss': train_loss / len(train_loader),
        'val_f1': val_f1
    }, model_path)
    
    # Save best model based on F1 score (better for imbalanced classes)
    if val_f1 > best_val_acc:  # Using F1 as main metric
        best_val_acc = val_f1
        best_model_path = os.path.join(SAVE_DIR, "best_ensemble_model.pth")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_loss': val_loss / len(val_loader),
            'train_acc': train_acc,
            'train_loss': train_loss / len(train_loader),
            'val_f1': val_f1
        }, best_model_path)
        print(f"🏆 Best model updated! Val F1: {val_f1:.4f}, Val Acc: {val_acc:.4f}")
    
    # Early stopping check with increased patience
    current_val_loss = val_loss / len(val_loader)
    if early_stopping(current_val_loss, model):
        print(f"🛑 Early stopping triggered at epoch {epoch+1}")
        break
    
    print("-" * 60)

print("\n✅ Training completed!")
print(f"Best validation F1-score: {best_val_acc:.4f}")

# Plot training curves
plt.figure(figsize=(15, 10))

plt.subplot(2, 3, 1)
plt.plot(range(1, len(train_losses)+1), train_losses, 'b-', label='Train Loss')
plt.plot(range(1, len(val_losses)+1), val_losses, 'r-', label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 3, 2)
plt.plot(range(1, len(train_accs)+1), train_accs, 'b-', label='Train Acc')
plt.plot(range(1, len(val_accs)+1), val_accs, 'r-', label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.subplot(2, 3, 3)
plt.plot(range(1, len(train_f1s)+1), train_f1s, 'b-', label='Train F1')
plt.plot(range(1, len(val_f1s)+1), val_f1s, 'r-', label='Val F1')
plt.xlabel('Epoch')
plt.ylabel('F1 Score')
plt.title('Training and Validation F1 Score')
plt.legend()
plt.grid(True)

plt.subplot(2, 3, 4)
plt.plot(range(1, len(learning_rates)+1), learning_rates)
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.yscale('log')
plt.grid(True)

plt.tight_layout()
plt.savefig(os.path.join(SAVE_DIR, 'training_curves.png'), dpi=300, bbox_inches='tight')
plt.show()

print(f"📊 Training curves saved to {os.path.join(SAVE_DIR, 'training_curves.png')}")

Epoch 1/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 1/15 | Train Acc=0.726 | Val Acc=0.767 | Train Loss=0.597 | Val Loss=0.565
🏆 Best model updated (Val Acc=0.767)


Epoch 2/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 2/15 | Train Acc=0.844 | Val Acc=0.868 | Train Loss=0.408 | Val Loss=0.357
🏆 Best model updated (Val Acc=0.868)


Epoch 3/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 3/15 | Train Acc=0.847 | Val Acc=0.877 | Train Loss=0.365 | Val Loss=0.332
🏆 Best model updated (Val Acc=0.877)


Epoch 4/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 4/15 | Train Acc=0.856 | Val Acc=0.870 | Train Loss=0.345 | Val Loss=0.319


Epoch 5/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 5/15 | Train Acc=0.864 | Val Acc=0.882 | Train Loss=0.340 | Val Loss=0.308
🏆 Best model updated (Val Acc=0.882)


Epoch 6/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 6/15 | Train Acc=0.867 | Val Acc=0.882 | Train Loss=0.323 | Val Loss=0.304


Epoch 7/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 7/15 | Train Acc=0.868 | Val Acc=0.880 | Train Loss=0.324 | Val Loss=0.299


Epoch 8/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 8/15 | Train Acc=0.879 | Val Acc=0.878 | Train Loss=0.311 | Val Loss=0.298


Epoch 9/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 9/15 | Train Acc=0.872 | Val Acc=0.880 | Train Loss=0.313 | Val Loss=0.296


Epoch 10/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 10/15 | Train Acc=0.871 | Val Acc=0.883 | Train Loss=0.317 | Val Loss=0.296
🏆 Best model updated (Val Acc=0.883)


Epoch 11/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 11/15 | Train Acc=0.874 | Val Acc=0.878 | Train Loss=0.312 | Val Loss=0.295


Epoch 12/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 12/15 | Train Acc=0.864 | Val Acc=0.883 | Train Loss=0.313 | Val Loss=0.301


Epoch 13/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 13/15 | Train Acc=0.877 | Val Acc=0.885 | Train Loss=0.303 | Val Loss=0.291
🏆 Best model updated (Val Acc=0.885)


Epoch 14/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 14/15 | Train Acc=0.877 | Val Acc=0.890 | Train Loss=0.307 | Val Loss=0.288
🏆 Best model updated (Val Acc=0.890)


Epoch 15/15:   0%|          | 0/44 [00:00<?, ?it/s]

Epoch 15/15 | Train Acc=0.879 | Val Acc=0.890 | Train Loss=0.298 | Val Loss=0.288
✅ Training complete.


In [None]:
def advanced_evaluate_and_save(model, test_loader, criterion, device, save_dir="evaluation_results"):
    """
    Advanced evaluation with comprehensive metrics and visualizations
    """
    
    os.makedirs(save_dir, exist_ok=True)
    
    model.eval()
    test_loss, correct, total = 0.0, 0, 0
    all_labels, all_preds, all_probs = [], [], []
    per_class_correct = defaultdict(int)
    per_class_total = defaultdict(int)
    
    print("🧪 Running advanced evaluation...")
    
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Evaluating"):
            imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
            
            # Mixed precision inference with updated API
            if device.type == 'cuda':
                with torch.amp.autocast('cuda'):
                    outputs = model(imgs)
                    loss = criterion(outputs, labels)
            else:
                outputs = model(imgs)
                loss = criterion(outputs, labels)
            
            test_loss += loss.item()
            
            # Probabilities and predictions
            probs = torch.sigmoid(outputs).cpu().numpy().flatten()
            preds = (probs > 0.5).astype(int)
            labels_np = labels.cpu().numpy().flatten().astype(int)
            
            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(labels_np)
            
            # Per-class accuracy
            for pred, label in zip(preds, labels_np):
                per_class_total[label] += 1
                if pred == label:
                    per_class_correct[label] += 1
            
            correct += (preds == labels_np).sum().item()
            total += labels_np.shape[0]

    test_loss /= len(test_loader)
    test_acc = correct / total
    
    print(f"\n🧪 Test Results:")
    print(f"Loss: {test_loss:.4f}")
    print(f"Accuracy: {test_acc:.4f}")
    
    # Calculate comprehensive metrics
    precision = sklearn.metrics.precision_score(all_labels, all_preds, average='binary')
    recall = sklearn.metrics.recall_score(all_labels, all_preds, average='binary')
    f1 = sklearn.metrics.f1_score(all_labels, all_preds, average='binary')
    roc_auc = sklearn.metrics.roc_auc_score(all_labels, all_probs)
    
    # Calculate per-class metrics
    per_class_precision = sklearn.metrics.precision_score(all_labels, all_preds, average=None)
    per_class_recall = sklearn.metrics.recall_score(all_labels, all_preds, average=None)
    per_class_f1 = sklearn.metrics.f1_score(all_labels, all_preds, average=None)
    
    print(f"Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f}")
    print(f"F1-Score: {f1:.4f}")
    print(f"ROC-AUC: {roc_auc:.4f}")
    
    print(f"\nPer-class Accuracy:")
    for class_idx in [0, 1]:
        class_name = "Authentic" if class_idx == 0 else "Forged"
        if class_idx in per_class_total:
            acc = per_class_correct[class_idx] / per_class_total[class_idx]
            print(f"{class_name}: {acc:.4f} ({per_class_correct[class_idx]}/{per_class_total[class_idx]})")
    
    # -----------------------------------
    # Enhanced Classification Report
    # -----------------------------------
    report = sklearn.metrics.classification_report(
        all_labels, all_preds, 
        target_names=["Authentic", "Forged"], 
        output_dict=True
    )
    
    print("\n📊 Detailed Classification Report:")
    print(sklearn.metrics.classification_report(all_labels, all_preds, target_names=["Authentic", "Forged"]))
    
    # Save classification report
    report_path = os.path.join(save_dir, "classification_report.txt")
    with open(report_path, "w") as f:
        f.write(sklearn.metrics.classification_report(all_labels, all_preds, target_names=["Authentic", "Forged"]))
        f.write(f"\n\nOverall Metrics:\n")
        f.write(f"Test Accuracy: {test_acc:.4f}\n")
        f.write(f"Test Loss: {test_loss:.4f}\n")
        f.write(f"ROC-AUC: {roc_auc:.4f}\n")
    
    # -----------------------------------
    # Enhanced Confusion Matrix
    # -----------------------------------
    cm = sklearn.metrics.confusion_matrix(all_labels, all_preds)
    
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt="d", cmap="Blues", cbar_kws={'label': 'Count'},
                xticklabels=["Authentic", "Forged"],
                yticklabels=["Authentic", "Forged"])
    plt.xlabel("Predicted Label", fontsize=12)
    plt.ylabel("True Label", fontsize=12)
    plt.title("Confusion Matrix - Document Forgery Detection", fontsize=14)
    
    # Add percentage annotations
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j+0.5, i+0.7, f'({cm_percent[i, j]:.1f}%)', 
                    horizontalalignment='center', fontsize=10, color='darkred')
    
    cm_path = os.path.join(save_dir, "confusion_matrix.png")
    plt.savefig(cm_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    # -----------------------------------
    # ROC Curve and Precision-Recall Curve
    # -----------------------------------
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # ROC Curve
    fpr, tpr, _ = sklearn.metrics.roc_curve(all_labels, all_probs)
    
    ax1.plot(fpr, tpr, color="darkorange", lw=2, label=f"ROC Curve (AUC = {roc_auc:.3f})")
    ax1.plot([0, 1], [0, 1], color="gray", lw=1, linestyle="--", label="Random Classifier")
    ax1.set_xlabel("False Positive Rate")
    ax1.set_ylabel("True Positive Rate")
    ax1.set_title("ROC Curve - Forgery Detection")
    ax1.legend(loc="lower right")
    ax1.grid(True, alpha=0.3)
    
    # Precision-Recall Curve
    precision_curve, recall_curve, _ = sklearn.metrics.precision_recall_curve(all_labels, all_probs)
    pr_auc = sklearn.metrics.auc(recall_curve, precision_curve)
    
    ax2.plot(recall_curve, precision_curve, color="blue", lw=2, label=f"PR Curve (AUC = {pr_auc:.3f})")
    ax2.axhline(y=sum(all_labels)/len(all_labels), color="gray", linestyle="--", label="Random Classifier")
    ax2.set_xlabel("Recall")
    ax2.set_ylabel("Precision")
    ax2.set_title("Precision-Recall Curve")
    ax2.legend(loc="lower left")
    ax2.grid(True, alpha=0.3)
    
    curves_path = os.path.join(save_dir, "roc_pr_curves.png")
    plt.savefig(curves_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    # -----------------------------------
    # Prediction Distribution Analysis
    # -----------------------------------
    plt.figure(figsize=(12, 8))
    
    # Plot prediction probability distributions
    authentic_probs = [prob for prob, label in zip(all_probs, all_labels) if label == 0]
    forged_probs = [prob for prob, label in zip(all_probs, all_labels) if label == 1]
    
    plt.subplot(2, 2, 1)
    plt.hist(authentic_probs, bins=50, alpha=0.7, label='Authentic', color='blue', density=True)
    plt.hist(forged_probs, bins=50, alpha=0.7, label='Forged', color='red', density=True)
    plt.xlabel('Prediction Probability')
    plt.ylabel('Density')
    plt.title('Prediction Probability Distribution')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Confidence vs Accuracy plot
    plt.subplot(2, 2, 2)
    confidence_bins = np.linspace(0, 1, 11)
    bin_accuracies = []
    bin_counts = []
    
    for i in range(len(confidence_bins)-1):
        lower, upper = confidence_bins[i], confidence_bins[i+1]
        mask = (np.array(all_probs) >= lower) & (np.array(all_probs) < upper)
        if mask.sum() > 0:
            bin_acc = (np.array(all_preds)[mask] == np.array(all_labels)[mask]).mean()
            bin_accuracies.append(bin_acc)
            bin_counts.append(mask.sum())
        else:
            bin_accuracies.append(0)
            bin_counts.append(0)
    
    plt.bar(range(len(bin_accuracies)), bin_accuracies, alpha=0.7)
    plt.xlabel('Confidence Bin')
    plt.ylabel('Accuracy')
    plt.title('Confidence vs Accuracy')
    plt.xticks(range(len(confidence_bins)-1), [f'{confidence_bins[i]:.1f}-{confidence_bins[i+1]:.1f}' for i in range(len(confidence_bins)-1)], rotation=45)
    plt.grid(True, alpha=0.3)
    
    # Class-wise metrics bar plot
    plt.subplot(2, 2, 3)
    classes = ['Authentic', 'Forged']
    x = np.arange(len(classes))
    width = 0.25
    
    plt.bar(x - width, per_class_precision, width, label='Precision', alpha=0.8)
    plt.bar(x, per_class_recall, width, label='Recall', alpha=0.8)
    plt.bar(x + width, per_class_f1, width, label='F1-Score', alpha=0.8)
    
    plt.xlabel('Classes')
    plt.ylabel('Score')
    plt.title('Per-Class Metrics')
    plt.xticks(x, classes)
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Overall metrics summary
    plt.subplot(2, 2, 4)
    metrics_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC']
    metrics_values = [test_acc, precision, recall, f1, roc_auc]
    
    bars = plt.bar(metrics_names, metrics_values, color=['skyblue', 'lightgreen', 'lightcoral', 'lightsalmon', 'lightpink'])
    plt.ylabel('Score')
    plt.title('Overall Performance Metrics')
    plt.ylim(0, 1.1)
    
    # Add value labels on bars
    for bar, value in zip(bars, metrics_values):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, 
                f'{value:.3f}', ha='center', va='bottom')
    
    plt.xticks(rotation=45)
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    analysis_path = os.path.join(save_dir, "prediction_analysis.png")
    plt.savefig(analysis_path, dpi=300, bbox_inches="tight")
    plt.close()
    
    # -----------------------------------
    # Save comprehensive results
    # -----------------------------------
    results = {
        "test_loss": float(test_loss),
        "test_accuracy": float(test_acc),
        "precision": float(precision),
        "recall": float(recall),
        "f1_score": float(f1),
        "roc_auc": float(roc_auc),
        "pr_auc": float(pr_auc),
        "per_class_metrics": {
            "authentic": {
                "precision": float(per_class_precision[0]),
                "recall": float(per_class_recall[0]),
                "f1_score": float(per_class_f1[0]),
                "accuracy": float(per_class_correct[0] / per_class_total[0]) if 0 in per_class_total else 0.0
            },
            "forged": {
                "precision": float(per_class_precision[1]),
                "recall": float(per_class_recall[1]),
                "f1_score": float(per_class_f1[1]),
                "accuracy": float(per_class_correct[1] / per_class_total[1]) if 1 in per_class_total else 0.0
            }
        },
        "confusion_matrix": cm.tolist(),
        "model_info": {
            "architecture": "EnsembleModel (EfficientNet-B4 + ConvNeXt + ResNet50 + ViT)",
            "total_parameters": sum(p.numel() for p in model.parameters()),
            "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad)
        }
    }
    
    results_path = os.path.join(save_dir, "comprehensive_metrics.json")
    with open(results_path, "w") as f:
        json.dump(results, f, indent=4)
    
    print(f"\n✅ Advanced evaluation complete!")
    print(f"📊 Results saved in: {os.path.abspath(save_dir)}")
    print(f"📈 Confusion Matrix: {cm_path}")
    print(f"📉 ROC & PR Curves: {curves_path}")
    print(f"🔍 Analysis Plots: {analysis_path}")
    print(f"📦 Comprehensive Metrics: {results_path}")
    
    return results

In [None]:
# Load best ensemble model
checkpoint = torch.load("saved_models/best_ensemble_model.pth", map_location=device)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)

print(f"Loaded best model from epoch {checkpoint['epoch']+1}")
print(f"Best validation F1-score: {checkpoint.get('val_f1', 'N/A'):.4f}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.4f}")
print(f"Best validation loss: {checkpoint['val_loss']:.4f}")

# Run comprehensive evaluation
results = advanced_evaluate_and_save(model, test_loader, criterion, device, save_dir="advanced_evaluation_results")

# Print summary
print(f"\n🎯 FINAL RESULTS SUMMARY:")
print(f"=" * 50)
print(f"Test Accuracy: {results['test_accuracy']:.4f} ({results['test_accuracy']*100:.2f}%)")
print(f"Test F1-Score: {results['f1_score']:.4f}")
print(f"Test ROC-AUC: {results['roc_auc']:.4f}")
print(f"Test PR-AUC: {results['pr_auc']:.4f}")
print(f"=" * 50)

if results['test_accuracy'] >= 0.90:
    print("🎉 SUCCESS! Achieved >90% accuracy target!")
else:
    print(f"📈 Current accuracy: {results['test_accuracy']*100:.2f}% - Pushing for 90% with TTA!")

# Enhanced Test Time Augmentation for final boost
def enhanced_test_time_augmentation(model, test_loader, device, num_tta=8):
    """Apply enhanced test-time augmentation for improved predictions"""
    model.eval()
    all_tta_probs = []
    all_labels = []
    
    print(f"🔄 Applying Enhanced Test-Time Augmentation (TTA) with {num_tta} augmentations...")
    
    # Define TTA transforms
    tta_transforms = [
        transforms.Compose([transforms.ToPILImage(), transforms.ToTensor(), 
                           transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.RandomHorizontalFlip(p=1.0), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.RandomVerticalFlip(p=1.0), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.RandomRotation(5), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.RandomRotation(-5), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.ColorJitter(brightness=0.1), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.ColorJitter(contrast=0.1), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),
        transforms.Compose([transforms.ToPILImage(), transforms.RandomAffine(degrees=0, translate=(0.05, 0.05)), 
                           transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    ]
    
    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc="Enhanced TTA Evaluation"):
            imgs, labels = imgs.to(device), labels.float().unsqueeze(1).to(device)
            
            # Collect probabilities from multiple augmented versions
            tta_probs = []
            
            # Apply different augmentations
            for i in range(min(num_tta, len(tta_transforms))):
                if i == 0:
                    # Original image
                    aug_imgs = imgs
                else:
                    # Apply augmentation
                    aug_imgs = torch.stack([tta_transforms[i](img) for img in imgs])
                    aug_imgs = aug_imgs.to(device)
                
                if device.type == 'cuda':
                    with torch.amp.autocast('cuda'):
                        outputs = model(aug_imgs)
                else:
                    outputs = model(aug_imgs)
                
                probs = torch.sigmoid(outputs)
                tta_probs.append(probs)
            
            # Average the predictions
            avg_probs = torch.stack(tta_probs).mean(dim=0)
            all_tta_probs.extend(avg_probs.cpu().numpy().flatten())
            all_labels.extend(labels.cpu().numpy().flatten())
    
    # Calculate TTA metrics
    tta_preds = (np.array(all_tta_probs) > 0.5).astype(int)
    tta_acc = (tta_preds == np.array(all_labels)).mean()
    tta_f1 = sklearn.metrics.f1_score(all_labels, tta_preds, average='binary')
    tta_auc = sklearn.metrics.roc_auc_score(all_labels, all_tta_probs)
    
    # Try different thresholds to optimize accuracy
    best_threshold = 0.5
    best_acc = tta_acc
    
    for threshold in np.arange(0.3, 0.8, 0.05):
        thresh_preds = (np.array(all_tta_probs) > threshold).astype(int)
        thresh_acc = (thresh_preds == np.array(all_labels)).mean()
        if thresh_acc > best_acc:
            best_acc = thresh_acc
            best_threshold = threshold
    
    # Final predictions with optimal threshold
    final_preds = (np.array(all_tta_probs) > best_threshold).astype(int)
    final_acc = (final_preds == np.array(all_labels)).mean()
    final_f1 = sklearn.metrics.f1_score(all_labels, final_preds, average='binary')
    
    print(f"\n🚀 ENHANCED TEST-TIME AUGMENTATION RESULTS:")
    print(f"TTA Accuracy (0.5 threshold): {tta_acc:.4f} ({tta_acc*100:.2f}%)")
    print(f"TTA F1-Score (0.5 threshold): {tta_f1:.4f}")
    print(f"TTA ROC-AUC: {tta_auc:.4f}")
    print(f"\n🎯 OPTIMIZED THRESHOLD RESULTS:")
    print(f"Best threshold: {best_threshold:.3f}")
    print(f"Optimized Accuracy: {final_acc:.4f} ({final_acc*100:.2f}%)")
    print(f"Optimized F1-Score: {final_f1:.4f}")
    
    if final_acc >= 0.90:
        print("🎉🎉 SUCCESS! Achieved >90% accuracy with TTA! 🎉🎉")
    
    return final_acc, final_f1, tta_auc, best_threshold

# Run enhanced TTA for final accuracy boost
tta_acc, tta_f1, tta_auc, best_thresh = enhanced_test_time_augmentation(model, test_loader, device)

print("\n✅ Enhanced evaluation completed! Check the 'advanced_evaluation_results' folder for detailed analysis.")
print(f"💡 For production use, apply threshold {best_thresh:.3f} for optimal results!")

Evaluating:   0%|          | 0/10 [00:00<?, ?it/s]


🧪 Test Loss: 0.3662 | Test Accuracy: 0.8533

📊 Classification Report:
              precision    recall  f1-score   support

   Authentic       0.77      1.00      0.87       300
      Forged       1.00      0.71      0.83       300

    accuracy                           0.85       600
   macro avg       0.88      0.85      0.85       600
weighted avg       0.88      0.85      0.85       600

📝 Classification report saved to evaluation_results/classification_report.txt
🖼️ Confusion matrix saved to evaluation_results/confusion_matrix.png
📉 ROC curve saved to evaluation_results/roc_curve.png
📦 Metrics saved to evaluation_results/metrics.json

✅ Evaluation complete. All results saved in: /content/evaluation_results
