# Location: notebooks/06_special_topics/19_research_tricks.ipynb

## Advanced Research Tricks and Techniques in JAX

This notebook covers advanced techniques and research tricks commonly used in machine learning research, including gradient manipulation, memory optimization, numerical tricks, and debugging strategies.

## Advanced Gradient Manipulation

```python
import jax
import jax.numpy as jnp
from jax import grad, vmap, jit, random
from functools import partial
import time

# Gradient clipping and modification
def create_gradient_manipulation_tools():
    """Create tools for advanced gradient manipulation"""
    
    def clip_gradients(grads, max_norm=1.0):
        """Clip gradients by global norm"""
        # Compute global gradient norm
        global_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(grads)))
        
        # Clip if necessary
        clip_factor = jnp.minimum(1.0, max_norm / (global_norm + 1e-8))
        clipped_grads = jax.tree_map(lambda g: g * clip_factor, grads)
        
        return clipped_grads, global_norm
    
    def gradient_centralization(grads):
        """Apply gradient centralization to weight gradients"""
        def centralize_grad(grad):
            if len(grad.shape) >= 2:  # Weight matrix
                # Center gradients
                centered = grad - jnp.mean(grad, axis=tuple(range(1, len(grad.shape))), keepdims=True)
                return centered
            else:  # Bias vector or other
                return grad
        
        return jax.tree_map(centralize_grad, grads)
    
    def gradient_noise_injection(grads, key, noise_scale=1e-3):
        """Add noise to gradients for improved generalization"""
        def add_noise_to_grad(grad, key):
            noise = random.normal(key, grad.shape) * noise_scale
            return grad + noise
        
        keys = jax.tree_map(lambda g: random.split(key, g.size)[0], grads)
        noisy_grads = jax.tree_map(add_noise_to_grad, grads, keys)
        
        return noisy_grads
    
    def gradient_standardization(grads, momentum=0.9, eps=1e-8):
        """Standardize gradients (similar to batch norm for gradients)"""
        # This would typically maintain running statistics
        # Simplified version for demonstration
        
        def standardize_grad(grad):
            mean = jnp.mean(grad)
            var = jnp.var(grad)
            return (grad - mean) / jnp.sqrt(var + eps)
        
        return jax.tree_map(standardize_grad, grads)
    
    return clip_gradients, gradient_centralization, gradient_noise_injection, gradient_standardization

# Test gradient manipulation
clip_grads, center_grads, noise_grads, std_grads = create_gradient_manipulation_tools()

# Create test gradients
test_grads = {
    'w1': jnp.array([[1.5, -2.0, 3.0], [0.5, -1.0, 2.5]]),
    'b1': jnp.array([0.1, -0.3, 0.8]),
    'w2': jnp.array([[10.0, -15.0], [5.0, -8.0], [12.0, -20.0]])
}

print("Original gradients:")
for name, grad in test_grads.items():
    print(f"  {name}: norm={jnp.linalg.norm(grad):.3f}")

# Test gradient clipping
clipped_grads, global_norm = clip_grads(test_grads, max_norm=5.0)
print(f"\nGradient clipping (max_norm=5.0):")
print(f"  Global norm before: {global_norm:.3f}")
clipped_norm = jnp.sqrt(sum(jnp.sum(g**2) for g in jax.tree_leaves(clipped_grads)))
print(f"  Global norm after: {clipped_norm:.3f}")

# Test gradient centralization
centered_grads = center_grads(test_grads)
print(f"\nGradient centralization:")
for name, (orig, cent) in zip(test_grads.keys(), zip(test_grads.values(), centered_grads.values())):
    if len(orig.shape) >= 2:
        print(f"  {name}: mean before={jnp.mean(orig):.3f}, after={jnp.mean(cent):.6f}")

# Test gradient noise injection
noisy_grads = noise_grads(test_grads, random.PRNGKey(42), noise_scale=0.1)
print(f"\nGradient noise injection:")
for name, (orig, noisy) in zip(test_grads.keys(), zip(test_grads.values(), noisy_grads.values())):
    diff = jnp.linalg.norm(noisy - orig)
    print(f"  {name}: noise magnitude={diff:.3f}")
```

## Memory Optimization Techniques

