In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import timm
import numpy as np
import cv2
import random
import time
import json
import pickle
from datetime import datetime
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

def save_backup_files(self, completed_count):
        """Save JSON/pickle backup files after each architecture"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Save as JSON
        json_file = f'deer_aging_backup_{completed_count}_{timestamp}.json'
        with open(json_file, 'w') as f:
            json.dump(self.all_results, f, indent=2)
        
        # Save as pickle for full Python objects
        pkl_file = f'deer_aging_backup_{completed_count}_{timestamp}.pkl'
        with open(pkl_file, 'wb') as f:
            pickle.dump(self.all_results, f)
        
        return json_file, pkl_file

def load_original_data():
    """Load the original 357 images"""
    print("📂 LOADING ORIGINAL DATA")
    print("="*50)
    
    # Load using the user's method
    from buck.analysis.basics import ingest_images
    
    fpath = "C:\\Users\\aaron\\Dropbox\\AI Projects\\buck\\images\\squared\\color\\*.png"
    images, ages = ingest_images(fpath)
    print(f"   ✅ Loaded {len(images)} original images")
    
    # Group ages: 5.5+ all become 5.5 (creating exactly 5 classes)
    print("   🔄 Grouping ages: 5.5+ → 5.5")
    ages_grouped = []
    for age in ages:
        if age >= 5.5:
            ages_grouped.append(5.5)
        else:
            ages_grouped.append(age)
    
    # Print age distribution before and after grouping
    print(f"   📊 Original age distribution: {dict(Counter(ages))}")
    print(f"   📊 Grouped age distribution: {dict(Counter(ages_grouped))}")
    
    return images, ages_grouped

def create_train_val_test_split(images, ages, test_size=0.2, val_size=0.15, random_state=42):
    """Create train/validation/test split"""
    print("\n🔀 CREATING TRAIN/VAL/TEST SPLIT")
    print("="*50)
    
    # Convert to numpy arrays if needed
    if not isinstance(images, np.ndarray):
        images = np.array(images)
    if not isinstance(ages, np.ndarray):
        ages = np.array(ages)
    
    # Check if stratified split is possible
    age_counts = Counter(ages)
    min_count = min(age_counts.values())
    can_stratify = min_count >= 2
    
    print(f"   📊 Age distribution: {dict(age_counts)}")
    print(f"   📊 Minimum class size: {min_count}")
    print(f"   🎯 Can use stratified split: {can_stratify}")
    
    if can_stratify:
        # First split: separate test set (stratified)
        X_temp, X_test, y_temp, y_test = train_test_split(
            images, ages, test_size=test_size, random_state=random_state, stratify=ages
        )
        
        # Second split: separate train and validation from remaining data
        val_size_adjusted = val_size / (1 - test_size)  # Adjust for remaining data
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=val_size_adjusted, random_state=random_state, stratify=y_temp
        )
    else:
        print("   ⚠️ Using random split (some classes too small for stratification)")
        # First split: separate test set (random)
        X_temp, X_test, y_temp, y_test = train_test_split(
            images, ages, test_size=test_size, random_state=random_state, shuffle=True
        )
        
        # Second split: separate train and validation from remaining data
        val_size_adjusted = val_size / (1 - test_size)  # Adjust for remaining data
        X_train, X_val, y_train, y_val = train_test_split(
            X_temp, y_temp, test_size=val_size_adjusted, random_state=random_state, shuffle=True
        )
    
    # Create label mapping
    unique_ages = sorted(list(set(ages)))
    label_mapping = {age: i for i, age in enumerate(unique_ages)}
    reverse_mapping = {i: age for age, i in label_mapping.items()}
    
    print(f"   📊 Train: {len(X_train)} samples")
    print(f"   📊 Val: {len(X_val)} samples") 
    print(f"   📊 Test: {len(X_test)} samples")
    print(f"   🏷️ Label mapping: {label_mapping}")
    print(f"   🎯 Number of classes: {len(unique_ages)}")
    
    # Convert ages to class indices
    y_train_indices = np.array([label_mapping[age] for age in y_train])
    y_val_indices = np.array([label_mapping[age] for age in y_val])
    y_test_indices = np.array([label_mapping[age] for age in y_test])
    
    print(f"   📈 Train distribution: {Counter(y_train_indices)}")
    print(f"   📈 Val distribution: {Counter(y_val_indices)}")
    print(f"   📈 Test distribution: {Counter(y_test_indices)}")
    
    return (X_train, y_train_indices, X_val, y_val_indices, X_test, y_test_indices, 
            label_mapping, reverse_mapping)

def augment_image(image):
    """Apply random augmentation to an image"""
    # Ensure image is uint8
    if image.dtype != np.uint8:
        image = image.astype(np.uint8)
    
    # Random rotation
    if random.random() < 0.5:
        angle = random.uniform(-15, 15)
        h, w = image.shape[:2]
        center = (w // 2, h // 2)
        M = cv2.getRotationMatrix2D(center, angle, 1.0)
        image = cv2.warpAffine(image, M, (w, h))
    
    # Random horizontal flip
    if random.random() < 0.5:
        image = cv2.flip(image, 1)
    
    # Random brightness/contrast
    if random.random() < 0.5:
        alpha = random.uniform(0.8, 1.2)  # Contrast
        beta = random.randint(-20, 20)    # Brightness
        image = cv2.convertScaleAbs(image, alpha=alpha, beta=beta)
    
    # Random noise (fixed data type issue)
    if random.random() < 0.3:
        # Create noise with same dtype as image
        noise = np.random.normal(0, 5, image.shape).astype(np.int16)  # Use int16 to handle negative values
        # Convert image to int16 for safe addition
        image_int16 = image.astype(np.int16)
        # Add noise and clip to valid range
        noisy_image = np.clip(image_int16 + noise, 0, 255)
        # Convert back to uint8
        image = noisy_image.astype(np.uint8)
    
    return image

def balance_and_augment_data(X_train, y_train, augment_multiplier=30, num_classes=5):
    """Balance classes and augment training data"""
    print(f"\n🔄 BALANCING AND AUGMENTING DATA")
    print("="*50)
    print(f"   🎯 Target: {augment_multiplier}x augmentation per class")
    
    # Count samples per class
    class_counts = Counter(y_train)
    print(f"   📊 Original distribution: {dict(class_counts)}")
    
    # Find target count (based on largest class * multiplier)
    max_count = max(class_counts.values())
    target_count = max_count * augment_multiplier
    print(f"   🎯 Target samples per class: {target_count}")
    
    X_augmented = []
    y_augmented = []
    
    for class_idx in range(num_classes):
        # Get samples for this class
        class_mask = y_train == class_idx
        class_images = X_train[class_mask]
        class_labels = y_train[class_mask]
        
        current_count = len(class_images)
        needed_count = target_count
        
        print(f"   📈 Class {class_idx}: {current_count} → {needed_count} samples")
        
        # Add original samples
        X_augmented.extend(class_images)
        y_augmented.extend(class_labels)
        
        # Generate augmented samples
        augmented_needed = needed_count - current_count
        
        for i in range(augmented_needed):
            # Pick random original image from this class
            original_idx = random.randint(0, current_count - 1)
            original_image = class_images[original_idx].copy()
            
            # Augment it
            augmented_image = augment_image(original_image)
            
            X_augmented.append(augmented_image)
            y_augmented.append(class_idx)
    
    # Convert to arrays
    X_augmented = np.array(X_augmented)
    y_augmented = np.array(y_augmented)
    
    print(f"   ✅ Augmentation complete: {len(X_augmented)} total samples")
    print(f"   📊 Final distribution: {Counter(y_augmented)}")
    
    return X_augmented, y_augmented

class DeerDataset(Dataset):
    """Dataset for deer aging with preprocessing"""
    
    def __init__(self, X, y, transform=True):
        if isinstance(X, np.ndarray):
            self.X = torch.FloatTensor(X)
        else:
            self.X = torch.FloatTensor(np.array(X))
            
        if isinstance(y, np.ndarray):
            self.y = torch.LongTensor(y)
        else:
            self.y = torch.LongTensor(np.array(y))
        
        self.transform = transform
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        image = self.X[idx].clone()
        label = self.y[idx].clone()
        
        # Normalize to [0,1]
        if image.max() > 1.0:
            image = image / 255.0
        
        # Ensure CHW format (channels first)
        if len(image.shape) == 3 and image.shape[-1] == 3:
            image = image.permute(2, 0, 1)
        
        # Resize to 224x224
        if image.shape[-2:] != (224, 224):
            image = image.unsqueeze(0)
            image = F.interpolate(image, size=(224, 224), mode='bilinear', align_corners=False)
            image = image.squeeze(0)
        
        # ImageNet normalization
        mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
        std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
        image = (image - mean) / std
        
        return image, label

class CompleteDeerAgeTrainer:
    """Complete deer age trainer starting from EfficientNet-B5"""
    
    def __init__(self, num_classes=5):
        self.num_classes = num_classes
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.all_results = []
        print(f"🔥 COMPLETE DEER AGE TRAINER")
        print(f"   Device: {self.device}")
        print(f"   Classes: {num_classes}")
    
    def get_all_architectures_with_fallback(self):
        """Get ALL architectures with graceful fallback for missing pretrained weights"""
        
        architectures = {
            # EfficientNet COMPLETE SERIES (B5-B7 + variants)
            #'EfficientNet-B5': {'model_name': 'efficientnet_b5', 'family': 'EfficientNet'},
            #'EfficientNet-B6': {'model_name': 'efficientnet_b6', 'family': 'EfficientNet'},
            #'EfficientNet-B7': {'model_name': 'efficientnet_b7', 'family': 'EfficientNet'},
            
            # EfficientNetV2 (try multiple naming conventions)
            #'EfficientNetV2-S': {'model_name': 'efficientnetv2_s', 'family': 'EfficientNetV2', 'alternatives': ['tf_efficientnetv2_s_in21ft1k', 'efficientnetv2_rw_s']},
            #EfficientNetV2-M': {'model_name': 'efficientnetv2_m', 'family': 'EfficientNetV2', 'alternatives': ['tf_efficientnetv2_m_in21ft1k', 'efficientnetv2_rw_m']},
            #'EfficientNetV2-L': {'model_name': 'efficientnetv2_l', 'family': 'EfficientNetV2', 'alternatives': ['tf_efficientnetv2_l_in21ft1k']},
            
            # DenseNet Family
            #'DenseNet-121': {'model_name': 'densenet121', 'family': 'DenseNet'},
            #'DenseNet-161': {'model_name': 'densenet161', 'family': 'DenseNet'},
            #'DenseNet-169': {'model_name': 'densenet169', 'family': 'DenseNet'},
            'DenseNet-201': {'model_name': 'densenet201', 'family': 'DenseNet'},
            
            # ResNet Family
            'ResNet-18': {'model_name': 'resnet18', 'family': 'ResNet'},
            'ResNet-34': {'model_name': 'resnet34', 'family': 'ResNet'},
            'ResNet-50': {'model_name': 'resnet50', 'family': 'ResNet'},
            'ResNet-101': {'model_name': 'resnet101', 'family': 'ResNet'},
            'ResNet-152': {'model_name': 'resnet152', 'family': 'ResNet'},
            
            # ResNeXt
            'ResNeXt-50': {'model_name': 'resnext50_32x4d', 'family': 'ResNeXt'},
            'ResNeXt-101': {'model_name': 'resnext101_32x8d', 'family': 'ResNeXt'},
            
            # Wide ResNet
            'Wide-ResNet-50': {'model_name': 'wide_resnet50_2', 'family': 'Wide-ResNet'},
            'Wide-ResNet-101': {'model_name': 'wide_resnet101_2', 'family': 'Wide-ResNet'},
            
            # MobileNet Family
            'MobileNetV2': {'model_name': 'mobilenetv2_100', 'family': 'MobileNet'},
            'MobileNetV3-Small': {'model_name': 'mobilenetv3_small_100', 'family': 'MobileNet'},
            'MobileNetV3-Large': {'model_name': 'mobilenetv3_large_100', 'family': 'MobileNet'},
            
            # RegNet Family (try multiple naming conventions)
            'RegNetX-400MF': {'model_name': 'regnetx_400mf', 'family': 'RegNet', 'alternatives': ['regnetx_002', 'regnetx_004']},
            'RegNetX-800MF': {'model_name': 'regnetx_800mf', 'family': 'RegNet', 'alternatives': ['regnetx_004', 'regnetx_006']},
            'RegNetY-400MF': {'model_name': 'regnety_400mf', 'family': 'RegNet', 'alternatives': ['regnety_002', 'regnety_004']},
            'RegNetY-800MF': {'model_name': 'regnety_800mf', 'family': 'RegNet', 'alternatives': ['regnety_004', 'regnety_006']},
            
            # ConvNeXt Family
            'ConvNeXt-Tiny': {'model_name': 'convnext_tiny', 'family': 'ConvNeXt', 'alternatives': ['convnext_tiny_in22ft1k']},
            'ConvNeXt-Small': {'model_name': 'convnext_small', 'family': 'ConvNeXt', 'alternatives': ['convnext_small_in22ft1k']},
            'ConvNeXt-Base': {'model_name': 'convnext_base', 'family': 'ConvNeXt', 'alternatives': ['convnext_base_in22ft1k']},
            
            # Vision Transformer variants
            'Swin-Tiny': {'model_name': 'swin_tiny_patch4_window7_224', 'family': 'Swin', 'alternatives': ['swin_tiny_patch4_window7_224_in22k']},
            'Swin-Small': {'model_name': 'swin_small_patch4_window7_224', 'family': 'Swin', 'alternatives': ['swin_small_patch4_window7_224_in22k']},
            
            # VGG (classic)
            'VGG-16': {'model_name': 'vgg16', 'family': 'VGG'},
            'VGG-19': {'model_name': 'vgg19', 'family': 'VGG'},
            
            # Vision Transformers
            'DeiT-Tiny': {'model_name': 'deit_tiny_patch16_224', 'family': 'DeiT'},
            'DeiT-Small': {'model_name': 'deit_small_patch16_224', 'family': 'DeiT'},
            'DeiT-Base': {'model_name': 'deit_base_patch16_224', 'family': 'DeiT'},
            
            # Additional EfficientNet variants
            'EfficientNet-ES': {'model_name': 'efficientnet_es', 'family': 'EfficientNet'},
            'EfficientNet-EM': {'model_name': 'efficientnet_em', 'family': 'EfficientNet'},
            'EfficientNet-EL': {'model_name': 'efficientnet_el', 'family': 'EfficientNet'},
            
            # Additional strong architectures
            'ResNet-26': {'model_name': 'resnet26', 'family': 'ResNet'},
            'ResNet-26d': {'model_name': 'resnet26d', 'family': 'ResNet'},
            'SEResNet-50': {'model_name': 'seresnet50', 'family': 'SEResNet'},
            'SEResNeXt-50': {'model_name': 'seresnext50_32x4d', 'family': 'SEResNeXt'},
        }
        
        print(f"\n🏗️ COMPLETE ARCHITECTURE ARSENAL ({len(architectures)} models)")
        print("="*80)
        print("🎯 FALLBACK STRATEGY: Pretrained → Alternative Names → Random Init")
        print("📊 ALL models will be tested regardless of pretrained weight availability")
        
        # Group by family and show counts
        families = {}
        for arch_name, arch_info in architectures.items():
            family = arch_info['family']
            if family not in families:
                families[family] = []
            families[family].append(arch_name)
        
        for family, models in families.items():
            print(f"📁 {family} ({len(models)} models): {', '.join(models)}")
        
        return architectures
    
    def create_model_with_fallback(self, arch_name, arch_info, freeze_strategy='none'):
        """Create model with graceful fallback for missing pretrained weights"""
        model_name = arch_info['model_name']
        alternatives = arch_info.get('alternatives', [])
        
        print(f"      🔧 Creating {arch_name}...")
        
        # Strategy 1: Try primary model name with pretrained=True
        try:
            print(f"         🎯 Trying pretrained: {model_name}")
            model = timm.create_model(model_name, pretrained=True, num_classes=self.num_classes)
            initialization_type = "pretrained"
            final_model_name = model_name
            print(f"         ✅ SUCCESS with pretrained weights!")
        except Exception as e1:
            print(f"         ❌ Pretrained failed: {str(e1)[:50]}...")
            
            # Strategy 2: Try alternative names with pretrained=True
            model = None
            for alt_name in alternatives:
                try:
                    print(f"         🎯 Trying alternative pretrained: {alt_name}")
                    model = timm.create_model(alt_name, pretrained=True, num_classes=self.num_classes)
                    initialization_type = "pretrained_alt"
                    final_model_name = alt_name
                    print(f"         ✅ SUCCESS with alternative pretrained weights!")
                    break
                except Exception as e2:
                    print(f"         ❌ Alternative {alt_name} failed: {str(e2)[:30]}...")
                    continue
            
            # Strategy 3: Fall back to random initialization
            if model is None:
                try:
                    print(f"         🎲 Falling back to random initialization: {model_name}")
                    model = timm.create_model(model_name, pretrained=False, num_classes=self.num_classes)
                    initialization_type = "random"
                    final_model_name = model_name
                    print(f"         ✅ SUCCESS with random initialization!")
                except Exception as e3:
                    # Try alternatives with random initialization
                    for alt_name in alternatives:
                        try:
                            print(f"         🎲 Trying alternative random: {alt_name}")
                            model = timm.create_model(alt_name, pretrained=False, num_classes=self.num_classes)
                            initialization_type = "random_alt"
                            final_model_name = alt_name
                            print(f"         ✅ SUCCESS with alternative random initialization!")
                            break
                        except Exception as e4:
                            continue
                    
                    # If still failed, return None
                    if model is None:
                        print(f"         ❌ COMPLETE FAILURE: All strategies failed")
                        return None, None, None
        
        # Apply freezing strategy
        if freeze_strategy == 'backbone':
            print(f"         🧊 Freezing backbone layers...")
            for name, param in model.named_parameters():
                if 'classifier' not in name and 'head' not in name and 'fc' not in name:
                    param.requires_grad = False
            
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in model.parameters())
            print(f"         📊 Loaded: {total_params:,} total, {trainable_params:,} trainable ({initialization_type})")
        
        elif freeze_strategy == 'partial':
            print(f"         ❄️ Partial freeze (last 30% unfrozen)...")
            all_params = list(model.named_parameters())
            total_layers = len(all_params)
            freeze_until = int(total_layers * 0.7)
            
            for i, (name, param) in enumerate(all_params):
                if i < freeze_until and 'classifier' not in name and 'head' not in name and 'fc' not in name:
                    param.requires_grad = False
            
            trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
            total_params = sum(p.numel() for p in model.parameters())
            print(f"         📊 Loaded: {total_params:,} total, {trainable_params:,} trainable ({initialization_type})")
        
        else:  # no freezing
            total_params = sum(p.numel() for p in model.parameters())
            print(f"         📊 Loaded: {total_params:,} parameters (all trainable, {initialization_type})")
        
        model = model.to(self.device)
        return model, initialization_type, final_model_name
    
    def ultra_aggressive_training(self, model, arch_name, train_loader, val_loader, test_loader, strategy='unfrozen'):
        """Ultra aggressive training with minimal early stopping"""
        print(f"      🔥 ULTRA AGGRESSIVE TRAINING: {arch_name} ({strategy})...")
        
        # More aggressive setup
        criterion = nn.CrossEntropyLoss(label_smoothing=0.1)
        
        # Strategy-specific hyperparameters
        if strategy == 'frozen':
            lr = 0.01
            max_epochs = 100
            patience = 50
        elif strategy == 'partial':
            lr = 0.005
            max_epochs = 100
            patience = 50
        else:  # unfrozen
            lr = 0.001
            max_epochs = 100
            patience = 50
        
        optimizer = optim.AdamW(
            model.parameters(),
            lr=lr,
            weight_decay=0.01,
            betas=(0.9, 0.999)
        )
        
        # Simple step scheduler
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
        
        best_val_acc = 0.0
        patience_counter = 0
        
        print(f"         📊 ULTRA SETUP: {max_epochs} epochs, LR={lr}, patience={patience}")
        
        for epoch in range(max_epochs):
            # Training phase
            model.train()
            train_correct = 0
            train_total = 0
            train_loss = 0.0
            
            for batch_idx, (images, labels) in enumerate(train_loader):
                images, labels = images.to(self.device), labels.to(self.device)
                
                optimizer.zero_grad()
                outputs = model(images)
                loss = criterion(outputs, labels)
                loss.backward()
                
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                
                optimizer.step()
                
                train_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_total += labels.size(0)
                train_correct += (predicted == labels).sum().item()
            
            train_acc = 100 * train_correct / train_total
            
            # Validation phase
            model.eval()
            val_correct = 0
            val_total = 0
            val_loss = 0.0
            
            with torch.no_grad():
                for images, labels in val_loader:
                    images, labels = images.to(self.device), labels.to(self.device)
                    outputs = model(images)
                    loss = criterion(outputs, labels)
                    val_loss += loss.item()
                    
                    _, predicted = torch.max(outputs.data, 1)
                    val_total += labels.size(0)
                    val_correct += (predicted == labels).sum().item()
            
            val_acc = 100 * val_correct / val_total
            scheduler.step()
            current_lr = scheduler.get_last_lr()[0]
            
            # Very lenient early stopping
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                patience_counter = 0
                best_model_state = model.state_dict().copy()
                improvement = "🔥"
            else:
                patience_counter += 1
                improvement = ""
            
            # More frequent progress updates
            if epoch % 5 == 0 or epoch < 10 or improvement or epoch > max_epochs - 10:
                gap = train_acc - val_acc
                print(f"         Epoch {epoch:3d}: Train {train_acc:.1f}%, Val {val_acc:.1f}% (gap: {gap:+.1f}%), LR: {current_lr:.2e} {improvement}")
        
        # Restore best model
        model.load_state_dict(best_model_state)
        
        # Test evaluation
        model.eval()
        test_correct = 0
        test_total = 0
        
        with torch.no_grad():
            for images, labels in test_loader:
                images, labels = images.to(self.device), labels.to(self.device)
                outputs = model(images)
                _, predicted = torch.max(outputs.data, 1)
                test_total += labels.size(0)
                test_correct += (predicted == labels).sum().item()
        
        test_acc = 100 * test_correct / test_total
        
        print(f"         🎯 {arch_name} ({strategy}) FINAL: Val {best_val_acc:.1f}%, Test {test_acc:.1f}%")
        
        return best_val_acc, test_acc
    
    def test_architecture_with_multiple_strategies(self, arch_name, arch_info, train_loader, val_loader, test_loader):
        """Test architecture with multiple training strategies and fallback support"""
        results = []
        
        # Strategy 1: Frozen backbone (fast warmup)
        print(f"      🧊 FROZEN BACKBONE STRATEGY:")
        model_frozen, init_type_frozen, final_name_frozen = self.create_model_with_fallback(arch_name, arch_info, freeze_strategy='backbone')
        if model_frozen is not None:
            try:
                val_acc_frozen, test_acc_frozen = self.ultra_aggressive_training(
                    model_frozen, arch_name, train_loader, val_loader, test_loader, strategy='frozen'
                )
                results.append({
                    'name': f"{arch_name}-Frozen",
                    'strategy': 'frozen',
                    'val_accuracy': val_acc_frozen,
                    'test_accuracy': test_acc_frozen,
                    'family': arch_info['family'],
                    'initialization': init_type_frozen,
                    'final_model_name': final_name_frozen,
                    'original_model_name': arch_info['model_name']
                })
            except Exception as e:
                print(f"         ❌ Frozen strategy failed: {str(e)[:50]}...")
        
        # Strategy 2: Partial freeze (if frozen worked reasonably)
        if results and results[-1]['val_accuracy'] > 35:
            print(f"      ❄️ PARTIAL FREEZE STRATEGY:")
            model_partial, init_type_partial, final_name_partial = self.create_model_with_fallback(arch_name, arch_info, freeze_strategy='partial')
            if model_partial is not None:
                try:
                    val_acc_partial, test_acc_partial = self.ultra_aggressive_training(
                        model_partial, arch_name, train_loader, val_loader, test_loader, strategy='partial'
                    )
                    results.append({
                        'name': f"{arch_name}-Partial",
                        'strategy': 'partial',
                        'val_accuracy': val_acc_partial,
                        'test_accuracy': test_acc_partial,
                        'family': arch_info['family'],
                        'initialization': init_type_partial,
                        'final_model_name': final_name_partial,
                        'original_model_name': arch_info['model_name']
                    })
                except Exception as e:
                    print(f"         ❌ Partial strategy failed: {str(e)[:50]}...")
        
        # Strategy 3: Full unfrozen (if partial worked well)
        if results and max(r['val_accuracy'] for r in results) > 45:
            print(f"      🔥 FULL UNFROZEN STRATEGY:")
            model_unfrozen, init_type_unfrozen, final_name_unfrozen = self.create_model_with_fallback(arch_name, arch_info, freeze_strategy='none')
            if model_unfrozen is not None:
                try:
                    val_acc_unfrozen, test_acc_unfrozen = self.ultra_aggressive_training(
                        model_unfrozen, arch_name, train_loader, val_loader, test_loader, strategy='unfrozen'
                    )
                    results.append({
                        'name': f"{arch_name}-Unfrozen",
                        'strategy': 'unfrozen',
                        'val_accuracy': val_acc_unfrozen,
                        'test_accuracy': test_acc_unfrozen,
                        'family': arch_info['family'],
                        'initialization': init_type_unfrozen,
                        'final_model_name': final_name_unfrozen,
                        'original_model_name': arch_info['model_name']
                    })
                except Exception as e:
                    print(f"         ❌ Unfrozen strategy failed: {str(e)[:50]}...")
        
        return results
    
    def run_complete_pipeline(self, X_train, y_train, X_val, y_val, X_test, y_test):
        """Run the complete pipeline with result storage"""
        print("🔥 COMPLETE DEER AGING PIPELINE")
        print("="*80)
        print("🎯 Starting from EfficientNet-B5 onwards")
        print("🎯 All results will be saved automatically")
        print("="*80)
        
        # Create datasets
        train_dataset = DeerDataset(X_train, y_train)
        val_dataset = DeerDataset(X_val, y_val)
        test_dataset = DeerDataset(X_test, y_test)
        
        # Create data loaders
        train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
        val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=0)
        test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)
        
        print(f"📊 Data ready: {len(train_dataset)} train, {len(val_dataset)} val, {len(test_dataset)} test")
        
        # Get all architectures
        architectures = self.get_all_architectures_with_fallback()
        
        total_start_time = time.time()
        
        print(f"\n🔥 ULTRA AGGRESSIVE TESTING: {len(architectures)} ARCHITECTURES")
        print("="*80)
        
        for i, (arch_name, arch_info) in enumerate(architectures.items(), 1):
            print(f"\n[{i}/{len(architectures)}] 🔥 ULTRA AGGRESSIVE {arch_name}")
            print("-" * 70)
            
            start_time = time.time()
            
            # Test with multiple strategies
            arch_results = self.test_architecture_with_multiple_strategies(
                arch_name, arch_info, train_loader, val_loader, test_loader
            )
            
            # Add metadata and timing
            for result in arch_results:
                result['architecture_family'] = arch_info['family']
                result['training_time'] = time.time() - start_time
                result['timestamp'] = datetime.now().isoformat()
                self.all_results.append(result)
            
            if arch_results:
                best_arch_result = max(arch_results, key=lambda x: x['test_accuracy'])
                print(f"      🏆 Best {arch_name}: {best_arch_result['name']} ({best_arch_result['test_accuracy']:.1f}%)")
            
            print(f"      ⏱️ Total time for {arch_name}: {time.time() - start_time:.1f}s")
            
            # Save progress text file and backup files after each architecture
            elapsed_time = time.time() - total_start_time
            progress_file = self.save_progress_text_file(i, len(architectures), arch_name, arch_results, elapsed_time)
            json_file, pkl_file = self.save_backup_files(i)
            
            print(f"      💾 Backups: {json_file}, {pkl_file}")
            
            # Intermediate leaderboard every 3 architectures
            if i % 3 == 0:
                self.show_intermediate_leaderboard(i)
        
        total_time = time.time() - total_start_time
        
        # Save final results
        self.save_final_results(total_time)
        
        # Display final leaderboard
        self.show_final_leaderboard(total_time)
        
        return self.all_results
    
    def save_progress_text_file(self, completed_count, total_count, arch_name, arch_results, total_time_so_far):
        """Save human-readable progress text file after each architecture
        
        Creates files like:
        🔥 DEER AGING PIPELINE - PROGRESS REPORT
        ================================================================================
        📅 Timestamp: 2025-06-11 14:23:45
        ⏱️  Runtime so far: 2.34 hours
        📊 Progress: 5/32 architectures completed
        🔧 Device: cuda
        🎯 Classes: 5
        
        🏁 JUST COMPLETED: EfficientNet-B5
        --------------------------------------------------
        EfficientNet-B5-Frozen        | frozen   | Pretrained   | Val:  52.1% | Test:  48.3%
        EfficientNet-B5-Partial       | partial  | Pretrained   | Val:  58.7% | Test:  55.2%
        🏆 Best: EfficientNet-B5-Partial (55.2%)
        
        🏆 CURRENT TOP 10 LEADERBOARD
        --------------------------------------------------------------------------------
        Rank Model                          Strategy   Init         Val%     Test%
        --------------------------------------------------------------------------------
        1    EfficientNet-B4-Unfrozen       unfrozen   Pretrained   61.2     58.9
        2    EfficientNet-B5-Partial        partial    Pretrained   58.7     55.2
        ...
        """
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        filename = f'deer_aging_progress_{timestamp}.txt'
        
        with open(filename, 'w') as f:
            f.write("🔥 DEER AGING PIPELINE - PROGRESS REPORT\n")
            f.write("="*80 + "\n")
            f.write(f"📅 Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write(f"⏱️  Runtime so far: {total_time_so_far/3600:.2f} hours\n")
            f.write(f"📊 Progress: {completed_count}/{total_count} architectures completed\n")
            f.write(f"🔧 Device: {self.device}\n")
            f.write(f"🎯 Classes: {self.num_classes}\n")
            f.write("="*80 + "\n\n")
            
            # Just completed architecture results
            f.write(f"🏁 JUST COMPLETED: {arch_name}\n")
            f.write("-"*50 + "\n")
            if arch_results:
                for result in arch_results:
                    init_type = result.get('initialization', 'unknown')
                    f.write(f"   {result['name']:30} | {result['strategy']:8} | {init_type:12} | Val: {result['val_accuracy']:5.1f}% | Test: {result['test_accuracy']:5.1f}%\n")
                
                best_arch = max(arch_results, key=lambda x: x['test_accuracy'])
                f.write(f"   🏆 Best: {best_arch['name']} ({best_arch['test_accuracy']:.1f}%)\n")
            else:
                f.write("   ❌ No successful results for this architecture\n")
            f.write("\n")
            
            # Current overall leaderboard (top 10)
            sorted_results = sorted(self.all_results, key=lambda x: x['test_accuracy'], reverse=True)
            f.write(f"🏆 CURRENT TOP 10 LEADERBOARD\n")
            f.write("-"*80 + "\n")
            f.write(f"{'Rank':<4} {'Model':<30} {'Strategy':<10} {'Init':<12} {'Val%':<8} {'Test%':<8}\n")
            f.write("-"*80 + "\n")
            
            for i, result in enumerate(sorted_results[:10], 1):
                init_type = result.get('initialization', 'unknown')
                if init_type == 'pretrained':
                    init_display = "Pretrained"
                elif init_type == 'pretrained_alt':
                    init_display = "Alt-Pre"
                elif init_type == 'random':
                    init_display = "Random"
                elif init_type == 'random_alt':
                    init_display = "Alt-Rand"
                else:
                    init_display = "Unknown"
                
                f.write(f"{i:<4} {result['name']:<30} {result['strategy']:<10} {init_display:<12} {result['val_accuracy']:<7.1f} {result['test_accuracy']:<7.1f}\n")
            
            f.write("\n")
            
            # Summary statistics
            if sorted_results:
                best_overall = sorted_results[0]
                breakthrough_count = sum(1 for r in sorted_results if r['test_accuracy'] > 54.2)
                excellent_count = sum(1 for r in sorted_results if r['test_accuracy'] >= 65.0)
                
                f.write(f"📊 SUMMARY STATISTICS\n")
                f.write("-"*50 + "\n")
                f.write(f"🏆 Current Best: {best_overall['name']} ({best_overall['test_accuracy']:.1f}%)\n")
                f.write(f"🚀 Models beating 54.2% baseline: {breakthrough_count}/{len(sorted_results)}\n")
                f.write(f"🎉 Models achieving 65%+: {excellent_count}\n")
                f.write(f"📈 Total models tested: {len(sorted_results)}\n")
                
                # Analysis by initialization type
                f.write(f"\n📊 BY INITIALIZATION TYPE:\n")
                init_groups = {}
                for result in sorted_results:
                    init_type = result.get('initialization', 'unknown')
                    if init_type not in init_groups:
                        init_groups[init_type] = []
                    init_groups[init_type].append(result)
                
                for init_type, group in init_groups.items():
                    avg_test = sum(r['test_accuracy'] for r in group) / len(group)
                    best_test = max(r['test_accuracy'] for r in group)
                    f.write(f"   {init_type:15}: {len(group):2d} models, avg: {avg_test:.1f}%, best: {best_test:.1f}%\n")
            
            f.write("\n")
            f.write(f"🔄 REMAINING: {total_count - completed_count} architectures to test\n")
            estimated_time_remaining = (total_time_so_far / completed_count) * (total_count - completed_count) if completed_count > 0 else 0
            f.write(f"⏰ Estimated time remaining: {estimated_time_remaining/3600:.1f} hours\n")
            f.write("="*80 + "\n")
            f.write(f"📁 This file: {filename}\n")
            f.write(f"💾 Auto-saved at: {datetime.now().strftime('%H:%M:%S')}\n")
        
        print(f"      📝 Progress saved: {filename}")
        return filename
    
    def save_final_results(self, total_time):
        """Save comprehensive final results"""
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        
        # Prepare comprehensive results
        final_data = {
            'experiment_info': {
                'timestamp': timestamp,
                'total_runtime_hours': total_time / 3600,
                'total_models_tested': len(self.all_results),
                'device': str(self.device),
                'num_classes': self.num_classes
            },
            'results': self.all_results,
            'leaderboard': sorted(self.all_results, key=lambda x: x['test_accuracy'], reverse=True)
        }
        
        # Save as JSON
        with open(f'deer_aging_final_results_{timestamp}.json', 'w') as f:
            json.dump(final_data, f, indent=2)
        
        # Save as pickle
        with open(f'deer_aging_final_results_{timestamp}.pkl', 'wb') as f:
            pickle.dump(final_data, f)
        
        print(f"\n💾 RESULTS SAVED:")
        print(f"   📁 deer_aging_final_results_{timestamp}.json")
        print(f"   📁 deer_aging_final_results_{timestamp}.pkl")
    
    def show_intermediate_leaderboard(self, completed_count):
        """Show intermediate leaderboard"""
        current_best = sorted(self.all_results, key=lambda x: x['test_accuracy'], reverse=True)[:5]
        print(f"\n📊 CURRENT TOP 5 (after {completed_count} architectures):")
        for j, result in enumerate(current_best, 1):
            print(f"   {j}. {result['name']}: {result['test_accuracy']:.1f}%")
        print()
    
    def show_final_leaderboard(self, total_time):
        """Show comprehensive final leaderboard"""
        # Sort all results
        sorted_results = sorted(self.all_results, key=lambda x: x['test_accuracy'], reverse=True)
        
        print(f"\n🏆 FINAL COMPREHENSIVE RESULTS")
        print("="*80)
        print(f"⏰ Total testing time: {total_time/3600:.1f} hours")
        print(f"🎯 Models tested: {len(self.all_results)}")
        print("="*80)
        print(f"{'Rank':<4} {'Model':<30} {'Strategy':<10} {'Init':<12} {'Val%':<8} {'Test%':<8} {'Status'}")
        print("-" * 80)
        
        for i, result in enumerate(sorted_results, 1):
            val_acc = result['val_accuracy']
            test_acc = result['test_accuracy']
            strategy = result['strategy']
            init_type = result.get('initialization', 'unknown')
            
            # Format initialization type for display
            if init_type == 'pretrained':
                init_display = "🎯 Pretrained"
            elif init_type == 'pretrained_alt':
                init_display = "🎯 Alt-Pre"
            elif init_type == 'random':
                init_display = "🎲 Random"
            elif init_type == 'random_alt':
                init_display = "🎲 Alt-Rand"
            else:
                init_display = "❓ Unknown"
            
            if test_acc >= 75.0:
                status = "🎉 BREAKTHROUGH!"
            elif test_acc >= 65.0:
                status = "🔥 EXCELLENT!"
            elif test_acc > 54.2:
                status = "🚀 NEW BEST!"
            elif test_acc > 45.0:
                status = "📈 Good"
            else:
                status = "📉 Weak"
            
            print(f"{i:<4} {result['name']:<30} {strategy:<10} {init_display:<12} {val_acc:<7.1f} {test_acc:<7.1f} {status}")
        
        # Additional analysis by initialization type
        print(f"\n📊 ANALYSIS BY INITIALIZATION:")
        init_groups = {}
        for result in sorted_results:
            init_type = result.get('initialization', 'unknown')
            if init_type not in init_groups:
                init_groups[init_type] = []
            init_groups[init_type].append(result)
        
        for init_type, group in init_groups.items():
            avg_test = sum(r['test_accuracy'] for r in group) / len(group)
            best_test = max(r['test_accuracy'] for r in group)
            print(f"   {init_type:15}: {len(group):2d} models, avg: {avg_test:.1f}%, best: {best_test:.1f}%")
        
        # Summary statistics
        if sorted_results:
            best = sorted_results[0]
            breakthrough_count = sum(1 for r in sorted_results if r['test_accuracy'] > 54.2)
            excellent_count = sum(1 for r in sorted_results if r['test_accuracy'] >= 65.0)
            
            print(f"\n🎊 FINAL SUMMARY:")
            print(f"   🏆 ULTIMATE CHAMPION: {best['name']} ({best['test_accuracy']:.1f}%)")
            print(f"   🚀 Beat 54.2% baseline: {breakthrough_count}/{len(sorted_results)} models")
            print(f"   🎉 Achieved 65%+: {excellent_count} models")
            
            if best['test_accuracy'] >= 75.0:
                print(f"   🎉 MISSION ACCOMPLISHED! Achieved 75%+ accuracy!")
            elif best['test_accuracy'] >= 65.0:
                print(f"   🎊 EXCELLENT! Found 65%+ architecture!")
            elif best['test_accuracy'] > 54.2:
                improvement = best['test_accuracy'] - 54.2
                print(f"   🚀 SUCCESS! Improved by +{improvement:.1f}% over baseline!")
            
            # Analysis of initialization types
            best_pretrained = max([r for r in sorted_results if r.get('initialization', '').startswith('pretrained')], 
                                key=lambda x: x['test_accuracy'], default=None)
            best_random = max([r for r in sorted_results if r.get('initialization', '').startswith('random')], 
                            key=lambda x: x['test_accuracy'], default=None)
            
            if best_pretrained and best_random:
                print(f"   🎯 Best Pretrained: {best_pretrained['name']} ({best_pretrained['test_accuracy']:.1f}%)")
                print(f"   🎲 Best Random Init: {best_random['name']} ({best_random['test_accuracy']:.1f}%)")
                if best_random['test_accuracy'] > best_pretrained['test_accuracy']:
                    print(f"   🔥 SURPRISE! Random initialization outperformed pretrained!")
        
        print("="*80)

def run_complete_deer_aging_pipeline():
    """Run the complete deer aging pipeline from start to finish"""
    print("🚀 LAUNCHING COMPLETE DEER AGING PIPELINE")
    print("="*80)
    print("📋 PIPELINE STEPS:")
    print("   1. Load original 357 images")
    print("   2. Create train/val/test splits")
    print("   3. Balance and augment training data")
    print("   4. Test all architectures (starting from EfficientNet-B5)")
    print("   5. Save results and create leaderboard")
    print("="*80)
    print("💾 CRASH RECOVERY: Progress saved after each model!")
    print("📝 Look for 'deer_aging_progress_*.txt' files for latest results")
    print("💼 JSON/pickle backups: 'deer_aging_backup_*.json/.pkl'")
    print("="*80)
    
    try:
        # Step 1: Load data
        images, ages = load_original_data()
        
        # Step 2: Create splits
        X_train, y_train, X_val, y_val, X_test, y_test, label_mapping, reverse_mapping = create_train_val_test_split(images, ages)
        
        # Step 3: Augment data
        X_train_aug, y_train_aug = balance_and_augment_data(X_train, y_train, augment_multiplier=30, num_classes=len(label_mapping))
        
        # Step 4: Run complete testing
        trainer = CompleteDeerAgeTrainer(num_classes=len(label_mapping))
        results = trainer.run_complete_pipeline(X_train_aug, y_train_aug, X_val, y_val, X_test, y_test)
        
        print("\n🎉 PIPELINE COMPLETE!")
        print("📁 All results saved with timestamps")
        print("🏆 Check the final leaderboard above")
        
        return results, label_mapping, reverse_mapping
        
    except KeyboardInterrupt:
        print("\n⚠️  INTERRUPTED BY USER")
        print("📝 Check latest 'deer_aging_progress_*.txt' file for current results")
        print("💾 Backup files saved as 'deer_aging_backup_*.json/.pkl'")
        raise
    except Exception as e:
        print(f"\n❌ PIPELINE CRASHED: {str(e)}")
        print("📝 Check latest 'deer_aging_progress_*.txt' file for results up to crash")
        print("💾 Backup files saved as 'deer_aging_backup_*.json/.pkl'")
        print("🔄 You can manually load the pickle files to recover results")
        raise

# 🔄 CRASH RECOVERY FUNCTION
def load_latest_results():
    """Load the most recent backup results (for crash recovery)"""
    import glob
    import os
    
    print("🔄 CRASH RECOVERY MODE")
    print("="*50)
    
    # Find latest backup files
    pickle_files = glob.glob('deer_aging_backup_*.pkl')
    if not pickle_files:
        print("❌ No backup files found!")
        return None
    
    # Get the most recent file
    latest_pickle = max(pickle_files, key=os.path.getctime)
    print(f"📁 Loading latest backup: {latest_pickle}")
    
    # Load the results
    with open(latest_pickle, 'rb') as f:
        results = pickle.load(f)
    
    print(f"✅ Loaded {len(results)} results")
    
    # Show quick summary
    if results:
        sorted_results = sorted(results, key=lambda x: x['test_accuracy'], reverse=True)
        print(f"🏆 Best model: {sorted_results[0]['name']} ({sorted_results[0]['test_accuracy']:.1f}%)")
        print(f"📊 Models tested: {len(results)}")
        
        # Show latest progress file
        progress_files = glob.glob('deer_aging_progress_*.txt')
        if progress_files:
            latest_progress = max(progress_files, key=os.path.getctime)
            print(f"📝 Latest progress report: {latest_progress}")
    
    return results

# 🔥 RUN COMPLETE PIPELINE
if __name__ == "__main__":
    print("🔥 LAUNCHING COMPLETE DEER AGING PIPELINE...")
    print("⚠️  Starting from EfficientNet-B5 (as requested)")
    print("🎯 TESTING ALL MODELS: Pretrained → Alternatives → Random Init")
    print("💾 All results will be automatically saved")
    print("📊 Will show which models used pretrained vs random initialization")
    print("="*80)
    print("💡 CRASH RECOVERY TIPS:")
    print("   📝 Progress files: 'deer_aging_progress_YYYYMMDD_HHMMSS.txt'")
    print("   💾 Backup files: 'deer_aging_backup_*.pkl'") 
    print("   🔄 To recover: results = load_latest_results()")
    print("="*80)
    
    final_results, final_label_mapping, final_reverse_mapping = run_complete_deer_aging_pipeline()