# Loss Per Bit-Width Analysis for QAT Training

This notebook analyzes the training loss patterns for different bit-widths during Quantization-Aware Training (QAT) and Cyclic Precision Training (CPT).

## Contents
1. Load and inspect training data
2. Loss per bit-width analysis
3. Bit-width switching patterns
4. Performance comparison across bit-widths
5. Optimal bit-width determination
6. Convergence analysis per bit-width

In [None]:
# Import required libraries
import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.signal import savgol_filter
import glob
import os
import warnings
warnings.filterwarnings('ignore')

# Set style for better visualizations
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette('husl')
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 11
plt.rcParams['axes.titlesize'] = 14
plt.rcParams['axes.labelsize'] = 12

## 1. Load Training Statistics

In [None]:
def load_all_training_stats():
    """Load all available training statistics files."""
    stats_files = {}
    
    # Find QAT training stats
    qat_files = glob.glob('qat_training_stats_*.json')
    if qat_files:
        latest_qat = max(qat_files, key=os.path.getctime)
        try:
            with open(latest_qat, 'r') as f:
                stats_files['QAT'] = json.load(f)
                print(f"✅ Loaded QAT stats from: {latest_qat}")
        except Exception as e:
            print(f"❌ Error loading QAT stats: {e}")
    
    # Find CPT training stats
    cpt_files = glob.glob('cpt_training_stats_*.json')
    if not cpt_files:
        cpt_files = glob.glob('part2_cyclic_precision/cpt_training_stats_*.json')
    
    if cpt_files:
        latest_cpt = max(cpt_files, key=os.path.getctime)
        try:
            with open(latest_cpt, 'r') as f:
                stats_files['CPT'] = json.load(f)
                print(f"✅ Loaded CPT stats from: {latest_cpt}")
        except Exception as e:
            print(f"❌ Error loading CPT stats: {e}")
    
    return stats_files

# Load all available statistics
all_stats = load_all_training_stats()

if all_stats:
    print(f"\n📊 Loaded {len(all_stats)} training statistics files")
    for name, stats in all_stats.items():
        print(f"\n{name} Statistics:")
        print(f"  - Keys: {list(stats.keys())[:10]}..." if len(stats.keys()) > 10 else f"  - Keys: {list(stats.keys())}")
        if 'iteration_losses' in stats:
            print(f"  - Iterations: {len(stats['iteration_losses'])}")
        if 'losses_per_bit' in stats:
            print(f"  - Bit-widths tracked: {list(stats['losses_per_bit'].keys())}")

## 2. Detailed Loss Per Bit-Width Analysis

In [None]:
def analyze_losses_per_bit(stats, name="Training"):
    """Analyze losses for each bit-width."""
    
    # Check if we have losses_per_bit data
    if 'losses_per_bit' in stats and stats['losses_per_bit']:
        losses_per_bit = stats['losses_per_bit']
        
        # Convert to proper format (handle both string and int keys)
        losses_dict = {}
        for key, value in losses_per_bit.items():
            bit_width = int(key) if isinstance(key, str) else key
            if value:  # Only include non-empty lists
                losses_dict[bit_width] = value
        
        if losses_dict:
            print(f"\n📊 {name} - Loss Statistics per Bit-Width:")
            print("="*60)
            
            bit_stats = []
            for bit_width in sorted(losses_dict.keys()):
                losses = losses_dict[bit_width]
                if losses:
                    stats_row = {
                        'Bit Width': bit_width,
                        'Samples': len(losses),
                        'Mean Loss': np.mean(losses),
                        'Std Dev': np.std(losses),
                        'Min Loss': np.min(losses),
                        'Max Loss': np.max(losses),
                        'Final Loss': losses[-1] if losses else np.nan
                    }
                    bit_stats.append(stats_row)
            
            df_stats = pd.DataFrame(bit_stats)
            display(df_stats.round(4))
            
            return losses_dict, df_stats
    
    # Alternative: Calculate from bit_width_usage and iteration_losses
    elif 'bit_width_usage' in stats and 'iteration_losses' in stats:
        bit_widths = stats['bit_width_usage']
        losses = stats['iteration_losses']
        
        # Group losses by bit-width
        losses_dict = {}
        for i, (bit, loss) in enumerate(zip(bit_widths, losses)):
            if bit is not None:
                bit = int(bit)
                if bit not in losses_dict:
                    losses_dict[bit] = []
                losses_dict[bit].append(loss)
        
        if losses_dict:
            print(f"\n📊 {name} - Loss Statistics per Bit-Width (Calculated):")
            print("="*60)
            
            bit_stats = []
            for bit_width in sorted(losses_dict.keys()):
                losses = losses_dict[bit_width]
                stats_row = {
                    'Bit Width': bit_width,
                    'Samples': len(losses),
                    'Mean Loss': np.mean(losses),
                    'Std Dev': np.std(losses),
                    'Min Loss': np.min(losses),
                    'Max Loss': np.max(losses),
                    'Final Loss': losses[-1]
                }
                bit_stats.append(stats_row)
            
            df_stats = pd.DataFrame(bit_stats)
            display(df_stats.round(4))
            
            return losses_dict, df_stats
    
    return None, None

