# Lab 3.2.1: Data Type Exploration - Solutions

This notebook contains solutions for all exercises in Lab 3.2.1.

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt

# Import utilities
import sys
sys.path.append('..')
from scripts import symmetric_quantize, asymmetric_quantize, dequantize

## Exercise 1 Solution: Custom Precision Analyzer

Create a function that analyzes how different data types affect a neural network layer.

In [None]:
def analyze_layer_precision(weights: np.ndarray, activations: np.ndarray, dtype: str) -> dict:
    """
    Analyze precision loss when performing layer computation in different dtypes.
    
    Args:
        weights: Weight matrix (out_features, in_features)
        activations: Input activations (batch, in_features)
        dtype: Target dtype ('fp32', 'fp16', 'bf16', 'int8')
        
    Returns:
        Dictionary with analysis metrics
    """
    # Reference output in FP32
    ref_output = activations @ weights.T
    
    if dtype == 'fp32':
        output = ref_output
    elif dtype == 'fp16':
        w_fp16 = weights.astype(np.float16)
        a_fp16 = activations.astype(np.float16)
        output = (a_fp16 @ w_fp16.T).astype(np.float32)
    elif dtype == 'bf16':
        # Simulate bfloat16 by truncating mantissa
        w_torch = torch.from_numpy(weights).to(torch.bfloat16)
        a_torch = torch.from_numpy(activations).to(torch.bfloat16)
        output = (a_torch @ w_torch.T).float().numpy()
    elif dtype == 'int8':
        # Symmetric quantization
        w_quant, w_scale = symmetric_quantize(weights, 8)
        a_quant, a_scale = symmetric_quantize(activations, 8)
        # Integer matmul and dequantize
        int_output = a_quant.astype(np.int32) @ w_quant.T.astype(np.int32)
        output = int_output.astype(np.float32) * (a_scale * w_scale)
    else:
        raise ValueError(f"Unknown dtype: {dtype}")
    
    # Compute metrics
    abs_error = np.abs(output - ref_output)
    rel_error = abs_error / (np.abs(ref_output) + 1e-10)
    
    return {
        'dtype': dtype,
        'mean_abs_error': np.mean(abs_error),
        'max_abs_error': np.max(abs_error),
        'mean_rel_error': np.mean(rel_error),
        'max_rel_error': np.max(rel_error),
        'snr_db': 10 * np.log10(np.mean(ref_output**2) / (np.mean(abs_error**2) + 1e-10)),
        'output_range': (output.min(), output.max())
    }


# Test the function
np.random.seed(42)
weights = np.random.randn(256, 512).astype(np.float32) * 0.1
activations = np.random.randn(32, 512).astype(np.float32)

print("Layer Precision Analysis:")
print("="*60)

for dtype in ['fp32', 'fp16', 'bf16', 'int8']:
    result = analyze_layer_precision(weights, activations, dtype)
    print(f"\n{dtype.upper()}:")
    print(f"  Mean Abs Error: {result['mean_abs_error']:.6f}")
    print(f"  Max Abs Error:  {result['max_abs_error']:.6f}")
    print(f"  SNR:            {result['snr_db']:.1f} dB")

## Exercise 2 Solution: Quantization Error Visualization

Visualize how quantization error varies across the value range.

