In [3]:
# Griffin Ablation Study for MNIST (12 Configurations)

import os
import time
import random
import contextlib
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Optimizer
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score
import matplotlib.pyplot as plt
import pandas as pd
from tqdm import tqdm
import json
from dataclasses import dataclass
from typing import Tuple, Dict, List, Any

# ---------- Global Performance Knobs ----------
USE_AMP = torch.cuda.is_available()
CHANNELS_LAST = False  # MNIST is grayscale, so channels_last not beneficial
AUTO_BENCHMARK = True
MATMUL_PRECISION = "high"

def set_seed(seed=42):
    """Sets the seed for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = False
        torch.backends.cudnn.benchmark = AUTO_BENCHMARK
        try:
            torch.set_float32_matmul_precision(MATMUL_PRECISION)
        except AttributeError:
            print("Warning: torch.set_float32_matmul_precision is not available.")

set_seed(42)

# ---------------------------
# Griffin Optimizer Implementation
# ---------------------------
class Griffin(Optimizer):
    def __init__(self, params, lr=5e-4, betas=(0.95, 0.99), weight_decay=1e-4,
                 eps=1e-8, beta_sigma=0.9, schedule_decay=5e-3, warmup_steps=100):
        if not 0.0 <= lr:
            raise ValueError(f"Invalid learning rate: {lr}")
        if not 0.0 <= eps:
            raise ValueError(f"Invalid epsilon value: {eps}")
        if not 0.0 <= betas[0] < 1.0:
            raise ValueError(f"Invalid beta-1 parameter: {betas[0]}")
        if not 0.0 <= betas[1] < 1.0:
            raise ValueError(f"Invalid beta-2 parameter: {betas[1]}")
        if not 0.0 <= beta_sigma < 1.0:
            raise ValueError(f"Invalid beta_sigma parameter: {beta_sigma}")
        if not 0.0 <= weight_decay:
            raise ValueError(f"Invalid weight_decay value: {weight_decay}")
        if not 0.0 <= schedule_decay:
            raise ValueError(f"Invalid schedule_decay value: {schedule_decay}")
        if not 0 <= warmup_steps:
            raise ValueError(f"Invalid warmup_steps value: {warmup_steps}")

        defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay,
                        eps=eps, beta_sigma=beta_sigma,
                        schedule_decay=schedule_decay, warmup_steps=warmup_steps)
        super().__init__(params, defaults)

    @torch.no_grad()
    def step(self, closure=None):
        loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        for group in self.param_groups:
            lr = group['lr']
            wd = group['weight_decay']
            eps = group['eps']
            beta_sigma = group['beta_sigma']
            schedule_decay = group['schedule_decay']
            warmup_steps = group['warmup_steps']
            beta1, beta2 = group['betas']

            for p in group['params']:
                if p.grad is None:
                    continue

                g = p.grad
                state = self.state[p]
                if len(state) == 0:
                    state['step'] = 0
                    state['m_schedule'] = 1.0
                    state['m'] = torch.zeros_like(p)  # 1st moment
                    state['v'] = torch.zeros_like(p)  # 2nd moment
                    state['sigma'] = 0.6  # Scalar sigma per parameter
                    state['g_prev'] = torch.zeros_like(p)  # Previous gradient

                state['step'] += 1
                step = state['step']
                m, v, g_prev = state['m'], state['v'], state['g_prev']
                m_schedule = state['m_schedule']
                sigma = state['sigma']

                # ===== Nadam Momentum Schedule (Optimized) =====
                momentum_t = beta1 * (1. - 0.5 * (0.98 ** (step * schedule_decay)))
                momentum_t1 = beta1 * (1. - 0.5 * (0.98 ** ((step + 1) * schedule_decay)))
                m_schedule = m_schedule * momentum_t
                schedule_next = m_schedule * momentum_t1

                # ===== Efficient Weight Decay =====
                if wd != 0:
                    p.mul_(1 - lr * wd)

                # ===== Low-Cost Moment Updates =====
                # Update momentum (m)
                m.mul_(beta1).add_(g, alpha=1 - beta1)

                # Update adaptive term (v)
                v.mul_(beta2).addcmul_(g, g, value=1 - beta2)

                # ===== Ultra-Efficient Sigma Update =====
                grad_change = (g - g_prev).abs().mean().item()
                stability = 1.0 / (1.0 + 10 * grad_change)
                new_sigma = 0.4 + 0.4 * stability

                # Warmup-controlled sigma update
                warmup_factor = min(1.0, step / warmup_steps)
                sigma = sigma * (beta_sigma * warmup_factor) + new_sigma * (1 - beta_sigma * warmup_factor)
                sigma = max(0.4, min(0.8, sigma))
                state['sigma'] = sigma

                # Store current gradient for next step
                g_prev.copy_(g)

                # ===== Fast Parameter Update =====
                # Bias correction with momentum combination
                bias_corr1 = 1 - beta1 ** step
                bias_corr2 = 1 - beta2 ** step

                m_hat = (beta1 * m) / (1 - beta1**(step+1)) + ((1 - beta1) * g) / (1 - beta1**step)
                v_hat = v / bias_corr2
                denom = v_hat.sqrt().add_(eps)

                # Learning rate with warmup and sigma scaling
                lr_scale = min(1.0, step / warmup_steps)
                step_size = lr * lr_scale * (1. - momentum_t) / (1. - m_schedule) * sigma

                p.addcdiv_(m_hat, denom, value=-step_size)

                # Update momentum schedule
                state['m_schedule'] = schedule_next

        return loss

# ---------------------------
# Enhanced Metrics Tracking
# ---------------------------
class MetricsTracker:
    """Enhanced metrics tracking with statistical analysis"""
    def __init__(self):
        self.metrics = {}
        self.training_history = []
    
    def update_epoch(self, run_name: str, epoch: int, metrics: Dict[str, float]):
        if run_name not in self.metrics:
            self.metrics[run_name] = []
        
        epoch_data = {'epoch': epoch, **metrics}
        self.metrics[run_name].append(epoch_data)
        self.training_history.append({'run': run_name, **epoch_data})
    
    def get_summary_dataframe(self) -> pd.DataFrame:
        """Convert metrics to pandas DataFrame for analysis"""
        df = pd.DataFrame(self.training_history)
        return df
    
    def compute_statistical_significance(self, config1: str, config2: str, metric: str = 'test_acc') -> float:
        """Compute if differences are statistically significant"""
        try:
            import scipy.stats as stats
            
            if config1 not in self.metrics or config2 not in self.metrics:
                return 1.0
            
            metrics1 = [m[metric] for m in self.metrics[config1] if metric in m]
            metrics2 = [m[metric] for m in self.metrics[config2] if metric in m]
            
            if len(metrics1) < 2 or len(metrics2) < 2:
                return 1.0
            
            t_stat, p_value = stats.ttest_ind(metrics1, metrics2)
            return p_value
        except ImportError:
            return 1.0

# ---------------------------
# MNIST Data Loading
# ---------------------------
def get_mnist_loaders(batch_size=128, data_root="./data"):
    cores = os.cpu_count() or 2
    num_workers = min(8, max(2, cores // 2))

    # MNIST specific transformations
    transform_train = transforms.Compose([
        transforms.RandomCrop(28, padding=4),  # MNIST is 28x28
        transforms.RandomHorizontalFlip(),  # Can still be useful for digits
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))  # MNIST mean and std
    ])
    
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

    # Load MNIST dataset
    train_full = datasets.MNIST(root=data_root, train=True, download=True, transform=transform_train)
    testset = datasets.MNIST(root=data_root, train=False, download=True, transform=transform_test)

    # Split training set into train and validation
    trainset, valset = random_split(train_full, [55000, 5000])  # MNIST has 60k training samples

    loader_kwargs = dict(num_workers=num_workers, pin_memory=True, persistent_workers=True if num_workers > 0 else False)
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, **loader_kwargs)
    val_loader = DataLoader(valset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, **loader_kwargs)
    
    print(f"MNIST Dataset: {len(trainset)} training, {len(valset)} validation, {len(testset)} test samples")
    return train_loader, val_loader, test_loader

# ---------------------------
# MNIST Model Definition
# ---------------------------
class MNIST_CNN(nn.Module):
    """A simple CNN model for MNIST classification"""
    def __init__(self):
        super(MNIST_CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(64 * 7 * 7, 128)  # After two pools: 28->14->7
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout1(x)
        x = x.view(-1, 64 * 7 * 7)
        x = F.relu(self.fc1(x))
        x = self.dropout2(x)
        x = self.fc2(x)
        return x

def get_mnist_model():
    """Returns a CNN model for MNIST"""
    return MNIST_CNN()

# ---------------------------
# Training and Evaluation Loop
# ---------------------------
def train_model(model_fn, optimizer_factory, train_loader, val_loader, test_loader,
                num_epochs=10, run_name="default", metrics_tracker=None):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model_fn()
    model = model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optimizer_factory(model.parameters())

    scaler = torch.amp.GradScaler('cuda', enabled=USE_AMP)
    autocast_ctx = lambda: torch.amp.autocast(device_type=device.type, enabled=USE_AMP)

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

    metrics = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}
    overall_start_time = time.time()

    print(f"Starting training for '{run_name}' on {device}...")

    for epoch in range(num_epochs):
        epoch_start = time.time()
        model.train()
        train_loss, train_correct, train_total = 0.0, 0, 0
        for data, targets in train_loader:
            data, targets = data.to(device, non_blocking=True), targets.to(device, non_blocking=True)

            optimizer.zero_grad(set_to_none=True)
            with autocast_ctx():
                outputs = model(data)
                loss = criterion(outputs, targets)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item() * data.size(0)
            _, predicted = outputs.max(1)
            train_total += targets.size(0)
            train_correct += predicted.eq(targets).sum().item()

        metrics['train_loss'].append(train_loss / train_total)
        metrics['train_acc'].append(train_correct / train_total)

        # Validation
        model.eval()
        val_loss, val_correct, val_total = 0.0, 0, 0
        with torch.no_grad():
            for data, targets in val_loader:
                data, targets = data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
                with autocast_ctx():
                    outputs = model(data)
                    loss = criterion(outputs, targets)
                val_loss += loss.item() * data.size(0)
                _, predicted = outputs.max(1)
                val_total += targets.size(0)
                val_correct += predicted.eq(targets).sum().item()

        metrics['val_loss'].append(val_loss / val_total)
        metrics['val_acc'].append(val_correct / val_total)

        # Update metrics tracker
        if metrics_tracker is not None:
            epoch_metrics = {
                'train_loss': metrics['train_loss'][-1],
                'val_loss': metrics['val_loss'][-1],
                'train_acc': metrics['train_acc'][-1],
                'val_acc': metrics['val_acc'][-1]
            }
            metrics_tracker.update_epoch(run_name, epoch, epoch_metrics)

        scheduler.step()

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train Loss: {metrics['train_loss'][-1]:.4f}, Acc: {metrics['train_acc'][-1]:.4f} | "
              f"Val Loss: {metrics['val_loss'][-1]:.4f}, Acc: {metrics['val_acc'][-1]:.4f} | "
              f"Time: {time.time() - epoch_start:.2f}s")

    # Final Test Evaluation
    model.eval()
    test_loss, test_correct, test_total = 0.0, 0, 0
    all_probs, all_targets = [], []
    with torch.no_grad():
        for data, targets in test_loader:
            data, targets = data.to(device, non_blocking=True), targets.to(device, non_blocking=True)
            outputs = model(data)
            loss = criterion(outputs, targets)
            test_loss += loss.item() * data.size(0)
            _, predicted = outputs.max(1)
            test_total += targets.size(0)
            test_correct += predicted.eq(targets).sum().item()
            all_probs.extend(torch.softmax(outputs, dim=1).cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

    final_metrics = {
        'test_acc': test_correct / test_total,
        'test_loss': test_loss / test_total,
        'avg_epoch_time': (time.time() - overall_start_time) / num_epochs
    }

    preds = np.argmax(all_probs, axis=1)
    final_metrics['test_precision'] = precision_score(all_targets, preds, average='macro', zero_division=0)
    final_metrics['test_recall'] = recall_score(all_targets, preds, average='macro', zero_division=0)
    final_metrics['test_f1'] = f1_score(all_targets, preds, average='macro', zero_division=0)
    
    try:
        final_metrics['test_auc'] = roc_auc_score(all_targets, all_probs, multi_class='ovr')
    except:
        final_metrics['test_auc'] = 0.0

    print(f"Test Results ({run_name}): Acc: {final_metrics['test_acc']:.4f}, F1: {final_metrics['test_f1']:.4f}")
    return {**metrics, **final_metrics}

# ---------------------------
# Ablation Study Setup for Griffin (12 Configurations)
# ---------------------------
def get_ablated_hyperparams_griffin():
    """Focused hyperparameter search for Griffin optimizer (12 configurations)"""
    default_cfg = {
        'lr': 5e-4,
        'betas': (0.95, 0.99),
        'weight_decay': 1e-4,
        'eps': 1e-8,
        'beta_sigma': 0.9,
        'schedule_decay': 5e-3,
        'warmup_steps': 100
    }
    hyperparams = {}

    # 1. Baseline
    hyperparams["Default"] = default_cfg.copy()
    
    # 2-3. Learning rate variations (most critical parameter)
    for lr in [1e-4, 1e-3]:
        cfg = default_cfg.copy()
        cfg['lr'] = lr
        hyperparams[f"lr={lr}"] = cfg

    # 4-5. Beta variations (momentum and variance)
    for betas in [(0.9, 0.999), (0.98, 0.999)]:
        cfg = default_cfg.copy()
        cfg['betas'] = betas
        hyperparams[f"Œ≤‚ÇÅ={betas[0]},Œ≤‚ÇÇ={betas[1]}"] = cfg

    # 6. Beta_sigma variation (sigma momentum)
    cfg = default_cfg.copy()
    cfg['beta_sigma'] = 0.95
    hyperparams["Œ≤_œÉ=0.95"] = cfg

    # 7-8. Weight decay variations (regularization)
    for wd in [0.0, 1e-3]:
        cfg = default_cfg.copy()
        cfg['weight_decay'] = wd
        hyperparams[f"wd={wd}"] = cfg

    # 9-10. Schedule decay variations (momentum scheduling)
    for schedule_decay in [1e-3, 1e-2]:
        cfg = default_cfg.copy()
        cfg['schedule_decay'] = schedule_decay
        hyperparams[f"schedule_decay={schedule_decay}"] = cfg

    # 11. Warmup steps variations
    cfg = default_cfg.copy()
    cfg['warmup_steps'] = 500
    hyperparams["warmup=500"] = cfg

    # 12. Combined aggressive configuration
    cfg = default_cfg.copy()
    cfg['lr'] = 1e-3
    cfg['betas'] = (0.98, 0.999)
    cfg['schedule_decay'] = 1e-2
    hyperparams["HighLR+HighBeta+HighDecay"] = cfg

    return hyperparams

def validate_griffin_params(params):
    """Validate Griffin optimizer parameters"""
    if not isinstance(params, dict):
        raise TypeError("params must be a dictionary")
    
    required_keys = ['lr', 'betas', 'weight_decay', 'eps', 'beta_sigma', 
                    'schedule_decay', 'warmup_steps']
    
    for key in required_keys:
        if key not in params:
            raise ValueError(f"Missing required parameter: {key}")
    
    beta1, beta2 = params['betas']
    if not (0.0 <= beta1 < 1.0 and 0.0 <= beta2 < 1.0):
        raise ValueError(f"Invalid beta values: {params['betas']}")
    
    if not 0.0 <= params['beta_sigma'] < 1.0:
        raise ValueError(f"Invalid beta_sigma: {params['beta_sigma']}")
    
    if not 0.0 <= params['schedule_decay']:
        raise ValueError(f"Invalid schedule_decay: {params['schedule_decay']}")
    
    if not 0 <= params['warmup_steps']:
        raise ValueError(f"Invalid warmup_steps: {params['warmup_steps']}")
    
    return True

# ---------------------------
# Enhanced Visualization
# ---------------------------
def create_griffin_comprehensive_plots(all_metrics: Dict[str, Dict], output_dir: str = "./griffin_mnist_results"):
    """Create comprehensive visualization of Griffin ablation results"""
    os.makedirs(output_dir, exist_ok=True)
    
    # Performance comparison
    fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(16, 12))
    
    # Accuracy and F1 comparison
    configs = list(all_metrics.keys())
    test_accs = [all_metrics[c]['test_acc'] for c in configs]
    test_f1s = [all_metrics[c]['test_f1'] for c in configs]
    
    x_pos = np.arange(len(configs))
    bars1 = ax1.bar(x_pos - 0.2, test_accs, 0.4, label='Accuracy', alpha=0.8, color='steelblue')
    bars2 = ax1.bar(x_pos + 0.2, test_f1s, 0.4, label='F1-Score', alpha=0.8, color='darkorange')
    
    ax1.set_xlabel('Configurations')
    ax1.set_ylabel('Scores')
    ax1.set_title('Griffin on MNIST: Test Accuracy and F1-Score (12 Configurations)')
    ax1.set_xticks(x_pos)
    ax1.set_xticklabels(configs, rotation=45, ha='right')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bars in [bars1, bars2]:
        for bar in bars:
            height = bar.get_height()
            ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{height:.3f}', ha='center', va='bottom', fontsize=8)
    
    # Training curves for top configurations
    top_configs = sorted(configs, key=lambda x: all_metrics[x]['test_acc'], reverse=True)[:3]
    for config in top_configs:
        ax2.plot(all_metrics[config]['train_acc'], label=f'{config} (Train)', linewidth=2)
        ax2.plot(all_metrics[config]['val_acc'], '--', label=f'{config} (Val)', linewidth=2)
    
    ax2.set_xlabel('Epoch')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Griffin on MNIST: Training Curves - Top 3 Configurations')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Convergence speed
    convergence_data = []
    for config in configs:
        val_accs = all_metrics[config]['val_acc']
        if len(val_accs) > 0:
            max_acc = max(val_accs)
            target_acc = 0.8 * max_acc
            convergence_epoch = next((i for i, acc in enumerate(val_accs) if acc >= target_acc), len(val_accs))
            convergence_data.append(convergence_epoch)
        else:
            convergence_data.append(len(val_accs))
    
    ax3.bar(configs, convergence_data, alpha=0.7, color='mediumpurple')
    ax3.set_xlabel('Configurations')
    ax3.set_ylabel('Epoch to Reach 80% Max Accuracy')
    ax3.set_title('Griffin on MNIST: Convergence Speed')
    ax3.set_xticklabels(configs, rotation=45, ha='right')
    ax3.grid(True, alpha=0.3)
    
    # Performance vs computational efficiency
    times = [all_metrics[c]['avg_epoch_time'] for c in configs]
    accuracies = test_accs
    
    scatter = ax4.scatter(times, accuracies, s=100, alpha=0.7, c=test_f1s, cmap='coolwarm')
    ax4.set_xlabel('Average Epoch Time (s)')
    ax4.set_ylabel('Test Accuracy')
    ax4.set_title('Griffin on MNIST: Accuracy vs Computational Efficiency')
    ax4.grid(True, alpha=0.3)
    
    # Add colorbar for F1 scores
    plt.colorbar(scatter, ax=ax4, label='F1-Score')
    
    # Annotate points
    for i, config in enumerate(configs):
        ax4.annotate(config, (times[i], accuracies[i]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    plt.tight_layout()
    plt.savefig(f'{output_dir}/griffin_mnist_ablation_12_configs.png', dpi=300, bbox_inches='tight')
    plt.close()

    # Additional plot: Griffin-specific parameter analysis
    fig, ax = plt.subplots(figsize=(14, 8))
    
    # Group by parameter type for sensitivity analysis
    param_groups = {
        'Learning Rate': [k for k in configs if k.startswith('lr=')],
        'Beta Parameters': [k for k in configs if k.startswith('Œ≤')],
        'Weight Decay': [k for k in configs if k.startswith('wd=')],
        'Schedule Decay': [k for k in configs if k.startswith('schedule_decay=')],
        'Beta Sigma': [k for k in configs if 'Œ≤_œÉ' in k],
        'Warmup/Combined': [k for k in configs if 'warmup' in k or 'HighLR' in k]
    }
    
    colors = plt.cm.viridis(np.linspace(0, 1, len(param_groups)))
    
    for i, (group_name, group_configs) in enumerate(param_groups.items()):
        group_accs = [all_metrics[c]['test_acc'] for c in group_configs if c in all_metrics]
        if group_accs:
            mean_acc = np.mean(group_accs)
            std_acc = np.std(group_accs)
            ax.bar(group_name, mean_acc, color=colors[i], alpha=0.7, 
                  yerr=std_acc, capsize=5, label=f'{group_name} (n={len(group_accs)})')
    
    ax.set_ylabel('Test Accuracy')
    ax.set_title('Griffin: Parameter Group Sensitivity Analysis on MNIST')
    ax.grid(True, alpha=0.3)
    ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig(f'{output_dir}/griffin_parameter_sensitivity.png', dpi=300, bbox_inches='tight')
    plt.close()

# ---------------------------
# Main Ablation Study Execution
# ---------------------------
def run_griffin_ablation_study(num_epochs=10, batch_size=256):
    """Run Griffin ablation study with 12 configurations on MNIST"""
    train_loader, val_loader, test_loader = get_mnist_loaders(batch_size=batch_size)
    hyperparams = get_ablated_hyperparams_griffin()
    all_metrics = {}
    
    # Initialize metrics tracker
    metrics_tracker = MetricsTracker()
    
    print(f"üöÄ Starting Griffin Ablation Study on MNIST with {len(hyperparams)} configurations")
    print(f"üìä Epochs: {num_epochs}, Batch Size: {batch_size}")
    print("=" * 100)
    
    for name, cfg in tqdm(hyperparams.items(), desc="Griffin Configurations"):
        print(f"\n‚ñ∂Ô∏è  Running: {name}")
        print(f"   ‚öôÔ∏è  Config: {cfg}")
        
        try:
            validate_griffin_params(cfg)
            
            optimizer_factory = lambda params: Griffin(params, **cfg)
            metrics = train_model(
                get_mnist_model, optimizer_factory,
                train_loader, val_loader, test_loader,
                num_epochs=num_epochs, run_name=name,
                metrics_tracker=metrics_tracker
            )
            all_metrics[name] = metrics
                
        except Exception as e:
            print(f"‚ùå Error in configuration {name}: {e}")
            continue

    # Generate comprehensive analysis
    print("\n" + "üìà" * 20 + " GRIFFIN MNIST ANALYSIS " + "üìà" * 20)
    generate_detailed_griffin_analysis(all_metrics, metrics_tracker)
    
    return all_metrics, metrics_tracker

def generate_detailed_griffin_analysis(all_metrics, metrics_tracker):
    """Generate detailed analysis of Griffin results on MNIST"""
    # Create comprehensive plots
    create_griffin_comprehensive_plots(all_metrics)
    
    # Export to CSV for further analysis
    df = metrics_tracker.get_summary_dataframe()
    df.to_csv('griffin_mnist_ablation_12_configs.csv', index=False)
    
    # Save configurations
    save_griffin_configs({k: v for k, v in get_ablated_hyperparams_griffin().items() 
                       if k in all_metrics})
    
    # Print results table
    print("\n" + "#" * 80)
    print(" " * 20 + "Griffin Ablation Study Results on MNIST (12 Configurations)")
    print("#" * 80)

    best_config_name = max(all_metrics.keys(), key=lambda k: all_metrics[k]['test_acc'])
    best_metrics = all_metrics[best_config_name]
    
    print(f"\nüèÜ Best Configuration: '{best_config_name}'")
    print(f"   - Accuracy:  {best_metrics['test_acc']:.4f}")
    print(f"   - F1-Score:  {best_metrics['test_f1']:.4f}")
    print(f"   - AUC:       {best_metrics['test_auc']:.4f}")
    print(f"   - Precision: {best_metrics['test_precision']:.4f}")
    print(f"   - Time/Epoch: {best_metrics['avg_epoch_time']:.2f}s")

    print("\n" + "-" * 70)
    print("LaTeX Table Summary:")
    print("-" * 70)
    print(r"\begin{tabular}{lcccccc}")
    print(r"\toprule")
    print(r"\textbf{Configuration} & \textbf{Accuracy} & \textbf{F1-Score} & \textbf{AUC} & \textbf{Precision} & \textbf{Recall} & \textbf{Time/Epoch} \\")
    print(r"\midrule")

    sorted_names = sorted(all_metrics.keys(), key=lambda k: all_metrics[k]['test_acc'], reverse=True)
    for name in sorted_names:
        metrics = all_metrics[name]
        is_best = name == best_config_name
        acc_str = r"\textbf{" + f"{metrics['test_acc']:.4f}" + "}" if is_best else f"{metrics['test_acc']:.4f}"
        
        # LaTeX-safe name
        latex_name = name.replace('_', ' ').replace('Œ≤', r'$\beta$').replace('œÉ', r'$\sigma$')
        
        print(
            latex_name + " & " +
            acc_str + " & " +
            f"{metrics['test_f1']:.4f} & " +
            f"{metrics['test_auc']:.4f} & " +
            f"{metrics['test_precision']:.4f} & " +
            f"{metrics['test_recall']:.4f} & " +
            f"{metrics['avg_epoch_time']:.2f}s" + r" \\"
        )

    print(r"\bottomrule")
    print(r"\end{tabular}")

    # Statistical significance analysis
    configs = list(all_metrics.keys())
    if len(configs) >= 2:
        best_config = max(configs, key=lambda x: all_metrics[x]['test_acc'])
        second_best = max([c for c in configs if c != best_config], 
                         key=lambda x: all_metrics[x]['test_acc'])
        
        p_value = metrics_tracker.compute_statistical_significance(
            best_config, second_best
        )
        
        print(f"\nüìä Statistical Significance Analysis:")
        print(f"   Best vs Second Best: p-value = {p_value:.4f}")
        if p_value < 0.05:
            print("   ‚úÖ Difference is statistically significant (p < 0.05)")
        else:
            print("   ‚ö†Ô∏è  Difference is not statistically significant (p ‚â• 0.05)")

    # Griffin-specific parameter importance analysis
    print(f"\nüîç Griffin Parameter Importance Summary:")
    param_groups = {
        'Learning Rate': [k for k in configs if k.startswith('lr=')],
        'Beta Parameters': [k for k in configs if k.startswith('Œ≤')],
        'Weight Decay': [k for k in configs if k.startswith('wd=')],
        'Schedule Decay': [k for k in configs if k.startswith('schedule_decay=')],
        'Beta Sigma': [k for k in configs if 'Œ≤_œÉ' in k],
        'Warmup Steps': [k for k in configs if 'warmup' in k],
        'Combined Configs': [k for k in configs if 'HighLR' in k]
    }
    
    for group_name, group_configs in param_groups.items():
        if group_configs:
            group_accs = [all_metrics[c]['test_acc'] for c in group_configs if c in all_metrics]
            if group_accs:
                mean_acc = np.mean(group_accs)
                std_acc = np.std(group_accs)
                print(f"   {group_name:20}: {mean_acc:.4f} ¬± {std_acc:.4f} (n={len(group_accs)})")

    # Griffin-specific insights
    print(f"\nüéØ Griffin-Specific Insights:")
    print(f"   - Best accuracy: {best_metrics['test_acc']:.4f}")
    print(f"   - Average accuracy across all configs: {np.mean([all_metrics[c]['test_acc'] for c in configs]):.4f}")
    print(f"   - Standard deviation: {np.std([all_metrics[c]['test_acc'] for c in configs]):.4f}")
    
    # Analyze sigma-related configurations
    sigma_configs = [c for c in configs if 'Œ≤_œÉ' in c]
    if sigma_configs:
        sigma_accs = [all_metrics[c]['test_acc'] for c in sigma_configs if c in all_metrics]
        print(f"   - Beta_sigma configurations average: {np.mean(sigma_accs):.4f}")
    
    # Analyze schedule decay effects
    decay_configs = [c for c in configs if 'schedule_decay' in c]
    if decay_configs:
        decay_accs = [all_metrics[c]['test_acc'] for c in decay_configs if c in all_metrics]
        print(f"   - Schedule decay configurations average: {np.mean(decay_accs):.4f}")
    
    # Analyze warmup effects
    warmup_configs = [c for c in configs if 'warmup' in c]
    if warmup_configs:
        warmup_accs = [all_metrics[c]['test_acc'] for c in warmup_configs if c in all_metrics]
        print(f"   - Warmup configurations average: {np.mean(warmup_accs):.4f}")

    # Griffin algorithm specific observations
    print(f"\n‚ö° Griffin Algorithm Observations:")
    print(f"   - Uses Nadam-style momentum scheduling with exponential decay")
    print(f"   - Features adaptive sigma based on gradient stability")
    print(f"   - Implements warmup-controlled learning rate scaling")
    print(f"   - Combines bias-corrected moments with sigma scaling")

def save_griffin_configs(hyperparams, filename="griffin_mnist_ablation_12_configs.json"):
    """Save Griffin ablation configurations for reproducibility"""
    serializable_configs = {}
    for name, cfg in hyperparams.items():
        serializable_configs[name] = {
            'lr': cfg['lr'],
            'betas': list(cfg['betas']),
            'weight_decay': cfg['weight_decay'],
            'eps': cfg['eps'],
            'beta_sigma': cfg['beta_sigma'],
            'schedule_decay': cfg['schedule_decay'],
            'warmup_steps': cfg['warmup_steps']
        }
    
    with open(filename, 'w') as f:
        json.dump(serializable_configs, f, indent=2)

if __name__ == "__main__":
    all_metrics, metrics_tracker = run_griffin_ablation_study(num_epochs=10, batch_size=256)

MNIST Dataset: 55000 training, 5000 validation, 10000 test samples
üöÄ Starting Griffin Ablation Study on MNIST with 12 configurations
üìä Epochs: 10, Batch Size: 256


Griffin Configurations:   0%|          | 0/12 [00:00<?, ?it/s]


‚ñ∂Ô∏è  Running: Default
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'Default' on cuda...
Epoch 1/10 | Train Loss: 1.7227, Acc: 0.3916 | Val Loss: 0.9342, Acc: 0.7182 | Time: 10.56s
Epoch 2/10 | Train Loss: 0.8635, Acc: 0.7085 | Val Loss: 0.5524, Acc: 0.8362 | Time: 10.33s
Epoch 3/10 | Train Loss: 0.6463, Acc: 0.7830 | Val Loss: 0.4308, Acc: 0.8678 | Time: 10.25s
Epoch 4/10 | Train Loss: 0.5516, Acc: 0.8159 | Val Loss: 0.3666, Acc: 0.8854 | Time: 10.26s
Epoch 5/10 | Train Loss: 0.5012, Acc: 0.8329 | Val Loss: 0.3422, Acc: 0.8910 | Time: 10.24s
Epoch 6/10 | Train Loss: 0.4641, Acc: 0.8471 | Val Loss: 0.3153, Acc: 0.9000 | Time: 10.46s
Epoch 7/10 | Train Loss: 0.4410, Acc: 0.8537 | Val Loss: 0.3042, Acc: 0.9024 | Time: 10.36s
Epoch 8/10 | Train Loss: 0.4270, Acc: 0.8599 | Val Loss: 0.2926, Acc: 0.9074 | Time: 10.20s
Epoch 9/10 | Train Loss: 0.4223, Ac

Griffin Configurations:   8%|‚ñä         | 1/12 [01:44<19:09, 104.49s/it]

Test Results (Default): Acc: 0.9353, F1: 0.9339

‚ñ∂Ô∏è  Running: lr=0.0001
   ‚öôÔ∏è  Config: {'lr': 0.0001, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'lr=0.0001' on cuda...
Epoch 1/10 | Train Loss: 2.1620, Acc: 0.2264 | Val Loss: 1.7992, Acc: 0.4420 | Time: 10.37s
Epoch 2/10 | Train Loss: 1.6614, Acc: 0.4245 | Val Loss: 1.3902, Acc: 0.5418 | Time: 10.35s
Epoch 3/10 | Train Loss: 1.3911, Acc: 0.5223 | Val Loss: 1.1520, Acc: 0.6366 | Time: 10.35s
Epoch 4/10 | Train Loss: 1.2017, Acc: 0.5902 | Val Loss: 0.9838, Acc: 0.7104 | Time: 10.30s
Epoch 5/10 | Train Loss: 1.0845, Acc: 0.6331 | Val Loss: 0.8677, Acc: 0.7426 | Time: 10.33s
Epoch 6/10 | Train Loss: 1.0037, Acc: 0.6624 | Val Loss: 0.8148, Acc: 0.7642 | Time: 10.45s
Epoch 7/10 | Train Loss: 0.9544, Acc: 0.6790 | Val Loss: 0.7619, Acc: 0.7752 | Time: 10.34s
Epoch 8/10 | Train Loss: 0.9346, Acc: 0.6844 | Val Loss: 0.7616, Acc: 0.777

Griffin Configurations:  17%|‚ñà‚ñã        | 2/12 [03:28<17:23, 104.38s/it]

Test Results (lr=0.0001): Acc: 0.8554, F1: 0.8499

‚ñ∂Ô∏è  Running: lr=0.001
   ‚öôÔ∏è  Config: {'lr': 0.001, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'lr=0.001' on cuda...
Epoch 1/10 | Train Loss: 1.4642, Acc: 0.4851 | Val Loss: 0.6273, Acc: 0.8024 | Time: 10.23s
Epoch 2/10 | Train Loss: 0.6420, Acc: 0.7829 | Val Loss: 0.3957, Acc: 0.8750 | Time: 10.35s
Epoch 3/10 | Train Loss: 0.4910, Acc: 0.8363 | Val Loss: 0.3291, Acc: 0.8910 | Time: 10.28s
Epoch 4/10 | Train Loss: 0.4261, Acc: 0.8601 | Val Loss: 0.2653, Acc: 0.9212 | Time: 10.36s
Epoch 5/10 | Train Loss: 0.3761, Acc: 0.8786 | Val Loss: 0.2336, Acc: 0.9262 | Time: 10.25s
Epoch 6/10 | Train Loss: 0.3479, Acc: 0.8893 | Val Loss: 0.2128, Acc: 0.9358 | Time: 10.35s
Epoch 7/10 | Train Loss: 0.3251, Acc: 0.8979 | Val Loss: 0.2045, Acc: 0.9356 | Time: 10.21s
Epoch 8/10 | Train Loss: 0.3122, Acc: 0.9001 | Val Loss: 0.2055, Acc: 0.9384

Griffin Configurations:  25%|‚ñà‚ñà‚ñå       | 3/12 [05:12<15:37, 104.21s/it]

Test Results (lr=0.001): Acc: 0.9603, F1: 0.9596

‚ñ∂Ô∏è  Running: Œ≤‚ÇÅ=0.9,Œ≤‚ÇÇ=0.999
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.9, 0.999), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'Œ≤‚ÇÅ=0.9,Œ≤‚ÇÇ=0.999' on cuda...
Epoch 1/10 | Train Loss: 1.5620, Acc: 0.4468 | Val Loss: 0.7079, Acc: 0.7896 | Time: 10.38s
Epoch 2/10 | Train Loss: 0.7159, Acc: 0.7636 | Val Loss: 0.4617, Acc: 0.8490 | Time: 10.47s
Epoch 3/10 | Train Loss: 0.5651, Acc: 0.8147 | Val Loss: 0.3714, Acc: 0.8810 | Time: 10.21s
Epoch 4/10 | Train Loss: 0.4939, Acc: 0.8363 | Val Loss: 0.3354, Acc: 0.8918 | Time: 10.70s
Epoch 5/10 | Train Loss: 0.4529, Acc: 0.8528 | Val Loss: 0.3051, Acc: 0.9060 | Time: 10.63s
Epoch 6/10 | Train Loss: 0.4228, Acc: 0.8616 | Val Loss: 0.2886, Acc: 0.9078 | Time: 10.65s
Epoch 7/10 | Train Loss: 0.4047, Acc: 0.8691 | Val Loss: 0.2686, Acc: 0.9188 | Time: 10.39s
Epoch 8/10 | Train Loss: 0.3917, Acc: 0.8735 | Val

Griffin Configurations:  33%|‚ñà‚ñà‚ñà‚ñé      | 4/12 [06:58<13:59, 104.90s/it]

Test Results (Œ≤‚ÇÅ=0.9,Œ≤‚ÇÇ=0.999): Acc: 0.9422, F1: 0.9411

‚ñ∂Ô∏è  Running: Œ≤‚ÇÅ=0.98,Œ≤‚ÇÇ=0.999
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.98, 0.999), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'Œ≤‚ÇÅ=0.98,Œ≤‚ÇÇ=0.999' on cuda...
Epoch 1/10 | Train Loss: 1.7904, Acc: 0.3627 | Val Loss: 1.0239, Acc: 0.6748 | Time: 10.44s
Epoch 2/10 | Train Loss: 0.9145, Acc: 0.6881 | Val Loss: 0.6099, Acc: 0.8080 | Time: 10.49s
Epoch 3/10 | Train Loss: 0.6770, Acc: 0.7750 | Val Loss: 0.4529, Acc: 0.8622 | Time: 10.35s
Epoch 4/10 | Train Loss: 0.5658, Acc: 0.8129 | Val Loss: 0.3832, Acc: 0.8758 | Time: 10.61s
Epoch 5/10 | Train Loss: 0.5068, Acc: 0.8344 | Val Loss: 0.3525, Acc: 0.8850 | Time: 10.74s
Epoch 6/10 | Train Loss: 0.4747, Acc: 0.8444 | Val Loss: 0.3346, Acc: 0.8986 | Time: 10.63s
Epoch 7/10 | Train Loss: 0.4568, Acc: 0.8517 | Val Loss: 0.3116, Acc: 0.9024 | Time: 10.62s
Epoch 8/10 | Train Loss: 0.4426, A

Griffin Configurations:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 5/12 [08:45<12:19, 105.67s/it]

Test Results (Œ≤‚ÇÅ=0.98,Œ≤‚ÇÇ=0.999): Acc: 0.9289, F1: 0.9272

‚ñ∂Ô∏è  Running: Œ≤_œÉ=0.95
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.95, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'Œ≤_œÉ=0.95' on cuda...
Epoch 1/10 | Train Loss: 1.7367, Acc: 0.3881 | Val Loss: 0.8898, Acc: 0.7416 | Time: 10.67s
Epoch 2/10 | Train Loss: 0.8256, Acc: 0.7253 | Val Loss: 0.5111, Acc: 0.8416 | Time: 10.52s
Epoch 3/10 | Train Loss: 0.6215, Acc: 0.7945 | Val Loss: 0.4148, Acc: 0.8678 | Time: 10.77s
Epoch 4/10 | Train Loss: 0.5351, Acc: 0.8238 | Val Loss: 0.3606, Acc: 0.8880 | Time: 10.75s
Epoch 5/10 | Train Loss: 0.4870, Acc: 0.8415 | Val Loss: 0.3216, Acc: 0.8980 | Time: 10.75s
Epoch 6/10 | Train Loss: 0.4558, Acc: 0.8505 | Val Loss: 0.3059, Acc: 0.9090 | Time: 10.66s
Epoch 7/10 | Train Loss: 0.4271, Acc: 0.8613 | Val Loss: 0.2878, Acc: 0.9094 | Time: 10.31s
Epoch 8/10 | Train Loss: 0.4186, Acc: 0.8662 | Val Loss: 

Griffin Configurations:  50%|‚ñà‚ñà‚ñà‚ñà‚ñà     | 6/12 [10:32<10:36, 106.16s/it]

Test Results (Œ≤_œÉ=0.95): Acc: 0.9372, F1: 0.9357

‚ñ∂Ô∏è  Running: wd=0.0
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'wd=0.0' on cuda...
Epoch 1/10 | Train Loss: 1.7577, Acc: 0.3758 | Val Loss: 0.9853, Acc: 0.7026 | Time: 10.26s
Epoch 2/10 | Train Loss: 0.8810, Acc: 0.7008 | Val Loss: 0.5665, Acc: 0.8208 | Time: 10.42s
Epoch 3/10 | Train Loss: 0.6524, Acc: 0.7825 | Val Loss: 0.4472, Acc: 0.8572 | Time: 10.36s
Epoch 4/10 | Train Loss: 0.5634, Acc: 0.8120 | Val Loss: 0.3830, Acc: 0.8790 | Time: 11.08s
Epoch 5/10 | Train Loss: 0.5089, Acc: 0.8311 | Val Loss: 0.3489, Acc: 0.8914 | Time: 10.72s
Epoch 6/10 | Train Loss: 0.4734, Acc: 0.8440 | Val Loss: 0.3193, Acc: 0.8978 | Time: 10.73s
Epoch 7/10 | Train Loss: 0.4508, Acc: 0.8518 | Val Loss: 0.3037, Acc: 0.9042 | Time: 10.43s
Epoch 8/10 | Train Loss: 0.4397, Acc: 0.8583 | Val Loss: 0.3007, Acc: 0.9052 | Ti

Griffin Configurations:  58%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä    | 7/12 [12:19<08:51, 106.37s/it]

Test Results (wd=0.0): Acc: 0.9286, F1: 0.9267

‚ñ∂Ô∏è  Running: wd=0.001
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 100}
Starting training for 'wd=0.001' on cuda...
Epoch 1/10 | Train Loss: 1.7253, Acc: 0.3914 | Val Loss: 0.8861, Acc: 0.7396 | Time: 10.64s
Epoch 2/10 | Train Loss: 0.8282, Acc: 0.7240 | Val Loss: 0.5071, Acc: 0.8462 | Time: 10.57s
Epoch 3/10 | Train Loss: 0.6120, Acc: 0.7980 | Val Loss: 0.4028, Acc: 0.8750 | Time: 11.11s
Epoch 4/10 | Train Loss: 0.5229, Acc: 0.8271 | Val Loss: 0.3438, Acc: 0.8954 | Time: 10.70s
Epoch 5/10 | Train Loss: 0.4716, Acc: 0.8434 | Val Loss: 0.3134, Acc: 0.8976 | Time: 10.77s
Epoch 6/10 | Train Loss: 0.4381, Acc: 0.8582 | Val Loss: 0.2892, Acc: 0.9086 | Time: 10.60s
Epoch 7/10 | Train Loss: 0.4213, Acc: 0.8633 | Val Loss: 0.2806, Acc: 0.9106 | Time: 10.59s
Epoch 8/10 | Train Loss: 0.4012, Acc: 0.8711 | Val Loss: 0.2714, Acc: 0.9170 | 

Griffin Configurations:  67%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 8/12 [14:08<07:08, 107.01s/it]

Test Results (wd=0.001): Acc: 0.9391, F1: 0.9377

‚ñ∂Ô∏è  Running: schedule_decay=0.001
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.001, 'warmup_steps': 100}
Starting training for 'schedule_decay=0.001' on cuda...
Epoch 1/10 | Train Loss: 1.6851, Acc: 0.4019 | Val Loss: 0.8609, Acc: 0.7238 | Time: 10.68s
Epoch 2/10 | Train Loss: 0.8253, Acc: 0.7209 | Val Loss: 0.5276, Acc: 0.8316 | Time: 10.90s
Epoch 3/10 | Train Loss: 0.6166, Acc: 0.7938 | Val Loss: 0.4050, Acc: 0.8748 | Time: 10.43s
Epoch 4/10 | Train Loss: 0.5201, Acc: 0.8258 | Val Loss: 0.3534, Acc: 0.8884 | Time: 10.52s
Epoch 5/10 | Train Loss: 0.4730, Acc: 0.8449 | Val Loss: 0.3102, Acc: 0.8970 | Time: 10.49s
Epoch 6/10 | Train Loss: 0.4397, Acc: 0.8561 | Val Loss: 0.2846, Acc: 0.9070 | Time: 10.44s
Epoch 7/10 | Train Loss: 0.4209, Acc: 0.8630 | Val Loss: 0.2746, Acc: 0.9134 | Time: 10.37s
Epoch 8/10 | Train Loss: 0.4014, Acc: 0.8695 | Val L

Griffin Configurations:  75%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå  | 9/12 [15:54<05:20, 106.84s/it]

Test Results (schedule_decay=0.001): Acc: 0.9400, F1: 0.9386

‚ñ∂Ô∏è  Running: schedule_decay=0.01
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.01, 'warmup_steps': 100}
Starting training for 'schedule_decay=0.01' on cuda...
Epoch 1/10 | Train Loss: 1.6930, Acc: 0.4025 | Val Loss: 0.8515, Acc: 0.7440 | Time: 10.57s
Epoch 2/10 | Train Loss: 0.8234, Acc: 0.7263 | Val Loss: 0.5176, Acc: 0.8416 | Time: 10.45s
Epoch 3/10 | Train Loss: 0.6316, Acc: 0.7912 | Val Loss: 0.4202, Acc: 0.8644 | Time: 10.40s
Epoch 4/10 | Train Loss: 0.5403, Acc: 0.8203 | Val Loss: 0.3671, Acc: 0.8796 | Time: 10.55s
Epoch 5/10 | Train Loss: 0.4880, Acc: 0.8397 | Val Loss: 0.3205, Acc: 0.8994 | Time: 10.45s
Epoch 6/10 | Train Loss: 0.4595, Acc: 0.8485 | Val Loss: 0.3077, Acc: 0.9000 | Time: 10.42s
Epoch 7/10 | Train Loss: 0.4354, Acc: 0.8598 | Val Loss: 0.2882, Acc: 0.9088 | Time: 10.39s
Epoch 8/10 | Train Loss: 0.4240, Acc: 0.862

Griffin Configurations:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 10/12 [17:40<03:33, 106.59s/it]

Test Results (schedule_decay=0.01): Acc: 0.9354, F1: 0.9339

‚ñ∂Ô∏è  Running: warmup=500
   ‚öôÔ∏è  Config: {'lr': 0.0005, 'betas': (0.95, 0.99), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.005, 'warmup_steps': 500}
Starting training for 'warmup=500' on cuda...
Epoch 1/10 | Train Loss: 2.1129, Acc: 0.2385 | Val Loss: 1.5869, Acc: 0.4916 | Time: 10.45s
Epoch 2/10 | Train Loss: 1.3174, Acc: 0.5460 | Val Loss: 0.8012, Acc: 0.7644 | Time: 10.82s
Epoch 3/10 | Train Loss: 0.7991, Acc: 0.7317 | Val Loss: 0.5089, Acc: 0.8428 | Time: 10.73s
Epoch 4/10 | Train Loss: 0.6227, Acc: 0.7903 | Val Loss: 0.4178, Acc: 0.8698 | Time: 10.91s
Epoch 5/10 | Train Loss: 0.5522, Acc: 0.8169 | Val Loss: 0.3768, Acc: 0.8820 | Time: 10.66s
Epoch 6/10 | Train Loss: 0.5023, Acc: 0.8337 | Val Loss: 0.3521, Acc: 0.8890 | Time: 10.49s
Epoch 7/10 | Train Loss: 0.4757, Acc: 0.8457 | Val Loss: 0.3312, Acc: 0.8952 | Time: 10.42s
Epoch 8/10 | Train Loss: 0.4581, Acc: 0.8499 | Val Loss: 0.32

Griffin Configurations:  92%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè| 11/12 [19:27<01:46, 106.74s/it]

Test Results (warmup=500): Acc: 0.9263, F1: 0.9244

‚ñ∂Ô∏è  Running: HighLR+HighBeta+HighDecay
   ‚öôÔ∏è  Config: {'lr': 0.001, 'betas': (0.98, 0.999), 'weight_decay': 0.0001, 'eps': 1e-08, 'beta_sigma': 0.9, 'schedule_decay': 0.01, 'warmup_steps': 100}
Starting training for 'HighLR+HighBeta+HighDecay' on cuda...
Epoch 1/10 | Train Loss: 1.5806, Acc: 0.4371 | Val Loss: 0.6907, Acc: 0.7910 | Time: 10.58s
Epoch 2/10 | Train Loss: 0.6843, Acc: 0.7721 | Val Loss: 0.4061, Acc: 0.8670 | Time: 10.49s
Epoch 3/10 | Train Loss: 0.5154, Acc: 0.8292 | Val Loss: 0.3379, Acc: 0.8912 | Time: 10.57s
Epoch 4/10 | Train Loss: 0.4471, Acc: 0.8521 | Val Loss: 0.2849, Acc: 0.9082 | Time: 10.52s
Epoch 5/10 | Train Loss: 0.4055, Acc: 0.8671 | Val Loss: 0.2582, Acc: 0.9162 | Time: 10.45s
Epoch 6/10 | Train Loss: 0.3820, Acc: 0.8762 | Val Loss: 0.2377, Acc: 0.9254 | Time: 10.56s
Epoch 7/10 | Train Loss: 0.3641, Acc: 0.8825 | Val Loss: 0.2231, Acc: 0.9276 | Time: 10.49s
Epoch 8/10 | Train Loss: 0.3544, Acc: 0.8

Griffin Configurations: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 12/12 [21:13<00:00, 106.16s/it]

Test Results (HighLR+HighBeta+HighDecay): Acc: 0.9528, F1: 0.9518

üìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìà GRIFFIN MNIST ANALYSIS üìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìàüìà



  ax3.set_xticklabels(configs, rotation=45, ha='right')



################################################################################
                    Griffin Ablation Study Results on MNIST (12 Configurations)
################################################################################

üèÜ Best Configuration: 'lr=0.001'
   - Accuracy:  0.9603
   - F1-Score:  0.9596
   - AUC:       0.9989
   - Precision: 0.9597
   - Time/Epoch: 10.39s

----------------------------------------------------------------------
LaTeX Table Summary:
----------------------------------------------------------------------
\begin{tabular}{lcccccc}
\toprule
\textbf{Configuration} & \textbf{Accuracy} & \textbf{F1-Score} & \textbf{AUC} & \textbf{Precision} & \textbf{Recall} & \textbf{Time/Epoch} \\
\midrule
lr=0.001 & \textbf{0.9603} & 0.9596 & 0.9989 & 0.9597 & 0.9596 & 10.39s \\
HighLR+HighBeta+HighDecay & 0.9528 & 0.9518 & 0.9984 & 0.9521 & 0.9516 & 10.62s \\
$\beta$‚ÇÅ=0.9,$\beta$‚ÇÇ=0.999 & 0.9422 & 0.9411 & 0.9975 & 0.9414 & 0.9409 & 10.59s \\
schedule