# Analyze losses for each training type
results = {}
for name, stats in all_stats.items():
    losses_dict, df_stats = analyze_losses_per_bit(stats, name)
    if losses_dict:
        results[name] = {'losses': losses_dict, 'stats': df_stats}

## 3. Visualize Loss Distributions per Bit-Width

In [None]:
def plot_loss_distributions(results):
    """Create comprehensive visualizations for loss per bit-width."""
    
    for training_name, data in results.items():
        losses_dict = data['losses']
        
        # Determine subplot layout
        n_bits = len(losses_dict)
        fig = plt.figure(figsize=(16, 10))
        fig.suptitle(f'{training_name} - Loss Analysis per Bit-Width', fontsize=16, fontweight='bold')
        
        # Create 4 subplots
        gs = fig.add_gridspec(2, 2, hspace=0.3, wspace=0.3)
        
        # 1. Box plot comparison
        ax1 = fig.add_subplot(gs[0, 0])
        box_data = [losses_dict[bit] for bit in sorted(losses_dict.keys())]
        box_labels = [f"{bit}-bit" for bit in sorted(losses_dict.keys())]
        bp = ax1.boxplot(box_data, labels=box_labels, patch_artist=True)
        
        # Color boxes by bit-width
        colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(box_data)))
        for patch, color in zip(bp['boxes'], colors):
            patch.set_facecolor(color)
        
        ax1.set_xlabel('Bit Width')
        ax1.set_ylabel('Loss')
        ax1.set_title('Loss Distribution Comparison')
        ax1.grid(True, alpha=0.3)
        
        # 2. Loss trajectory per bit-width
        ax2 = fig.add_subplot(gs[0, 1])
        for bit in sorted(losses_dict.keys()):
            losses = losses_dict[bit]
            iterations = range(len(losses))
            ax2.plot(iterations, losses, label=f'{bit}-bit', alpha=0.7, linewidth=2)
        
        ax2.set_xlabel('Sample Index')
        ax2.set_ylabel('Loss')
        ax2.set_title('Loss Trajectory per Bit-Width')
        ax2.legend(loc='best')
        ax2.grid(True, alpha=0.3)
        
        # 3. Mean loss comparison with error bars
        ax3 = fig.add_subplot(gs[1, 0])
        bit_widths = sorted(losses_dict.keys())
        means = [np.mean(losses_dict[bit]) for bit in bit_widths]
        stds = [np.std(losses_dict[bit]) for bit in bit_widths]
        
        bars = ax3.bar(bit_widths, means, yerr=stds, capsize=5, 
                       color=colors, edgecolor='black', linewidth=1.5)
        ax3.set_xlabel('Bit Width')
        ax3.set_ylabel('Mean Loss')
        ax3.set_title('Average Loss per Bit-Width (with std dev)')
        ax3.set_xticks(bit_widths)
        ax3.grid(True, alpha=0.3, axis='y')
        
        # Add value labels on bars
        for bar, mean, std in zip(bars, means, stds):
            height = bar.get_height()
            ax3.text(bar.get_x() + bar.get_width()/2., height + std,
                    f'{mean:.3f}', ha='center', va='bottom', fontsize=10)
        
        # 4. Convergence rate per bit-width
        ax4 = fig.add_subplot(gs[1, 1])
        for bit in sorted(losses_dict.keys()):
            losses = losses_dict[bit]
            if len(losses) > 10:
                # Calculate rolling mean for smoother convergence curve
                window = min(10, len(losses) // 5)
                rolling_mean = pd.Series(losses).rolling(window=window, center=True).mean()
                
                # Calculate convergence rate (negative of loss derivative)
                convergence_rate = -np.gradient(rolling_mean.dropna())
                x_axis = range(len(convergence_rate))
                
                ax4.plot(x_axis, convergence_rate, label=f'{bit}-bit', alpha=0.7, linewidth=2)
        
        ax4.set_xlabel('Sample Index')
        ax4.set_ylabel('Convergence Rate (Loss Reduction)')
        ax4.set_title('Convergence Rate per Bit-Width')
        ax4.legend(loc='best')
        ax4.grid(True, alpha=0.3)
        ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)
        
        plt.tight_layout()
        plt.show()

