# Comprehensive Evaluation and Analysis of Switchable Precision GPT-2

This notebook provides a complete evaluation pipeline for:
1. Configuration evaluation across different bit-widths
2. Training strategy comparison
3. Adversarial robustness testing
4. Visualization and analysis of results

In [None]:
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath('.')))

import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm
import json
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')

# Import our modules
from shared.models import SwitchableQATGPT2, QATGPT2
from shared.dataset import create_dataloaders
from transformers import GPT2Config, GPT2Tokenizer

from evaluate_configurations import ConfigurationEvaluator
from compare_strategies import compare_training_strategies, save_comparison_results
from adversarial_attacks import AdversarialEvaluator

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Setup and Model Initialization

In [None]:
# Configuration
class EvaluationConfig:
    def __init__(self):
        self.model_path = '../part1_switchable_precision/best_model.pth'  # Update with your model path
        self.bit_widths = [4, 8, 16]
        self.n_layer = 6
        self.n_embd = 768
        self.n_head = 12
        self.batch_size = 4
        self.max_length = 128
        self.test_samples = 100
        
config = EvaluationConfig()

# Initialize GPT-2 configuration
gpt2_config = GPT2Config(
    vocab_size=50257,
    n_positions=256,
    n_embd=config.n_embd,
    n_layer=config.n_layer,
    n_head=config.n_head,
    layer_norm_epsilon=1e-5,
    embd_pdrop=0.1
)

# Initialize model
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = SwitchableQATGPT2(gpt2_config, bit_widths=config.bit_widths).to(device)

# Load pretrained weights if available
if os.path.exists(config.model_path):
    checkpoint = torch.load(config.model_path, map_location=device)
    if 'model_state_dict' in checkpoint:
        model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from {config.model_path}")
else:
    print("Using randomly initialized model")

# Initialize tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token

In [None]:
# Create data loaders
train_loader, val_loader, test_loader = create_dataloaders(
    tokenizer=tokenizer,
    train_split='train[:1000]',
    val_split='validation[:200]',
    test_split='validation[200:400]',
    batch_size=config.batch_size,
    max_length=config.max_length,
    doc_stride=64
)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

## 2. Configuration Evaluation (Requirement 4)

In [None]:
# Evaluate different layer configurations
evaluator = ConfigurationEvaluator(model, test_loader)

# Define configurations to test
test_configs = {
    'uniform_4': [4] * config.n_layer,
    'uniform_8': [8] * config.n_layer,
    'uniform_16': [16] * config.n_layer,
    'progressive': [4, 4, 8, 8, 16, 16][:config.n_layer],
    'hourglass': [16, 8, 4, 4, 8, 16][:config.n_layer],
    'edges_high': [16, 8, 8, 8, 8, 16][:config.n_layer],
    'alternating': [4, 16, 4, 16, 4, 16][:config.n_layer],
    'middle_low': [16, 16, 4, 4, 16, 16][:config.n_layer]
}

# Evaluate all configurations
results = {}
for name, layer_config in tqdm(test_configs.items(), desc="Evaluating configurations"):
    result = evaluator._evaluate_single_config(layer_config)
    results[name] = result
    print(f"{name}: Loss={result['loss']:.4f}, Accuracy={result['accuracy']:.4f}, Avg Bits={result['effective_bits']:.1f}")