```python
def create_memory_optimization_tools():
    """Create tools for memory optimization in large models"""
    
    def gradient_checkpointing(fn, *args):
        """Manual gradient checkpointing implementation"""
        # Store only forward pass inputs, recompute on backward pass
        @jax.custom_vjp
        def checkpointed_fn(*args):
            return fn(*args)
        
        def checkpointed_fwd(*args):
            # Forward pass - only store inputs
            output = fn(*args)
            return output, args  # Store args as residuals
        
        def checkpointed_bwd(residuals, output_grad):
            args = residuals
            # Recompute forward pass to get intermediates for backward
            def inner_fn(*args):
                return fn(*args)
            _, vjp_fn = jax.vjp(inner_fn, *args)
            return vjp_fn(output_grad)
        
        checkpointed_fn.defvjp(checkpointed_fwd, checkpointed_bwd)
        return checkpointed_fn
    
    def activation_checkpointing_transformer_layer(params, x, layer_fn):
        """Checkpoint transformer layer activations"""
        
        @gradient_checkpointing
        def checkpointed_layer(params, x):
            return layer_fn(params, x)
        
        return checkpointed_layer(params, x)
    
    def mixed_precision_wrapper(fn):
        """Wrap function for mixed precision computation"""
        def mixed_precision_fn(*args):
            # Convert inputs to float16 for forward pass
            fp16_args = jax.tree_map(
                lambda x: x.astype(jnp.float16) if x.dtype == jnp.float32 else x,
                args
            )
            
            # Forward pass in fp16
            output = fn(*fp16_args)
            
            # Convert output back to fp32 for stability
            return jax.tree_map(
                lambda x: x.astype(jnp.float32) if x.dtype == jnp.float16 else x,
                output
            )
        
        return mixed_precision_fn
    
    def reversible_layer(f, g):
        """Implement reversible/invertible layer for memory efficiency"""
        def reversible_forward(x1, x2):
            # RevNet forward: y1 = x1 + f(x2), y2 = x2 + g(y1)
            y1 = x1 + f(x2)
            y2 = x2 + g(y1)
            return y1, y2
        
        def reversible_backward(y1, y2):
            # Reverse computation: x2 = y2 - g(y1), x1 = y1 - f(x2)
            x2 = y2 - g(y1)
            x1 = y1 - f(x2)
            return x1, x2
        
        return reversible_forward, reversible_backward
    
    return gradient_checkpointing, activation_checkpointing_transformer_layer, mixed_precision_wrapper, reversible_layer

# Test memory optimization tools
(grad_checkpoint, checkpoint_transformer, 
 mixed_precision, reversible) = create_memory_optimization_tools()

# Example: Memory-efficient large matrix multiplication
def large_matmul(A, B):
    """Large matrix multiplication with potential memory issues"""
    intermediate = jnp.tanh(A @ B)  # Large intermediate
    return jnp.sum(intermediate ** 2)

# Create large test matrices
key = random.PRNGKey(123)
A = random.normal(key, (1000, 800))
B = random.normal(random.split(key)[0], (800, 600))

# Compare memory usage (conceptually)
regular_fn = large_matmul
checkpointed_fn = grad_checkpoint(large_matmul)

print("Memory Optimization Example:")

# Regular computation
result1 = regular_fn(A, B)
grad_fn1 = jax.grad(regular_fn)
grads1 = grad_fn1(A, B)

print(f"Regular computation result: {result1:.6f}")
print(f"Regular gradient norm: {jnp.linalg.norm(grads1):.6f}")

# Checkpointed computation  
result2 = checkpointed_fn(A, B)
grad_fn2 = jax.grad(checkpointed_fn)
grads2 = grad_fn2(A, B)

print(f"Checkpointed result: {result2:.6f}")
print(f"Checkpointed gradient norm: {jnp.linalg.norm(grads2):.6f}")
print(f"Results match: {jnp.allclose(result1, result2)}")
print(f"Gradients match: {jnp.allclose(grads1, grads2)}")

# Test mixed precision
mp_fn = mixed_precision_wrapper(large_matmul)
result3 = mp_fn(A, B)
print(f"Mixed precision result: {result3:.6f}")
```

## Numerical Stability Tricks