# Create visualizations
if results:
    plot_loss_distributions(results)

## 4. Bit-Width Switching Pattern Analysis

In [None]:
def analyze_bit_switching_patterns(stats, name="Training"):
    """Analyze bit-width switching patterns and their impact on loss."""
    
    if 'bit_width_usage' not in stats or 'iteration_losses' not in stats:
        print(f"No bit-width usage data available for {name}")
        return
    
    bit_widths = stats['bit_width_usage']
    losses = stats['iteration_losses']
    
    # Ensure equal length
    min_len = min(len(bit_widths), len(losses))
    bit_widths = bit_widths[:min_len]
    losses = losses[:min_len]
    
    fig, axes = plt.subplots(3, 2, figsize=(15, 12))
    fig.suptitle(f'{name} - Bit-Width Switching Pattern Analysis', fontsize=16, fontweight='bold')
    
    # 1. Bit-width usage over time
    axes[0, 0].plot(range(len(bit_widths)), bit_widths, 'g-', linewidth=1.5, alpha=0.8)
    axes[0, 0].set_xlabel('Iteration')
    axes[0, 0].set_ylabel('Bit Width')
    axes[0, 0].set_title('Bit-Width Schedule Over Training')
    axes[0, 0].set_yticks(sorted(set([b for b in bit_widths if b is not None])))
    axes[0, 0].grid(True, alpha=0.3)
    
    # 2. Loss colored by bit-width
    scatter = axes[0, 1].scatter(range(len(losses)), losses, 
                                 c=bit_widths, cmap='viridis', 
                                 s=10, alpha=0.6)
    axes[0, 1].set_xlabel('Iteration')
    axes[0, 1].set_ylabel('Loss')
    axes[0, 1].set_title('Training Loss (colored by bit-width)')
    axes[0, 1].grid(True, alpha=0.3)
    plt.colorbar(scatter, ax=axes[0, 1], label='Bit Width')
    
    # 3. Transition impact analysis
    transitions = []
    transition_losses = []
    for i in range(1, len(bit_widths)):
        if bit_widths[i] != bit_widths[i-1] and bit_widths[i] is not None and bit_widths[i-1] is not None:
            transitions.append(i)
            if i < len(losses):
                # Calculate loss change during transition
                loss_change = losses[i] - losses[i-1] if i > 0 else 0
                transition_losses.append(loss_change)
    
    if transitions:
        axes[1, 0].scatter(transitions, transition_losses, alpha=0.6, s=30)
        axes[1, 0].axhline(y=0, color='red', linestyle='--', alpha=0.5)
        axes[1, 0].set_xlabel('Iteration')
        axes[1, 0].set_ylabel('Loss Change at Transition')
        axes[1, 0].set_title(f'Loss Impact of Bit-Width Transitions (n={len(transitions)})')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Add average line
        avg_impact = np.mean(transition_losses)
        axes[1, 0].axhline(y=avg_impact, color='blue', linestyle=':', 
                          label=f'Avg: {avg_impact:.4f}')
        axes[1, 0].legend()
    
    # 4. Bit-width frequency histogram
    unique_bits = [b for b in bit_widths if b is not None]
    axes[1, 1].hist(unique_bits, bins=len(set(unique_bits)), 
                   edgecolor='black', alpha=0.7, color='orange')
    axes[1, 1].set_xlabel('Bit Width')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].set_title('Bit-Width Usage Distribution')
    axes[1, 1].set_xticks(sorted(set(unique_bits)))
    axes[1, 1].grid(True, alpha=0.3, axis='y')
    
    # 5. Loss improvement per bit-width
    bit_improvements = {}
    for bit in set(unique_bits):
        bit_indices = [i for i, b in enumerate(bit_widths) if b == bit]
        if len(bit_indices) > 1:
            bit_losses = [losses[i] for i in bit_indices if i < len(losses)]
            if len(bit_losses) > 1:
                improvement = (bit_losses[0] - bit_losses[-1]) / bit_losses[0] * 100
                bit_improvements[bit] = improvement
    
    if bit_improvements:
        bits = sorted(bit_improvements.keys())
        improvements = [bit_improvements[b] for b in bits]
        
        bars = axes[2, 0].bar(bits, improvements, alpha=0.7, 
                             color=['green' if i > 0 else 'red' for i in improvements],
                             edgecolor='black')
        axes[2, 0].set_xlabel('Bit Width')
        axes[2, 0].set_ylabel('Loss Improvement (%)')
        axes[2, 0].set_title('Loss Improvement Within Each Bit-Width')
        axes[2, 0].set_xticks(bits)
        axes[2, 0].grid(True, alpha=0.3, axis='y')
        axes[2, 0].axhline(y=0, color='black', linestyle='-', alpha=0.5)
        
        # Add value labels
        for bar, val in zip(bars, improvements):
            height = bar.get_height()
            axes[2, 0].text(bar.get_x() + bar.get_width()/2., height,
                           f'{val:.1f}%', ha='center', 
                           va='bottom' if val > 0 else 'top', fontsize=9)
    
    # 6. Cumulative time spent at each bit-width
    bit_counts = {}
    for bit in unique_bits:
        bit_counts[bit] = bit_counts.get(bit, 0) + 1
    
    if bit_counts:
        bits = sorted(bit_counts.keys())
        counts = [bit_counts[b] for b in bits]
        percentages = [c / sum(counts) * 100 for c in counts]
        
        wedges, texts, autotexts = axes[2, 1].pie(percentages, 
                                                   labels=[f'{b}-bit' for b in bits],
                                                   autopct='%1.1f%%',
                                                   startangle=90,
                                                   colors=plt.cm.viridis(np.linspace(0.3, 0.9, len(bits))))
        axes[2, 1].set_title('Time Distribution Across Bit-Widths')
    
    plt.tight_layout()
    plt.show()
    
    # Print statistics
    print(f"\n📊 {name} - Bit-Width Switching Statistics:")
    print("="*60)
    print(f"Total transitions: {len(transitions)}")
    if transitions:
        print(f"Average loss change at transition: {np.mean(transition_losses):.6f}")
        print(f"Std dev of loss change: {np.std(transition_losses):.6f}")
        positive_transitions = sum(1 for t in transition_losses if t > 0)
        print(f"Transitions causing loss increase: {positive_transitions}/{len(transitions)} ({positive_transitions/len(transitions)*100:.1f}%)")