In [None]:
# Visualization: Accuracy vs Efficiency Trade-off
def plot_accuracy_vs_efficiency(results):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Extract data
    configs = list(results.keys())
    accuracies = [r['accuracy'] for r in results.values()]
    losses = [r['loss'] for r in results.values()]
    avg_bits = [r['effective_bits'] for r in results.values()]
    
    # Plot 1: Accuracy vs Average Bits
    scatter = ax1.scatter(avg_bits, accuracies, s=100, alpha=0.6, c=range(len(configs)), cmap='viridis')
    for i, name in enumerate(configs):
        ax1.annotate(name, (avg_bits[i], accuracies[i]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    ax1.set_xlabel('Average Bits', fontsize=12)
    ax1.set_ylabel('Accuracy', fontsize=12)
    ax1.set_title('Accuracy vs Efficiency Trade-off', fontsize=14, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # Plot 2: Loss vs Average Bits
    scatter2 = ax2.scatter(avg_bits, losses, s=100, alpha=0.6, c=range(len(configs)), cmap='viridis')
    for i, name in enumerate(configs):
        ax2.annotate(name, (avg_bits[i], losses[i]), 
                    xytext=(5, 5), textcoords='offset points', fontsize=8)
    
    ax2.set_xlabel('Average Bits', fontsize=12)
    ax2.set_ylabel('Loss', fontsize=12)
    ax2.set_title('Loss vs Efficiency Trade-off', fontsize=14, fontweight='bold')
    ax2.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig('accuracy_efficiency_tradeoff.png', dpi=300, bbox_inches='tight')
    plt.show()
    
plot_accuracy_vs_efficiency(results)

In [None]:
# Find Pareto optimal configurations
def find_pareto_frontier(results):
    points = [(r['effective_bits'], r['accuracy']) for r in results.values()]
    names = list(results.keys())
    
    pareto_points = []
    pareto_names = []
    
    for i, (bits_i, acc_i) in enumerate(points):
        is_pareto = True
        for j, (bits_j, acc_j) in enumerate(points):
            if i != j:
                # Check if point j dominates point i
                if bits_j <= bits_i and acc_j > acc_i:
                    is_pareto = False
                    break
        if is_pareto:
            pareto_points.append((bits_i, acc_i))
            pareto_names.append(names[i])
    
    return pareto_points, pareto_names

pareto_points, pareto_names = find_pareto_frontier(results)

# Visualize Pareto frontier
plt.figure(figsize=(10, 6))
all_bits = [r['effective_bits'] for r in results.values()]
all_accs = [r['accuracy'] for r in results.values()]

plt.scatter(all_bits, all_accs, s=100, alpha=0.5, label='All configurations')

pareto_bits = [p[0] for p in pareto_points]
pareto_accs = [p[1] for p in pareto_points]
plt.scatter(pareto_bits, pareto_accs, s=200, c='red', marker='*', label='Pareto optimal', zorder=5)

# Sort and draw line
sorted_indices = np.argsort(pareto_bits)
pareto_bits_sorted = [pareto_bits[i] for i in sorted_indices]
pareto_accs_sorted = [pareto_accs[i] for i in sorted_indices]
plt.plot(pareto_bits_sorted, pareto_accs_sorted, 'r--', alpha=0.5)

for name, bits, acc in zip(pareto_names, pareto_bits, pareto_accs):
    plt.annotate(name, (bits, acc), xytext=(5, 5), textcoords='offset points', fontweight='bold')

plt.xlabel('Average Bits', fontsize=12)
plt.ylabel('Accuracy', fontsize=12)
plt.title('Pareto Frontier Analysis', fontsize=14, fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.savefig('pareto_frontier.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nPareto Optimal Configurations:")
for name in pareto_names:
    print(f"  {name}: Accuracy={results[name]['accuracy']:.4f}, Bits={results[name]['effective_bits']:.1f}")

## 3. Layer-wise Precision Analysis

In [None]:
# Heatmap visualization of layer configurations
def plot_layer_configurations_heatmap(test_configs, results):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))
    
    # Create matrix for configurations
    config_matrix = []
    config_names = []
    
    for name, config in test_configs.items():
        config_matrix.append(config)
        config_names.append(name)
    
    config_matrix = np.array(config_matrix)
    
    # Plot 1: Configuration heatmap
    im1 = ax1.imshow(config_matrix, cmap='YlOrRd', aspect='auto', vmin=4, vmax=16)
    ax1.set_yticks(range(len(config_names)))
    ax1.set_yticklabels(config_names)
    ax1.set_xticks(range(config.n_layer))
    ax1.set_xticklabels([f'L{i}' for i in range(config.n_layer)])
    ax1.set_xlabel('Layer', fontsize=12)
    ax1.set_title('Layer-wise Bit Configurations', fontsize=14, fontweight='bold')
    
    # Add text annotations
    for i in range(len(config_names)):
        for j in range(config.n_layer):
            text = ax1.text(j, i, int(config_matrix[i, j]),
                          ha="center", va="center", color="black", fontsize=10)
    
    plt.colorbar(im1, ax=ax1, label='Bits')
    
    # Plot 2: Performance metrics
    metrics_df = pd.DataFrame({
        'Configuration': config_names,
        'Accuracy': [results[name]['accuracy'] for name in config_names],
        'Avg Bits': [results[name]['effective_bits'] for name in config_names]
    })
    
    metrics_df = metrics_df.sort_values('Accuracy', ascending=False)
    
    x = np.arange(len(metrics_df))
    width = 0.35
    
    ax2_twin = ax2.twinx()
    
    bars1 = ax2.bar(x - width/2, metrics_df['Accuracy'], width, label='Accuracy', color='skyblue')
    bars2 = ax2_twin.bar(x + width/2, metrics_df['Avg Bits'], width, label='Avg Bits', color='coral')
    
    ax2.set_xlabel('Configuration', fontsize=12)
    ax2.set_ylabel('Accuracy', fontsize=12, color='skyblue')
    ax2_twin.set_ylabel('Average Bits', fontsize=12, color='coral')
    ax2.set_title('Performance Comparison', fontsize=14, fontweight='bold')
    ax2.set_xticks(x)
    ax2.set_xticklabels(metrics_df['Configuration'], rotation=45, ha='right')
    
    ax2.tick_params(axis='y', labelcolor='skyblue')
    ax2_twin.tick_params(axis='y', labelcolor='coral')
    
    ax2.legend(loc='upper left')
    ax2_twin.legend(loc='upper right')
    
    plt.tight_layout()
    plt.savefig('layer_configurations_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

plot_layer_configurations_heatmap(test_configs, results)

## 4. Optimal Configuration Search

In [None]:
# Search for optimal configuration under bit budget
bit_budgets = [6.0, 8.0, 10.0, 12.0]
optimal_configs = {}

for budget in bit_budgets:
    print(f"\nSearching optimal configuration for {budget}-bit budget...")
    optimal = evaluator.search_optimal_configuration(max_bits=budget)
    optimal_configs[budget] = optimal
    print(f"  Config: {optimal['config']}")
    print(f"  Accuracy: {optimal['accuracy']:.4f}")

# Visualize optimal configurations
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.ravel()

for idx, (budget, optimal) in enumerate(optimal_configs.items()):
    if optimal['config'] is not None:
        ax = axes[idx]
        ax.bar(range(len(optimal['config'])), optimal['config'], color='steelblue')
        ax.set_xlabel('Layer', fontsize=10)
        ax.set_ylabel('Bits', fontsize=10)
        ax.set_title(f'Optimal Config (Budget: {budget} bits)\nAccuracy: {optimal["accuracy"]:.4f}', 
                    fontsize=11, fontweight='bold')
        ax.set_ylim([0, 18])
        ax.set_xticks(range(len(optimal['config'])))
        ax.set_xticklabels([f'L{i}' for i in range(len(optimal['config']))])
        ax.axhline(y=budget, color='r', linestyle='--', alpha=0.5, label=f'Budget: {budget}')
        ax.legend()
        ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('optimal_configurations.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Training Strategy Comparison (Requirement 5)

In [None]:
# Load training statistics from different strategies
# This assumes you have already trained models with different strategies
import glob

training_stats_files = glob.glob('../*.json')
training_stats = {}

for file in training_stats_files:
    with open(file, 'r') as f:
        stats = json.load(f)
        strategy_name = 'cyclic' if 'cpt' in file else 'joint'
        training_stats[strategy_name] = stats

if training_stats:
    print(f"Loaded training statistics for: {list(training_stats.keys())}")
else:
    print("No training statistics found. Creating synthetic data for demonstration...")
    # Create synthetic data for demonstration
    training_stats = {
        'joint': {
            'iteration_losses': np.random.exponential(2, 100).cumsum() / np.arange(1, 101),
            'validation_losses': np.random.exponential(2, 20).cumsum() / np.arange(1, 21),
            'bit_width_usage': np.random.choice([4, 8, 16], 100).tolist()
        },
        'cyclic': {
            'iteration_losses': np.random.exponential(1.8, 100).cumsum() / np.arange(1, 101),
            'validation_losses': np.random.exponential(1.8, 20).cumsum() / np.arange(1, 21),
            'bit_width_history': np.tile([4, 8, 16, 8], 25).tolist()
        },
        'curriculum': {
            'iteration_losses': np.random.exponential(1.6, 100).cumsum() / np.arange(1, 101),
            'validation_losses': np.random.exponential(1.6, 20).cumsum() / np.arange(1, 21),
            'bit_width_usage': np.concatenate([np.full(30, 16), np.full(40, 8), np.full(30, 4)]).tolist()
        }
    }

In [None]:
# Plot training curves comparison
def plot_training_comparison(training_stats):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Training Loss Curves
    ax = axes[0, 0]
    for strategy, stats in training_stats.items():
        losses = stats.get('iteration_losses', [])
        if losses:
            ax.plot(losses[:100], label=strategy.capitalize(), linewidth=2)
    
    ax.set_xlabel('Iteration', fontsize=11)
    ax.set_ylabel('Loss', fontsize=11)
    ax.set_title('Training Loss Comparison', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 2: Validation Loss Curves
    ax = axes[0, 1]
    for strategy, stats in training_stats.items():
        val_losses = stats.get('validation_losses', [])
        if val_losses:
            ax.plot(val_losses, marker='o', label=strategy.capitalize(), linewidth=2)
    
    ax.set_xlabel('Validation Step', fontsize=11)
    ax.set_ylabel('Validation Loss', fontsize=11)
    ax.set_title('Validation Loss Comparison', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Plot 3: Bit-width Distribution
    ax = axes[1, 0]
    for strategy, stats in training_stats.items():
        bit_history = stats.get('bit_width_usage', stats.get('bit_width_history', []))
        if bit_history:
            unique, counts = np.unique(bit_history[:100], return_counts=True)
            x_pos = np.arange(len(unique)) + len(training_stats) * 0.2 * list(training_stats.keys()).index(strategy)
            ax.bar(x_pos, counts, width=0.2, label=strategy.capitalize())
    
    ax.set_xlabel('Bit Width', fontsize=11)
    ax.set_ylabel('Count', fontsize=11)
    ax.set_title('Bit-width Usage Distribution', fontsize=12, fontweight='bold')
    ax.set_xticks([0.2, 1.2, 2.2])
    ax.set_xticklabels(['4', '8', '16'])
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: Convergence Speed
    ax = axes[1, 1]
    convergence_data = []
    strategies_list = []
    
    for strategy, stats in training_stats.items():
        losses = stats.get('iteration_losses', [])
        if len(losses) > 10:
            # Find iteration where loss stabilizes (simplified)
            window_size = 10
            for i in range(window_size, len(losses)):
                window = losses[i-window_size:i]
                if np.std(window) < 0.1:  # Threshold for convergence
                    convergence_data.append(i)
                    strategies_list.append(strategy.capitalize())
                    break
            else:
                convergence_data.append(len(losses))
                strategies_list.append(strategy.capitalize())
    
    if convergence_data:
        colors = plt.cm.Set3(np.linspace(0, 1, len(convergence_data)))
        bars = ax.bar(strategies_list, convergence_data, color=colors)
        ax.set_ylabel('Iterations to Convergence', fontsize=11)
        ax.set_title('Convergence Speed Comparison', fontsize=12, fontweight='bold')
        ax.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar, val in zip(bars, convergence_data):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{int(val)}', ha='center', va='bottom')
    
    plt.tight_layout()
    plt.savefig('training_strategy_comparison.png', dpi=300, bbox_inches='tight')
    plt.show()

plot_training_comparison(training_stats)

## 6. Adversarial Robustness Testing (Requirement 6)

In [None]:
# Initialize adversarial evaluator
adv_evaluator = AdversarialEvaluator(model, tokenizer)

# Test adversarial robustness with different defenses
print("Testing adversarial robustness (this may take a while)...")

# Reduce samples for faster execution
test_samples_subset = list(test_loader)[:10]

robustness_results = {}

# Test fixed precision baselines
for bits in [4, 8, 16]:
    print(f"\nTesting fixed {bits}-bit precision...")
    model.set_global_precision(bits)
    robustness_results[f'fixed_{bits}'] = adv_evaluator._evaluate_attack_success_rate(
        test_samples_subset, max_samples=5
    )

# Test dynamic defenses
print("\nTesting random switching defense...")
robustness_results['random_switch'] = adv_evaluator._evaluate_random_switching(
    test_samples_subset, max_samples=5
)

print("\nTesting ensemble defense...")
robustness_results['ensemble'] = adv_evaluator._evaluate_ensemble_defense(
    test_samples_subset, max_samples=5
)

print("\nTesting adaptive precision defense...")
robustness_results['adaptive'] = adv_evaluator._evaluate_adaptive_precision(
    test_samples_subset, max_samples=5
)

print("\nAdversarial robustness testing complete!")

In [None]:
# Visualize adversarial robustness results
def plot_adversarial_robustness(robustness_results):
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Prepare data
    defense_methods = list(robustness_results.keys())
    attack_types = ['fgsm', 'pgd', 'hotflip']
    
    # Plot 1: Attack Success Rate by Defense Method
    ax = axes[0, 0]
    x = np.arange(len(defense_methods))
    width = 0.25
    
    for i, attack in enumerate(attack_types):
        success_rates = []
        for method in defense_methods:
            if isinstance(robustness_results[method], dict):
                success_rates.append(robustness_results[method].get(attack, 0))
            else:
                success_rates.append(0)
        
        ax.bar(x + i * width, success_rates, width, label=attack.upper())
    
    ax.set_xlabel('Defense Method', fontsize=11)
    ax.set_ylabel('Attack Success Rate', fontsize=11)
    ax.set_title('Attack Success Rates by Defense Method', fontsize=12, fontweight='bold')
    ax.set_xticks(x + width)
    ax.set_xticklabels(defense_methods, rotation=45, ha='right')
    ax.legend()
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 2: Average Defense Effectiveness
    ax = axes[0, 1]
    avg_success = []
    for method in defense_methods:
        if isinstance(robustness_results[method], dict):
            avg = np.mean(list(robustness_results[method].values()))
            avg_success.append(avg)
        else:
            avg_success.append(0)
    
    colors = ['red' if 'fixed' in m else 'green' for m in defense_methods]
    bars = ax.bar(defense_methods, avg_success, color=colors, alpha=0.7)
    
    ax.set_xlabel('Defense Method', fontsize=11)
    ax.set_ylabel('Average Attack Success Rate', fontsize=11)
    ax.set_title('Overall Defense Effectiveness', fontsize=12, fontweight='bold')
    ax.set_xticklabels(defense_methods, rotation=45, ha='right')
    ax.grid(True, alpha=0.3, axis='y')
    
    # Add value labels
    for bar, val in zip(bars, avg_success):
        height = bar.get_height()
        ax.text(bar.get_x() + bar.get_width()/2., height,
               f'{val:.3f}', ha='center', va='bottom')
    
    # Plot 3: Improvement over baseline
    ax = axes[1, 0]
    baseline = avg_success[defense_methods.index('fixed_8')] if 'fixed_8' in defense_methods else 0.5
    improvements = [(baseline - s) / baseline * 100 if baseline > 0 else 0 for s in avg_success]
    
    colors = ['red' if imp < 0 else 'green' for imp in improvements]
    bars = ax.bar(defense_methods, improvements, color=colors, alpha=0.7)
    
    ax.set_xlabel('Defense Method', fontsize=11)
    ax.set_ylabel('Improvement over 8-bit (%)', fontsize=11)
    ax.set_title('Robustness Improvement vs 8-bit Baseline', fontsize=12, fontweight='bold')
    ax.set_xticklabels(defense_methods, rotation=45, ha='right')
    ax.axhline(y=0, color='black', linestyle='-', linewidth=0.5)
    ax.grid(True, alpha=0.3, axis='y')
    
    # Plot 4: Attack-specific heatmap
    ax = axes[1, 1]
    
    # Create matrix for heatmap
    matrix = []
    for method in defense_methods:
        row = []
        for attack in attack_types:
            if isinstance(robustness_results[method], dict):
                row.append(robustness_results[method].get(attack, 0))
            else:
                row.append(0)
        matrix.append(row)
    
    matrix = np.array(matrix)
    im = ax.imshow(matrix, cmap='RdYlGn_r', aspect='auto', vmin=0, vmax=1)
    
    ax.set_xticks(range(len(attack_types)))
    ax.set_yticks(range(len(defense_methods)))
    ax.set_xticklabels(attack_types)
    ax.set_yticklabels(defense_methods)
    ax.set_xlabel('Attack Type', fontsize=11)
    ax.set_ylabel('Defense Method', fontsize=11)
    ax.set_title('Defense Effectiveness Heatmap', fontsize=12, fontweight='bold')
    
    # Add text annotations
    for i in range(len(defense_methods)):
        for j in range(len(attack_types)):
            text = ax.text(j, i, f'{matrix[i, j]:.2f}',
                          ha="center", va="center", color="black")
    
    plt.colorbar(im, ax=ax, label='Attack Success Rate')
    
    plt.tight_layout()
    plt.savefig('adversarial_robustness_analysis.png', dpi=300, bbox_inches='tight')
    plt.show()

plot_adversarial_robustness(robustness_results)

## 7. Comprehensive Summary and Report Generation

In [None]:
# Generate comprehensive report
def generate_final_report(results, robustness_results, pareto_names, optimal_configs):
    report = {
        'timestamp': datetime.now().isoformat(),
        'model_info': {
            'n_layers': config.n_layer,
            'n_embd': config.n_embd,
            'n_head': config.n_head,
            'bit_widths': config.bit_widths
        },
        'configuration_evaluation': results,
        'pareto_optimal': pareto_names,
        'optimal_under_budget': optimal_configs,
        'adversarial_robustness': robustness_results,
        'key_findings': {}
    }
    
    # Calculate key findings
    best_config = max(results.items(), key=lambda x: x[1]['accuracy'])
    report['key_findings']['best_configuration'] = {
        'name': best_config[0],
        'accuracy': best_config[1]['accuracy'],
        'avg_bits': best_config[1]['effective_bits']
    }
    
    # Calculate robustness improvement
    if 'fixed_8' in robustness_results and 'random_switch' in robustness_results:
        baseline = np.mean(list(robustness_results['fixed_8'].values())) if isinstance(robustness_results['fixed_8'], dict) else 0
        dynamic = np.mean(list(robustness_results['random_switch'].values())) if isinstance(robustness_results['random_switch'], dict) else 0
        improvement = (baseline - dynamic) / baseline * 100 if baseline > 0 else 0
        report['key_findings']['robustness_improvement'] = f"{improvement:.1f}%"
    
    # Save report
    with open('comprehensive_evaluation_report.json', 'w') as f:
        json.dump(report, f, indent=2, default=str)
    
    return report

final_report = generate_final_report(results, robustness_results, pareto_names, optimal_configs)

# Display key findings
print("\n" + "="*60)
print("COMPREHENSIVE EVALUATION SUMMARY")
print("="*60)
print(f"\nTimestamp: {final_report['timestamp']}")
print(f"\nModel Configuration:")
for key, value in final_report['model_info'].items():
    print(f"  {key}: {value}")

print(f"\nKey Findings:")
print(f"  Best Configuration: {final_report['key_findings']['best_configuration']['name']}")
print(f"    - Accuracy: {final_report['key_findings']['best_configuration']['accuracy']:.4f}")
print(f"    - Average Bits: {final_report['key_findings']['best_configuration']['avg_bits']:.1f}")

print(f"\n  Pareto Optimal Configurations: {', '.join(pareto_names)}")

if 'robustness_improvement' in final_report['key_findings']:
    print(f"\n  Dynamic Quantization Robustness Improvement: {final_report['key_findings']['robustness_improvement']}")

print(f"\n  Report saved to: comprehensive_evaluation_report.json")
print("="*60)

## 8. Final Visualization Dashboard

In [None]:
# Create a comprehensive dashboard
def create_evaluation_dashboard(results, robustness_results, pareto_names):
    fig = plt.figure(figsize=(20, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # 1. Pareto Frontier
    ax1 = fig.add_subplot(gs[0, 0])
    all_bits = [r['effective_bits'] for r in results.values()]
    all_accs = [r['accuracy'] for r in results.values()]
    ax1.scatter(all_bits, all_accs, s=50, alpha=0.5)
    pareto_indices = [i for i, name in enumerate(results.keys()) if name in pareto_names]
    pareto_bits = [all_bits[i] for i in pareto_indices]
    pareto_accs = [all_accs[i] for i in pareto_indices]
    ax1.scatter(pareto_bits, pareto_accs, s=100, c='red', marker='*')
    ax1.set_xlabel('Average Bits')
    ax1.set_ylabel('Accuracy')
    ax1.set_title('Pareto Frontier', fontweight='bold')
    ax1.grid(True, alpha=0.3)
    
    # 2. Configuration Performance Bar Chart
    ax2 = fig.add_subplot(gs[0, 1:3])
    config_names = list(results.keys())[:8]
    accuracies = [results[name]['accuracy'] for name in config_names]
    colors = ['red' if name in pareto_names else 'steelblue' for name in config_names]
    ax2.bar(config_names, accuracies, color=colors, alpha=0.7)
    ax2.set_xlabel('Configuration')
    ax2.set_ylabel('Accuracy')
    ax2.set_title('Configuration Performance Comparison', fontweight='bold')
    ax2.set_xticklabels(config_names, rotation=45, ha='right')
    ax2.grid(True, alpha=0.3, axis='y')
    
    # 3. Layer Configuration Heatmap
    ax3 = fig.add_subplot(gs[1, :])
    config_matrix = []
    for name in config_names[:6]:
        if name in test_configs:
            config_matrix.append(test_configs[name])
    if config_matrix:
        config_matrix = np.array(config_matrix)
        im = ax3.imshow(config_matrix, cmap='YlOrRd', aspect='auto', vmin=4, vmax=16)
        ax3.set_yticks(range(len(config_matrix)))
        ax3.set_yticklabels(config_names[:len(config_matrix)])
        ax3.set_xticks(range(config.n_layer))
        ax3.set_xticklabels([f'L{i}' for i in range(config.n_layer)])
        ax3.set_xlabel('Layer')
        ax3.set_title('Layer-wise Bit Configuration', fontweight='bold')
        plt.colorbar(im, ax=ax3, label='Bits', fraction=0.046, pad=0.04)
    
    # 4. Adversarial Robustness Summary
    ax4 = fig.add_subplot(gs[2, 0:2])
    defense_methods = list(robustness_results.keys())[:5]
    avg_success = []
    for method in defense_methods:
        if isinstance(robustness_results[method], dict):
            avg = np.mean(list(robustness_results[method].values()))
            avg_success.append(avg)
        else:
            avg_success.append(0)
    
    colors = ['red' if 'fixed' in m else 'green' for m in defense_methods]
    ax4.bar(defense_methods, avg_success, color=colors, alpha=0.7)
    ax4.set_xlabel('Defense Method')
    ax4.set_ylabel('Attack Success Rate')
    ax4.set_title('Adversarial Defense Effectiveness', fontweight='bold')
    ax4.set_xticklabels(defense_methods, rotation=45, ha='right')
    ax4.grid(True, alpha=0.3, axis='y')
    
    # 5. Summary Statistics
    ax5 = fig.add_subplot(gs[2, 2])
    ax5.axis('off')
    
    summary_text = f"""Summary Statistics
    
    Total Configurations: {len(results)}
    Pareto Optimal: {len(pareto_names)}
    
    Best Accuracy: {max(r['accuracy'] for r in results.values()):.4f}
    Most Efficient: {min(r['effective_bits'] for r in results.values()):.1f} bits
    
    Defense Methods: {len(robustness_results)}
    Best Defense: {min(defense_methods, key=lambda x: avg_success[defense_methods.index(x)])}
    """
    
    ax5.text(0.1, 0.5, summary_text, fontsize=10, verticalalignment='center',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    plt.suptitle('Switchable Precision GPT-2 Evaluation Dashboard', fontsize=16, fontweight='bold', y=1.02)
    plt.savefig('evaluation_dashboard.png', dpi=300, bbox_inches='tight')
    plt.show()

create_evaluation_dashboard(results, robustness_results, pareto_names)

## Conclusion

This comprehensive evaluation has:
1. ✅ Evaluated multiple layer-wise bit configurations
2. ✅ Identified Pareto optimal configurations
3. ✅ Found optimal configurations under bit budgets
4. ✅ Compared different training strategies
5. ✅ Tested adversarial robustness with dynamic quantization
6. ✅ Generated comprehensive visualizations and reports

The results demonstrate that switchable precision with dynamic quantization provides both efficiency and robustness benefits compared to fixed precision approaches.