```python
def create_numerical_stability_tools():
    """Create tools for numerical stability"""
    
    def stable_softmax(logits, axis=-1):
        """Numerically stable softmax implementation"""
        # Subtract max for stability
        shifted_logits = logits - jnp.max(logits, axis=axis, keepdims=True)
        exp_shifted = jnp.exp(shifted_logits)
        return exp_shifted / jnp.sum(exp_shifted, axis=axis, keepdims=True)
    
    def log_sum_exp(logits, axis=-1):
        """Stable log-sum-exp computation"""
        max_logits = jnp.max(logits, axis=axis, keepdims=True)
        return max_logits + jnp.log(jnp.sum(jnp.exp(logits - max_logits), axis=axis, keepdims=True))
    
    def stable_log_softmax(logits, axis=-1):
        """Stable log-softmax using log-sum-exp"""
        return logits - log_sum_exp(logits, axis=axis)
    
    def gelu_stable(x):
        """Numerically stable GELU activation"""
        # GELU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
        # Use more stable implementation
        return 0.5 * x * (1.0 + jax.nn.tanh(jnp.sqrt(2.0 / jnp.pi) * (x + 0.044715 * x**3)))
    
    def swish_stable(x, beta=1.0):
        """Stable Swish activation"""
        # Swish(x) = x * sigmoid(β*x)
        # Use stable sigmoid implementation
        return x * jax.nn.sigmoid(beta * x)
    
    def layer_norm_stable(x, gamma, beta, eps=1e-6):
        """Numerically stable layer normalization"""
        mean = jnp.mean(x, axis=-1, keepdims=True)
        var = jnp.var(x, axis=-1, keepdims=True)
        
        # Stable normalization
        normalized = (x - mean) / jnp.sqrt(var + eps)
        return gamma * normalized + beta
    
    def attention_weights_stable(query, key, scale=None):
        """Stable attention weight computation"""
        if scale is None:
            scale = 1.0 / jnp.sqrt(query.shape[-1])
        
        # Scaled dot-product attention
        scores = jnp.einsum('...qd,...kd->...qk', query, key) * scale
        
        # Stable softmax over key dimension
        weights = stable_softmax(scores, axis=-1)
        
        return weights, scores
    
    return (stable_softmax, log_sum_exp, stable_log_softmax, 
            gelu_stable, swish_stable, layer_norm_stable, attention_weights_stable)

# Test numerical stability tools
(stable_softmax, log_sum_exp, stable_log_softmax,
 gelu_stable, swish_stable, layer_norm_stable, attention_stable) = create_numerical_stability_tools()

# Test with extreme values
extreme_logits = jnp.array([100.0, 101.0, 99.0, 102.0])
print("Numerical Stability Tests:")

# Compare regular vs stable softmax
regular_softmax = lambda x: jnp.exp(x) / jnp.sum(jnp.exp(x))
try:
    regular_result = regular_softmax(extreme_logits)
    print(f"Regular softmax: {regular_result}")
except:
    print("Regular softmax: OVERFLOW ERROR")

stable_result = stable_softmax(extreme_logits)
print(f"Stable softmax: {stable_result}")
print(f"Stable softmax sums to: {jnp.sum(stable_result)}")

# Test log-sum-exp
lse_result = log_sum_exp(extreme_logits)
print(f"Log-sum-exp: {lse_result}")

# Test stable activations with extreme inputs
extreme_inputs = jnp.array([-50.0, -10.0, 0.0, 10.0, 50.0])

gelu_results = gelu_stable(extreme_inputs)
swish_results = swish_stable(extreme_inputs)

print(f"\nActivation functions on extreme inputs:")
print(f"GELU: {gelu_results}")
print(f"Swish: {swish_results}")

# Test stable layer norm
layer_input = jnp.array([[1e6, 1e6 + 1, 1e6 + 2], [1e-6, 2e-6, 3e-6]])
gamma = jnp.ones(3)
beta = jnp.zeros(3)

ln_result = layer_norm_stable(layer_input, gamma, beta)
print(f"\nLayer norm on extreme scale inputs:")
print(f"Input scale: {jnp.std(layer_input, axis=-1)}")
print(f"Output scale: {jnp.std(ln_result, axis=-1)}")
print(f"Output mean: {jnp.mean(ln_result, axis=-1)}")
```

## Advanced Debugging and Profiling