# Analyze switching patterns for each training type
for name, stats in all_stats.items():
    analyze_bit_switching_patterns(stats, name)

## 5. Optimal Bit-Width Determination

In [None]:
def determine_optimal_bitwidth(results):
    """Determine the optimal bit-width based on loss and efficiency metrics."""
    
    for training_name, data in results.items():
        print(f"\n🎯 {training_name} - Optimal Bit-Width Analysis")
        print("="*60)
        
        df_stats = data['stats']
        
        # Calculate efficiency metrics
        df_stats['Efficiency'] = 1 / (df_stats['Mean Loss'] * df_stats['Bit Width'])
        df_stats['Stability'] = 1 / (1 + df_stats['Std Dev'])
        df_stats['Performance'] = 1 / df_stats['Mean Loss']
        
        # Normalize metrics to [0, 1]
        for col in ['Efficiency', 'Stability', 'Performance']:
            df_stats[f'{col}_Norm'] = (df_stats[col] - df_stats[col].min()) / (df_stats[col].max() - df_stats[col].min())
        
        # Calculate composite score (weighted average)
        weights = {'Efficiency': 0.4, 'Stability': 0.3, 'Performance': 0.3}
        df_stats['Composite_Score'] = (
            weights['Efficiency'] * df_stats['Efficiency_Norm'] +
            weights['Stability'] * df_stats['Stability_Norm'] +
            weights['Performance'] * df_stats['Performance_Norm']
        )
        
        # Sort by composite score
        df_sorted = df_stats.sort_values('Composite_Score', ascending=False)
        
        print("\n📈 Ranking by Composite Score:")
        display(df_sorted[['Bit Width', 'Mean Loss', 'Std Dev', 'Efficiency_Norm', 
                          'Stability_Norm', 'Performance_Norm', 'Composite_Score']].round(4))
        
        # Determine optimal bit-width
        optimal_bit = df_sorted.iloc[0]['Bit Width']
        print(f"\n✅ Optimal Bit-Width: {int(optimal_bit)}-bit")
        print(f"   - Mean Loss: {df_sorted.iloc[0]['Mean Loss']:.4f}")
        print(f"   - Stability (1/std): {df_sorted.iloc[0]['Stability']:.4f}")
        print(f"   - Composite Score: {df_sorted.iloc[0]['Composite_Score']:.4f}")
        
        # Create visualization
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        fig.suptitle(f'{training_name} - Bit-Width Optimization Metrics', fontsize=14, fontweight='bold')
        
        # Plot 1: Mean Loss vs Bit Width
        axes[0].plot(df_stats['Bit Width'], df_stats['Mean Loss'], 'o-', markersize=8, linewidth=2)
        axes[0].scatter(optimal_bit, df_sorted.iloc[0]['Mean Loss'], 
                       color='red', s=100, marker='*', zorder=5, label='Optimal')
        axes[0].set_xlabel('Bit Width')
        axes[0].set_ylabel('Mean Loss')
        axes[0].set_title('Mean Loss vs Bit Width')
        axes[0].grid(True, alpha=0.3)
        axes[0].legend()
        
        # Plot 2: Normalized Metrics Comparison
        x = np.arange(len(df_stats))
        width = 0.25
        
        axes[1].bar(x - width, df_stats['Efficiency_Norm'], width, label='Efficiency', alpha=0.8)
        axes[1].bar(x, df_stats['Stability_Norm'], width, label='Stability', alpha=0.8)
        axes[1].bar(x + width, df_stats['Performance_Norm'], width, label='Performance', alpha=0.8)
        
        axes[1].set_xlabel('Bit Width')
        axes[1].set_ylabel('Normalized Score')
        axes[1].set_title('Normalized Metrics Comparison')
        axes[1].set_xticks(x)
        axes[1].set_xticklabels(df_stats['Bit Width'].astype(int))
        axes[1].legend()
        axes[1].grid(True, alpha=0.3, axis='y')
        
        # Plot 3: Composite Score
        bars = axes[2].bar(df_stats['Bit Width'], df_stats['Composite_Score'], 
                          color='skyblue', edgecolor='black', linewidth=1.5)
        
        # Highlight optimal
        optimal_idx = df_stats[df_stats['Bit Width'] == optimal_bit].index[0]
        bars[optimal_idx].set_color('gold')
        
        axes[2].set_xlabel('Bit Width')
        axes[2].set_ylabel('Composite Score')
        axes[2].set_title('Composite Score (Higher is Better)')
        axes[2].grid(True, alpha=0.3, axis='y')
        
        # Add value labels
        for bar, score in zip(bars, df_stats['Composite_Score']):
            height = bar.get_height()
            axes[2].text(bar.get_x() + bar.get_width()/2., height,
                        f'{score:.3f}', ha='center', va='bottom', fontsize=9)
        
        plt.tight_layout()
        plt.show()