In [None]:
def visualize_quantization_error(values: np.ndarray, bits: int):
    """
    Visualize quantization error distribution.
    """
    # Quantize and dequantize
    quantized, scale = symmetric_quantize(values, bits)
    dequantized = dequantize(quantized, scale)
    
    errors = values - dequantized
    rel_errors = np.abs(errors) / (np.abs(values) + 1e-10)
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 10))
    
    # Original vs Quantized scatter
    ax1 = axes[0, 0]
    ax1.scatter(values, dequantized, alpha=0.3, s=1)
    ax1.plot([values.min(), values.max()], [values.min(), values.max()], 
             'r--', label='Perfect')
    ax1.set_xlabel('Original Values')
    ax1.set_ylabel('Quantized Values')
    ax1.set_title(f'{bits}-bit Quantization: Original vs Reconstructed')
    ax1.legend()
    
    # Error histogram
    ax2 = axes[0, 1]
    ax2.hist(errors, bins=50, edgecolor='black', alpha=0.7)
    ax2.axvline(0, color='r', linestyle='--')
    ax2.set_xlabel('Quantization Error')
    ax2.set_ylabel('Count')
    ax2.set_title(f'Error Distribution (std={np.std(errors):.6f})')
    
    # Error vs value magnitude
    ax3 = axes[1, 0]
    ax3.scatter(np.abs(values), np.abs(errors), alpha=0.3, s=1)
    ax3.set_xlabel('|Original Value|')
    ax3.set_ylabel('|Quantization Error|')
    ax3.set_title('Error vs Value Magnitude')
    
    # Relative error vs value
    ax4 = axes[1, 1]
    mask = np.abs(values) > 0.01  # Avoid division by near-zero
    ax4.scatter(np.abs(values[mask]), rel_errors[mask] * 100, alpha=0.3, s=1)
    ax4.set_xlabel('|Original Value|')
    ax4.set_ylabel('Relative Error (%)')
    ax4.set_title('Relative Error vs Value Magnitude')
    ax4.set_yscale('log')
    
    plt.suptitle(f'{bits}-bit Symmetric Quantization Analysis', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"\n{bits}-bit Quantization Statistics:")
    print(f"  Scale: {scale:.6f}")
    print(f"  Mean Error: {np.mean(np.abs(errors)):.6f}")
    print(f"  Max Error: {np.max(np.abs(errors)):.6f}")
    print(f"  RMSE: {np.sqrt(np.mean(errors**2)):.6f}")


# Test with different distributions
np.random.seed(42)

# Normal distribution (typical weights)
normal_values = np.random.randn(10000).astype(np.float32)
print("Normal Distribution:")
visualize_quantization_error(normal_values, 8)

# Heavy-tailed distribution (activations)
print("\nHeavy-Tailed Distribution:")
heavy_tail = np.random.standard_t(3, 10000).astype(np.float32)
visualize_quantization_error(heavy_tail, 8)

## Exercise 3 Solution: Optimal Bit-Width Selection

Determine the minimum bit-width needed to maintain acceptable error.

In [None]:
def find_optimal_bitwidth(
    values: np.ndarray,
    max_error_threshold: float = 0.01,
    max_bits: int = 16
) -> dict:
    """
    Find minimum bits needed to keep max relative error below threshold.
    
    Args:
        values: Array to quantize
        max_error_threshold: Maximum acceptable relative error
        max_bits: Maximum bits to try
        
    Returns:
        Dictionary with optimal bit-width and analysis
    """
    results = []
    
    for bits in range(2, max_bits + 1):
        quantized, scale = symmetric_quantize(values, bits)
        dequantized = dequantize(quantized, scale)
        
        abs_errors = np.abs(values - dequantized)
        # Relative error (avoid division by zero)
        rel_errors = abs_errors / (np.abs(values) + 1e-10)
        
        max_rel_error = np.max(rel_errors[np.abs(values) > 0.01])
        mean_rel_error = np.mean(rel_errors[np.abs(values) > 0.01])
        
        # Compression ratio
        compression = 32 / bits
        
        results.append({
            'bits': bits,
            'max_rel_error': max_rel_error,
            'mean_rel_error': mean_rel_error,
            'compression': compression,
            'meets_threshold': max_rel_error < max_error_threshold
        })
    
    # Find optimal
    optimal = None
    for r in results:
        if r['meets_threshold']:
            optimal = r
            break
    
    return {
        'optimal_bits': optimal['bits'] if optimal else max_bits,
        'threshold': max_error_threshold,
        'all_results': results,
        'recommendation': f"{optimal['bits']}-bit" if optimal else f"{max_bits}-bit (threshold not met)"
    }


# Test with model weights
np.random.seed(42)
weights = np.random.randn(10000).astype(np.float32) * 0.02  # Typical weight scale

# Find optimal for different thresholds
for threshold in [0.1, 0.05, 0.01, 0.001]:
    result = find_optimal_bitwidth(weights, threshold)
    print(f"\nThreshold {threshold*100:.1f}%: {result['recommendation']}")
    
    # Show compression
    bits = result['optimal_bits']
    print(f"  Compression ratio: {32/bits:.1f}x")
    print(f"  Memory savings: {(1 - bits/32)*100:.0f}%")

## Summary

Key findings from the exercises:

1. **FP16 vs INT8**: INT8 has higher quantization error but better throughput
2. **Error distribution**: Quantization error is roughly uniform for symmetric quantization
3. **Optimal bit-width**: For 1% error threshold, typically 8 bits is sufficient for weights
4. **Heavy-tailed distributions**: Require more bits or clipping for good accuracy