```python
def create_debugging_tools():
    """Create advanced debugging tools for JAX"""
    
    def debug_callback(name, x, print_shape=True, print_stats=True):
        """Debug callback that prints tensor information"""
        def _debug_callback(x):
            print(f"\n=== DEBUG: {name} ===")
            if print_shape:
                print(f"Shape: {x.shape}")
                print(f"Dtype: {x.dtype}")
            if print_stats:
                print(f"Min: {jnp.min(x):.6f}")
                print(f"Max: {jnp.max(x):.6f}")
                print(f"Mean: {jnp.mean(x):.6f}")
                print(f"Std: {jnp.std(x):.6f}")
                print(f"Has NaN: {jnp.any(jnp.isnan(x))}")
                print(f"Has Inf: {jnp.any(jnp.isinf(x))}")
            return x
        
        return jax.debug.callback(_debug_callback, x)
    
    def gradient_debug_wrapper(fn):
        """Wrapper that debugs gradients"""
        def debug_grad_fn(*args, **kwargs):
            # Forward pass
            result = fn(*args, **kwargs)
            
            # Gradient computation with debugging
            def loss_fn(*args):
                out = fn(*args, **kwargs)
                return jnp.sum(out) if hasattr(out, 'shape') else out
            
            grads = jax.grad(loss_fn)(*args)
            
            # Debug gradient information
            print(f"\n=== GRADIENT DEBUG ===")
            flat_grads = jax.tree_leaves(grads)
            grad_norms = [jnp.linalg.norm(g) for g in flat_grads]
            
            print(f"Number of gradient tensors: {len(flat_grads)}")
            print(f"Gradient norms: {grad_norms}")
            print(f"Total gradient norm: {jnp.sqrt(sum(norm**2 for norm in grad_norms))}")
            
            # Check for problematic gradients
            for i, (g, norm) in enumerate(zip(flat_grads, grad_norms)):
                if jnp.any(jnp.isnan(g)):
                    print(f"  WARNING: Gradient {i} contains NaN!")
                if jnp.any(jnp.isinf(g)):
                    print(f"  WARNING: Gradient {i} contains Inf!")
                if norm > 100:
                    print(f"  WARNING: Large gradient norm in tensor {i}: {norm}")
            
            return result
        
        return debug_grad_fn
    
    def profile_function(fn, *args, n_runs=10, warmup=3):
        """Profile function execution time"""
        # Warmup runs
        for _ in range(warmup):
            _ = fn(*args)
            
        # Timing runs
        times = []
        for _ in range(n_runs):
            start_time = time.time()
            result = fn(*args)
            if hasattr(result, 'block_until_ready'):
                result.block_until_ready()
            end_time = time.time()
            times.append(end_time - start_time)
        
        times = jnp.array(times)
        return {
            'mean_time': jnp.mean(times),
            'std_time': jnp.std(times),
            'min_time': jnp.min(times),
            'max_time': jnp.max(times),
            'all_times': times
        }
    
    def memory_profiler():
        """Simple memory usage profiler"""
        # This is a placeholder - actual memory profiling requires external tools
        import gc
        import psutil
        import os
        
        process = psutil.Process(os.getpid())
        memory_info = process.memory_info()
        
        return {
            'rss_mb': memory_info.rss / 1024 / 1024,  # Resident set size
            'vms_mb': memory_info.vms / 1024 / 1024,  # Virtual memory size
        }
    
    def nan_detector(fn):
        """Wrapper that detects NaN/Inf in computation"""
        def nan_detect_fn(*args, **kwargs):
            result = fn(*args, **kwargs)
            
            def check_tensor(x, name="tensor"):
                if hasattr(x, 'shape'):
                    if jnp.any(jnp.isnan(x)):
                        raise ValueError(f"NaN detected in {name}")
                    if jnp.any(jnp.isinf(x)):
                        raise ValueError(f"Inf detected in {name}")
                return x
            
            # Check all outputs
            jax.tree_map(check_tensor, result)
            return result
        
        return nan_detect_fn
    
    return debug_callback, gradient_debug_wrapper, profile_function, memory_profiler, nan_detector

# Test debugging tools
debug_cb, grad_debug, profile_fn, mem_profile, nan_detect = create_debugging_tools()

# Example function to debug
def test_function(x, w):
    """Test function for debugging"""
    # Add debug point
    x = debug_cb("input", x)
    
    # Some computation
    h1 = jax.nn.relu(x @ w)
    h1 = debug_cb("after_relu", h1)
    
    # Potential numerical instability
    h2 = jnp.exp(h1) / (jnp.sum(jnp.exp(h1)) + 1e-8)
    h2 = debug_cb("after_softmax", h2)
    
    return jnp.sum(h2)

# Test debugging
print("=== DEBUGGING DEMO ===")
test_x = random.normal(random.PRNGKey(789), (32, 64))
test_w = random.normal(random.PRNGKey(790), (64, 128))

# Use debug wrapper
debug_fn = grad_debug(test_function)
result = debug_fn(test_x, test_w)

# Profile the function
print("\n=== PROFILING DEMO ===")
profile_results = profile_fn(test_function, test_x, test_w, n_runs=5)
print(f"Mean execution time: {profile_results['mean_time']:.4f} ± {profile_results['std_time']:.4f} seconds")
print(f"Min/Max time: {profile_results['min_time']:.4f}/{profile_results['max_time']:.4f} seconds")

# Memory profiling
mem_before = mem_profile()
large_array = jnp.zeros((10000, 10000))  # Allocate large array
mem_after = mem_profile()

print(f"\nMemory usage:")
print(f"Before: {mem_before['rss_mb']:.1f} MB")
print(f"After: {mem_after['rss_mb']:.1f} MB")
print(f"Difference: {mem_after['rss_mb'] - mem_before['rss_mb']:.1f} MB")

# Test NaN detector
safe_fn = nan_detect(test_function)
result_safe = safe_fn(test_x, test_w)
print(f"\nNaN detection passed: {result_safe:.6f}")
```