# Determine optimal bit-width
if results:
    determine_optimal_bitwidth(results)

## 6. Convergence Analysis per Bit-Width

In [None]:
def analyze_convergence_per_bitwidth(results):
    """Analyze convergence characteristics for each bit-width."""
    
    for training_name, data in results.items():
        losses_dict = data['losses']
        
        print(f"\n📉 {training_name} - Convergence Analysis per Bit-Width")
        print("="*60)
        
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        fig.suptitle(f'{training_name} - Convergence Characteristics', fontsize=14, fontweight='bold')
        
        convergence_stats = []
        
        for bit_width in sorted(losses_dict.keys()):
            losses = losses_dict[bit_width]
            
            if len(losses) > 10:
                # Calculate convergence metrics
                initial_loss = np.mean(losses[:5])
                final_loss = np.mean(losses[-5:])
                improvement = (initial_loss - final_loss) / initial_loss * 100
                
                # Estimate convergence rate (exponential decay fit)
                x = np.arange(len(losses))
                y = losses
                
                # Simple linear regression on log scale for exponential decay
                try:
                    # Avoid log of negative or zero values
                    positive_losses = np.array(losses) - np.min(losses) + 0.001
                    z = np.polyfit(x, np.log(positive_losses), 1)
                    decay_rate = -z[0]
                except:
                    decay_rate = 0
                
                # Calculate stability (coefficient of variation of last 25%)
                last_quarter = losses[-(len(losses)//4):]
                stability = np.std(last_quarter) / np.mean(last_quarter) if np.mean(last_quarter) > 0 else np.inf
                
                convergence_stats.append({
                    'Bit Width': bit_width,
                    'Initial Loss': initial_loss,
                    'Final Loss': final_loss,
                    'Improvement (%)': improvement,
                    'Decay Rate': decay_rate,
                    'Stability (CV)': stability,
                    'Samples': len(losses)
                })
                
                # Plot loss trajectory with smoothing
                axes[0, 0].plot(range(len(losses)), losses, alpha=0.3, label=f'{bit_width}-bit')
                
                # Add smoothed line
                if len(losses) > 20:
                    window = min(21, len(losses) // 5)
                    if window % 2 == 0:
                        window += 1
                    smoothed = savgol_filter(losses, window, 3)
                    axes[0, 0].plot(range(len(smoothed)), smoothed, linewidth=2, label=f'{bit_width}-bit (smooth)')
        
        axes[0, 0].set_xlabel('Sample Index')
        axes[0, 0].set_ylabel('Loss')
        axes[0, 0].set_title('Loss Trajectories with Smoothing')
        axes[0, 0].legend(loc='best', ncol=2)
        axes[0, 0].grid(True, alpha=0.3)
        
        if convergence_stats:
            df_conv = pd.DataFrame(convergence_stats)
            
            # Plot 2: Improvement percentage
            axes[0, 1].bar(df_conv['Bit Width'], df_conv['Improvement (%)'], 
                          color='green', alpha=0.7, edgecolor='black')
            axes[0, 1].set_xlabel('Bit Width')
            axes[0, 1].set_ylabel('Improvement (%)')
            axes[0, 1].set_title('Loss Improvement Percentage')
            axes[0, 1].grid(True, alpha=0.3, axis='y')
            
            # Plot 3: Decay rate
            axes[1, 0].plot(df_conv['Bit Width'], df_conv['Decay Rate'], 
                           'o-', markersize=8, linewidth=2, color='blue')
            axes[1, 0].set_xlabel('Bit Width')
            axes[1, 0].set_ylabel('Decay Rate')
            axes[1, 0].set_title('Convergence Speed (Decay Rate)')
            axes[1, 0].grid(True, alpha=0.3)
            
            # Plot 4: Final loss vs stability
            scatter = axes[1, 1].scatter(df_conv['Final Loss'], df_conv['Stability (CV)'], 
                                        s=100, c=df_conv['Bit Width'], cmap='viridis', 
                                        edgecolors='black', linewidth=1.5)
            axes[1, 1].set_xlabel('Final Loss')
            axes[1, 1].set_ylabel('Stability (CV)')
            axes[1, 1].set_title('Final Loss vs Stability Trade-off')
            axes[1, 1].grid(True, alpha=0.3)
            
            # Add colorbar
            cbar = plt.colorbar(scatter, ax=axes[1, 1])
            cbar.set_label('Bit Width')
            
            # Add annotations for each point
            for idx, row in df_conv.iterrows():
                axes[1, 1].annotate(f"{int(row['Bit Width'])}b", 
                                   (row['Final Loss'], row['Stability (CV)']),
                                   xytext=(5, 5), textcoords='offset points', fontsize=8)
            
            plt.tight_layout()
            plt.show()
            
            # Display convergence statistics table
            print("\n📊 Convergence Statistics:")
            display(df_conv.round(4))
            
            # Find best configurations
            print("\n🏆 Best Configurations:")
            print(f"  - Fastest Convergence: {df_conv.loc[df_conv['Decay Rate'].idxmax(), 'Bit Width']:.0f}-bit")
            print(f"  - Best Final Loss: {df_conv.loc[df_conv['Final Loss'].idxmin(), 'Bit Width']:.0f}-bit")
            print(f"  - Most Stable: {df_conv.loc[df_conv['Stability (CV)'].idxmin(), 'Bit Width']:.0f}-bit")
            print(f"  - Highest Improvement: {df_conv.loc[df_conv['Improvement (%)'].idxmax(), 'Bit Width']:.0f}-bit")

# Analyze convergence
if results:
    analyze_convergence_per_bitwidth(results)

## 7. Summary and Recommendations

In [None]:
def generate_summary_report(all_stats, results):
    """Generate a comprehensive summary report."""
    
    print("="*70)
    print("📊 COMPREHENSIVE BIT-WIDTH ANALYSIS SUMMARY")
    print("="*70)
    
    for training_name in all_stats.keys():
        print(f"\n{'='*50}")
        print(f"📈 {training_name} Training Summary")
        print(f"{'='*50}")
        
        stats = all_stats[training_name]
        
        # Basic statistics
        if 'iteration_losses' in stats:
            losses = stats['iteration_losses']
            print(f"\n📊 Overall Statistics:")
            print(f"  - Total iterations: {len(losses)}")
            print(f"  - Final loss: {losses[-1]:.4f}")
            print(f"  - Best loss: {min(losses):.4f}")
            print(f"  - Average loss: {np.mean(losses):.4f}")
        
        # Bit-width specific summary
        if training_name in results:
            df_stats = results[training_name]['stats']
            
            print(f"\n🔢 Bit-Width Performance Summary:")
            for _, row in df_stats.iterrows():
                print(f"  {int(row['Bit Width'])}-bit:")
                print(f"    - Mean Loss: {row['Mean Loss']:.4f}")
                print(f"    - Std Dev: {row['Std Dev']:.4f}")
                print(f"    - Min Loss: {row['Min Loss']:.4f}")
                print(f"    - Samples: {int(row['Samples'])}")
        
        # Training insights
        if 'bit_width_usage' in stats:
            bit_widths = [b for b in stats['bit_width_usage'] if b is not None]
            if bit_widths:
                print(f"\n💡 Training Insights:")
                print(f"  - Bit-widths used: {sorted(set(bit_widths))}")
                print(f"  - Average bit-width: {np.mean(bit_widths):.2f}")
                print(f"  - Most frequent bit-width: {max(set(bit_widths), key=bit_widths.count)}")
    
    print(f"\n{'='*70}")
    print("💡 KEY RECOMMENDATIONS")
    print(f"{'='*70}")
    
    print("""
1. **Optimal Bit-Width Selection**:
   - For best accuracy: Use 16-bit precision
   - For best efficiency: Use 4-bit or 8-bit depending on accuracy requirements
   - For balanced performance: 8-bit provides good trade-off

2. **Training Strategy**:
   - Cyclic precision training can help explore different bit-widths
   - Start with higher precision and gradually reduce for better stability
   - Monitor loss spikes during bit-width transitions

3. **Convergence Optimization**:
   - Lower bit-widths may require more iterations to converge
   - Consider adaptive learning rates based on current bit-width
   - Use warm-up periods after bit-width transitions

4. **Production Deployment**:
   - Test the selected bit-width on validation data
   - Consider hardware constraints when selecting bit-width
   - Profile memory and computation savings for each bit-width
    """)
    
    print(f"\n{'='*70}")
    print("✅ Analysis Complete!")
    print(f"{'='*70}")

# Generate final summary
if all_stats:
    generate_summary_report(all_stats, results)