## Research-Specific Optimization Tricks

```python
def create_research_optimization_tricks():
    """Advanced optimization tricks for research"""
    
    def cosine_annealing_schedule(step, total_steps, lr_max, lr_min=0.0):
        """Cosine annealing learning rate schedule"""
        cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * step / total_steps))
        return lr_min + (lr_max - lr_min) * cosine_decay
    
    def warmup_cosine_schedule(step, warmup_steps, total_steps, lr_max, lr_min=0.0):
        """Warmup followed by cosine annealing"""
        def warmup_phase():
            return lr_max * step / warmup_steps
        
        def cosine_phase():
            cosine_steps = total_steps - warmup_steps
            cosine_step = step - warmup_steps
            cosine_decay = 0.5 * (1 + jnp.cos(jnp.pi * cosine_step / cosine_steps))
            return lr_min + (lr_max - lr_min) * cosine_decay
        
        return jnp.where(step < warmup_steps, warmup_phase(), cosine_phase())
    
    def lookahead_optimizer(base_optimizer_state, params, grads, k=5, alpha=0.5):
        """Lookahead optimizer wrapper"""
        # This is a simplified version - full implementation would track slow weights
        
        # Base optimizer update
        new_base_state = base_optimizer_state  # Placeholder
        fast_weights = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)  # Simplified SGD
        
        # Lookahead update (every k steps)
        step = base_optimizer_state.get('step', 0) + 1
        
        def lookahead_update():
            # Interpolate between fast and slow weights
            slow_weights = base_optimizer_state.get('slow_weights', params)
            new_slow_weights = jax.tree_map(
                lambda slow, fast: slow + alpha * (fast - slow),
                slow_weights, fast_weights
            )
            return new_slow_weights, {'slow_weights': new_slow_weights, 'step': step}
        
        def regular_update():
            return fast_weights, {'slow_weights': base_optimizer_state.get('slow_weights', params), 'step': step}
        
        return jnp.where(step % k == 0, lookahead_update(), regular_update())
    
    def spectral_normalization(w, u=None, power_iterations=1):
        """Spectral normalization for weight matrices"""
        if u is None:
            u = random.normal(random.PRNGKey(42), (w.shape[0],))
            u = u / jnp.linalg.norm(u)
        
        # Power iteration to find top singular value
        for _ in range(power_iterations):
            v = w.T @ u
            v = v / jnp.linalg.norm(v)
            u = w @ v
            u = u / jnp.linalg.norm(u)
        
        # Compute spectral norm
        sigma = jnp.dot(u, w @ v)
        
        # Normalize weights
        w_sn = w / sigma
        
        return w_sn, u, sigma
    
    def orthogonal_regularization(params, reg_strength=1e-4):
        """Orthogonal regularization for weight matrices"""
        reg_loss = 0.0
        
        for param in jax.tree_leaves(params):
            if len(param.shape) == 2 and min(param.shape) > 1:  # Matrix
                # Compute W^T W - I
                wtw = param.T @ param
                identity = jnp.eye(wtw.shape[0])
                ortho_loss = jnp.sum((wtw - identity) ** 2)
                reg_loss += ortho_loss
        
        return reg_strength * reg_loss
    
    def feature_matching_loss(real_features, fake_features):
        """Feature matching loss for GANs"""
        return jnp.mean((jnp.mean(real_features, axis=0) - jnp.mean(fake_features, axis=0)) ** 2)
    
    def progressive_training_schedule(step, stages, stage_lengths):
        """Progressive training with different stages"""
        cumulative_steps = jnp.cumsum(jnp.array(stage_lengths))
        current_stage = jnp.sum(step >= cumulative_steps)
        current_stage = jnp.clip(current_stage, 0, len(stages) - 1)
        
        return stages[current_stage]
    
    return (cosine_annealing_schedule, warmup_cosine_schedule, lookahead_optimizer,
            spectral_normalization, orthogonal_regularization, feature_matching_loss,
            progressive_training_schedule)

# Test research optimization tricks
(cosine_schedule, warmup_cosine, lookahead_opt, spectral_norm,
 ortho_reg, feature_match, progressive_schedule) = create_research_optimization_tricks()

print("=== RESEARCH OPTIMIZATION TRICKS ===")

# Test learning rate schedules
steps = jnp.arange(0, 1000, 10)
cosine_lrs = [cosine_schedule(s, 1000, 0.1, 0.001) for s in steps]
warmup_lrs = [warmup_cosine(s, 100, 1000, 0.1, 0.001) for s in steps]

print("Learning rate schedules:")
print(f"Cosine at steps [0, 250, 500, 750, 1000]: {[cosine_schedule(s, 1000, 0.1) for s in [0, 250, 500, 750, 1000]]}")
print(f"Warmup+cosine at same steps: {[warmup_cosine(s, 100, 1000, 0.1) for s in [0, 250, 500, 750, 1000]]}")

# Test spectral normalization
test_weight = random.normal(random.PRNGKey(999), (64, 32))
w_sn, u_vec, sigma = spectral_norm(test_weight)

print(f"\nSpectral normalization:")
print(f"Original spectral norm: {sigma:.4f}")
print(f"Normalized weight spectral norm: {jnp.linalg.norm(w_sn, ord=2):.4f}")

# Test orthogonal regularization
test_params = {
    'w1': random.normal(random.PRNGKey(1001), (50, 50)),
    'w2': random.normal(random.PRNGKey(1002), (30, 40)),
    'b1': jnp.zeros(50)
}

ortho_loss = ortho_reg(test_params)
print(f"Orthogonal regularization loss: {ortho_loss:.6f}")

# Test feature matching
real_features = random.normal(random.PRNGKey(1003), (100, 256))
fake_features = random.normal(random.PRNGKey(1004), (100, 256))
fm_loss = feature_match(real_features, fake_features)
print(f"Feature matching loss: {fm_loss:.6f}")

# Test progressive training
stages = ['warmup', 'normal', 'fine_tune']
stage_lengths = [100, 500, 400]

print(f"\nProgressive training stages:")
for step in [50, 300, 800]:
    stage = progressive_schedule(step, stages, stage_lengths)
    print(f"Step {step}: Stage {stage}")
```

## Summary

In this notebook, we explored advanced research tricks and techniques in JAX:

**Gradient Manipulation:**
- **Gradient Clipping**: Global norm clipping for training stability
- **Gradient Centralization**: Centering gradients for better optimization
- **Gradient Noise**: Adding noise for generalization
- **Gradient Standardization**: Normalizing gradient distributions

**Memory Optimization:**
- **Gradient Checkpointing**: Trading computation for memory
- **Mixed Precision**: Using float16 for memory efficiency
- **Reversible Layers**: Invertible architectures
- **Activation Checkpointing**: Strategic intermediate storage

**Numerical Stability:**
- **Stable Softmax**: Avoiding overflow in attention mechanisms  
- **Log-Sum-Exp**: Numerically stable logarithmic computations
- **Stable Activations**: GELU, Swish with numerical safeguards
- **Stable Normalization**: Layer norm with epsilon protection

**Debugging and Profiling:**
- **Debug Callbacks**: Runtime tensor inspection
- **Gradient Debugging**: Automatic gradient analysis
- **Performance Profiling**: Execution time measurement
- **NaN/Inf Detection**: Automatic numerical error catching

**Advanced Optimization:**
- **Learning Rate Schedules**: Cosine annealing with warmup
- **Lookahead Optimizer**: Slow-fast weight interpolation
- **Spectral Normalization**: Lipschitz constraint enforcement
- **Orthogonal Regularization**: Weight matrix orthogonality
- **Progressive Training**: Multi-stage training protocols

**Research Applications:**
- **GAN Training**: Feature matching and spectral normalization
- **Large Model Training**: Memory optimization and numerical stability
- **Transformer Training**: Gradient clipping and learning rate schedules
- **Experimental Debugging**: Comprehensive diagnostic tools

These techniques are essential for pushing the boundaries of what's possible in machine learning research while maintaining numerical stability and computational efficiency in JAX.