# Custom Autograd Functions Mastery: PyTorch Mastery Hub

**Building Advanced Differentiable Operations from Scratch**

**Authors:** PyTorch Mastery Hub Team  
**Institution:** Advanced Deep Learning Education  
**Course:** PyTorch Fundamentals & Advanced Techniques  
**Date:** December 2024

## Overview

This comprehensive notebook provides deep expertise in creating custom autograd functions in PyTorch. We'll master the art of building differentiable operations from the ground up, enabling advanced research capabilities and production-ready custom operations.

## Key Objectives
1. Master the `torch.autograd.Function` API and its advanced features
2. Implement custom forward and backward passes with proper gradient computation
3. Build memory-efficient operations for large-scale training scenarios
4. Handle non-differentiable operations using advanced techniques
5. Create production-ready custom functions with comprehensive testing
6. Develop advanced loss functions and specialized operations
7. Apply best practices for debugging and performance optimization

## 📚 Learning Path
- **Prerequisites:** Gradient computation fundamentals, calculus knowledge, PyTorch basics
- **Difficulty:** Advanced
- **Duration:** 2-3 hours
- **Applications:** Research, custom architectures, specialized domains

## 🎯 Advanced Topics Coverage
- Custom activation functions with complex derivatives
- Memory-efficient operations and gradient checkpointing
- Straight-through estimators for non-differentiable operations
- Advanced loss functions (Focal Loss, Knowledge Distillation)
- Comprehensive testing and validation frameworks
- Performance optimization and debugging techniques

---

## 1. Environment Setup and Foundation

```python
# Comprehensive imports for advanced custom autograd functions
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import math
import json
from pathlib import Path
from typing import Any, Tuple, Optional, List, Dict, Union
import warnings
warnings.filterwarnings('ignore')

# Advanced utilities
from collections import defaultdict
import inspect
from functools import wraps

# Create results directory for this notebook
results_dir = Path('../results/notebooks/custom_autograd_functions')
results_dir.mkdir(parents=True, exist_ok=True)

# Setup environment
torch.manual_seed(42)
np.random.seed(42)

# Configure plotting
plt.style.use('seaborn-v0_8')
sns.set_palette('husl')
plt.rcParams['figure.figsize'] = (14, 10)
plt.rcParams['font.size'] = 11

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print("🔥 PyTorch Mastery Hub - Custom Autograd Functions Mastery")
print("=" * 70)
print(f"📱 Device: {device}")
print(f"🎨 PyTorch version: {torch.__version__}")
print(f"📊 NumPy version: {np.__version__}")
print(f"📁 Results directory: {results_dir}")
print("✨ Ready to master custom differentiable operations!\n")

# Performance tracking
class PerformanceTracker:
    def __init__(self):
        self.metrics = defaultdict(list)
    
    def track(self, name, value):
        self.metrics[name].append(value)
    
    def get_summary(self):
        summary = {}
        for name, values in self.metrics.items():
            summary[name] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values),
                'count': len(values)
            }
        return summary

performance_tracker = PerformanceTracker()
```

## 2. Autograd Function Fundamentals: From Basics to Advanced

### 2.1 Understanding the Core Architecture

```python
def demonstrate_autograd_architecture():
    """Comprehensive demonstration of autograd function architecture"""
    
    print("=== 2.1 Autograd Function Architecture Deep Dive ===\n")
    
    class DetailedSquareFunction(Function):
        """Extensively documented square function for educational purposes"""
        
        @staticmethod
        def forward(ctx, input, verbose=False):
            """
            Forward pass: f(x) = x²
            
            Args:
                ctx: PyTorch context object for saving backward information
                input: Input tensor of any shape
                verbose: Whether to print detailed information
                
            Returns:
                Output tensor with same shape as input
                
            Mathematical Details:
                f(x) = x²
                Domain: All real numbers
                Range: [0, +∞)
            """
            if verbose:
                print(f"📊 Forward Pass Analysis:")
                print(f"  Input shape: {input.shape}")
                print(f"  Input dtype: {input.dtype}")
                print(f"  Input device: {input.device}")
                print(f"  Input requires_grad: {input.requires_grad}")
                print(f"  Input range: [{input.min().item():.4f}, {input.max().item():.4f}]")
            
            # Save input for backward pass - critical for gradient computation
            ctx.save_for_backward(input)
            
            # Store additional metadata if needed
            ctx.input_shape = input.shape
            ctx.verbose = verbose
            
            # Compute forward operation
            result = input * input  # Equivalent to input ** 2 but more efficient
            
            if verbose:
                print(f"  Output shape: {result.shape}")
                print(f"  Output range: [{result.min().item():.4f}, {result.max().item():.4f}]")
                print(f"  Memory usage: {result.numel() * result.element_size() / 1024**2:.2f} MB")
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass: df/dx = 2x
            
            Args:
                ctx: Context with saved forward information
                grad_output: Gradient of loss w.r.t. output (∂L/∂y)
                
            Returns:
                grad_input: Gradient of loss w.r.t. input (∂L/∂x)
                grad_verbose: None (verbose is not a tensor parameter)
                
            Mathematical Details:
                If f(x) = x², then f'(x) = 2x
                By chain rule: ∂L/∂x = ∂L/∂y * ∂y/∂x = grad_output * 2x
            """
            # Retrieve saved tensors
            input, = ctx.saved_tensors
            verbose = ctx.verbose
            
            if verbose:
                print(f"\n🔄 Backward Pass Analysis:")
                print(f"  Grad output shape: {grad_output.shape}")
                print(f"  Grad output range: [{grad_output.min().item():.4f}, {grad_output.max().item():.4f}]")
                print(f"  Saved input range: [{input.min().item():.4f}, {input.max().item():.4f}]")
            
            # Compute gradient using chain rule
            # df/dx = 2x, so ∂L/∂x = grad_output * 2x
            grad_input = 2 * input * grad_output
            
            if verbose:
                print(f"  Computed gradient range: [{grad_input.min().item():.4f}, {grad_input.max().item():.4f}]")
                print(f"  Gradient norm: {grad_input.norm().item():.4f}")
            
            # Return gradients for all forward inputs
            # For non-tensor inputs (like verbose), return None
            return grad_input, None
    
    # Wrapper function for convenience
    def detailed_square(input, verbose=False):
        """Convenient wrapper for DetailedSquareFunction"""
        return DetailedSquareFunction.apply(input, verbose)
    
    return detailed_square

# Create and test the detailed square function
detailed_square = demonstrate_autograd_architecture()

print("🧪 Testing Detailed Square Function:")
print("-" * 40)

# Test with different tensor types and shapes
test_cases = [
    ("1D tensor", torch.tensor([1.0, 2.0, 3.0], requires_grad=True)),
    ("2D tensor", torch.randn(3, 4, requires_grad=True)),
    ("3D tensor", torch.randn(2, 3, 4, requires_grad=True)),
    ("Large tensor", torch.randn(100, 100, requires_grad=True))
]

results_summary = {}

for name, test_tensor in test_cases:
    print(f"\n📋 Test Case: {name}")
    print(f"Shape: {test_tensor.shape}, Elements: {test_tensor.numel()}")
    
    # Forward pass
    start_time = time.time()
    output = detailed_square(test_tensor, verbose=(name == "1D tensor"))
    forward_time = time.time() - start_time
    
    # Backward pass
    start_time = time.time()
    loss = output.sum()
    loss.backward()
    backward_time = time.time() - start_time
    
    # Verify gradients
    analytical_grad = 2 * test_tensor.detach()
    gradient_error = (test_tensor.grad - analytical_grad).abs().max().item()
    
    results_summary[name] = {
        'forward_time': forward_time * 1000,  # Convert to ms
        'backward_time': backward_time * 1000,
        'gradient_error': gradient_error,
        'output_norm': output.norm().item(),
        'gradient_norm': test_tensor.grad.norm().item()
    }
    
    print(f"✅ Forward time: {forward_time*1000:.3f}ms")
    print(f"✅ Backward time: {backward_time*1000:.3f}ms")
    print(f"✅ Gradient error: {gradient_error:.2e}")
    print(f"✅ Gradient correct: {gradient_error < 1e-6}")

# Save results
with open(results_dir / 'basic_autograd_test_results.json', 'w') as f:
    json.dump(results_summary, f, indent=2)

print(f"\n💾 Basic autograd test results saved to {results_dir / 'basic_autograd_test_results.json'}")
```

### 2.2 Multi-Input Functions with Complex Dependencies

```python
def demonstrate_multi_input_functions():
    """Advanced multi-input custom functions with complex gradient computation"""
    
    print("\n=== 2.2 Advanced Multi-Input Functions ===\n")
    
    class WeightedNormFunction(Function):
        """
        Custom function: f(x, w, p) = ||w ⊙ x||_p
        Where ⊙ is element-wise multiplication and ||·||_p is the p-norm
        """
        
        @staticmethod
        def forward(ctx, input, weight, p=2.0, eps=1e-8):
            """
            Forward pass: weighted p-norm
            
            Args:
                input: Input tensor [batch_size, features]
                weight: Weight tensor [features] or [batch_size, features]
                p: Norm order (default: 2 for L2 norm)
                eps: Small value for numerical stability
            """
            # Validate inputs
            assert input.dim() >= 1, "Input must be at least 1D"
            assert weight.dim() <= input.dim(), "Weight dimensions must not exceed input dimensions"
            
            # Broadcast weight if necessary
            if weight.dim() == 1 and input.dim() == 2:
                weight = weight.unsqueeze(0).expand_as(input)
            
            # Compute weighted values
            weighted_input = input * weight
            
            # Compute p-norm
            if p == float('inf'):
                norm_result = torch.max(torch.abs(weighted_input), dim=-1)[0]
            elif p == 1:
                norm_result = torch.sum(torch.abs(weighted_input), dim=-1)
            else:
                norm_result = torch.sum(torch.abs(weighted_input) ** p, dim=-1) ** (1.0 / p)
            
            # Add epsilon for numerical stability
            norm_result = norm_result + eps
            
            # Save for backward pass
            ctx.save_for_backward(input, weight, weighted_input, norm_result)
            ctx.p = p
            ctx.eps = eps
            
            return norm_result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Complex backward pass for weighted p-norm
            """
            input, weight, weighted_input, norm_result = ctx.saved_tensors
            p = ctx.p
            eps = ctx.eps
            
            # Initialize gradients
            grad_input = grad_weight = None
            
            if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
                # Compute common terms
                abs_weighted = torch.abs(weighted_input)
                
                if p == float('inf'):
                    # For infinity norm, gradient is sparse
                    max_indices = torch.argmax(abs_weighted, dim=-1, keepdim=True)
                    grad_weighted = torch.zeros_like(weighted_input)
                    grad_weighted.scatter_(-1, max_indices, torch.sign(weighted_input.gather(-1, max_indices)))
                elif p == 1:
                    # For L1 norm
                    grad_weighted = torch.sign(weighted_input)
                else:
                    # For general p-norm
                    if p == 2:
                        # Optimized case for L2 norm
                        grad_weighted = weighted_input / (norm_result.unsqueeze(-1) - eps + 1e-12)
                    else:
                        # General case
                        grad_weighted = (
                            torch.sign(weighted_input) * 
                            (abs_weighted ** (p - 1)) * 
                            (norm_result.unsqueeze(-1) ** (1 - p))
                        )
                
                # Apply chain rule with output gradient
                grad_weighted = grad_weighted * grad_output.unsqueeze(-1)
                
                # Compute input gradient
                if ctx.needs_input_grad[0]:
                    grad_input = grad_weighted * weight
                
                # Compute weight gradient
                if ctx.needs_input_grad[1]:
                    grad_weight = grad_weighted * input
                    
                    # Sum over batch dimension if weight was broadcasted
                    if input.dim() == 2 and weight.dim() == 1:
                        grad_weight = grad_weight.sum(dim=0)
            
            return grad_input, grad_weight, None, None
    
    class BilinearInteractionFunction(Function):
        """
        Bilinear interaction: f(x, y, W) = x^T W y
        Useful for attention mechanisms and feature interactions
        """
        
        @staticmethod
        def forward(ctx, x, y, W):
            """
            Forward: f(x, y, W) = x^T W y
            
            Args:
                x: First input [batch_size, dim_x]
                y: Second input [batch_size, dim_y]  
                W: Bilinear weight [dim_x, dim_y]
            """
            # Compute bilinear interaction
            # result[i] = x[i]^T W y[i] for each batch element
            result = torch.sum(x.unsqueeze(2) * W.unsqueeze(0), dim=1)  # [batch, dim_y]
            result = torch.sum(result * y, dim=1)  # [batch]
            
            # Save for backward
            ctx.save_for_backward(x, y, W)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for bilinear interaction
            
            Gradients:
                ∂f/∂x = W y
                ∂f/∂y = W^T x  
                ∂f/∂W = x y^T (outer product)
            """
            x, y, W = ctx.saved_tensors
            
            grad_x = grad_y = grad_W = None
            
            if ctx.needs_input_grad[0]:
                # ∂f/∂x = W y
                grad_x = torch.matmul(y.unsqueeze(1), W.t()).squeeze(1)
                grad_x = grad_x * grad_output.unsqueeze(1)
            
            if ctx.needs_input_grad[1]:
                # ∂f/∂y = W^T x
                grad_y = torch.matmul(x.unsqueeze(1), W).squeeze(1)
                grad_y = grad_y * grad_output.unsqueeze(1)
            
            if ctx.needs_input_grad[2]:
                # ∂f/∂W = sum over batch of x_i y_i^T
                grad_W = torch.sum(
                    x.unsqueeze(2) * y.unsqueeze(1) * grad_output.unsqueeze(1).unsqueeze(2),
                    dim=0
                )
            
            return grad_x, grad_y, grad_W
    
    # Wrapper functions
    def weighted_norm(input, weight, p=2.0, eps=1e-8):
        return WeightedNormFunction.apply(input, weight, p, eps)
    
    def bilinear_interaction(x, y, W):
        return BilinearInteractionFunction.apply(x, y, W)
    
    return weighted_norm, bilinear_interaction

# Create and test multi-input functions
weighted_norm, bilinear_interaction = demonstrate_multi_input_functions()

print("🧪 Testing Multi-Input Functions:")
print("-" * 40)

# Test weighted norm function
print("\n📊 Testing Weighted Norm Function:")
batch_size, features = 32, 64
test_input = torch.randn(batch_size, features, requires_grad=True)
test_weight = torch.randn(features, requires_grad=True)

# Test different norms
norm_orders = [1, 2, float('inf')]
norm_results = {}

for p in norm_orders:
    # Clear gradients
    if test_input.grad is not None:
        test_input.grad.zero_()
    if test_weight.grad is not None:
        test_weight.grad.zero_()
    
    # Forward pass
    norm_output = weighted_norm(test_input, test_weight, p=p)
    
    # Backward pass
    loss = norm_output.sum()
    loss.backward()
    
    norm_results[f'L{p}'] = {
        'output_mean': norm_output.mean().item(),
        'output_std': norm_output.std().item(),
        'input_grad_norm': test_input.grad.norm().item(),
        'weight_grad_norm': test_weight.grad.norm().item()
    }
    
    print(f"  L{p} norm - Mean: {norm_output.mean():.4f}, "
          f"Input grad norm: {test_input.grad.norm():.4f}")

# Test bilinear interaction
print("\n📊 Testing Bilinear Interaction Function:")
dim_x, dim_y = 32, 24
x = torch.randn(batch_size, dim_x, requires_grad=True)
y = torch.randn(batch_size, dim_y, requires_grad=True)
W = torch.randn(dim_x, dim_y, requires_grad=True)

# Forward pass
bilinear_output = bilinear_interaction(x, y, W)
print(f"  Output shape: {bilinear_output.shape}")
print(f"  Output range: [{bilinear_output.min():.4f}, {bilinear_output.max():.4f}]")

# Backward pass
bilinear_loss = bilinear_output.sum()
bilinear_loss.backward()

print(f"  X gradient norm: {x.grad.norm():.4f}")
print(f"  Y gradient norm: {y.grad.norm():.4f}")
print(f"  W gradient norm: {W.grad.norm():.4f}")

# Save multi-input function results
multi_input_results = {
    'weighted_norm': norm_results,
    'bilinear_interaction': {
        'output_mean': bilinear_output.mean().item(),
        'output_std': bilinear_output.std().item(),
        'x_grad_norm': x.grad.norm().item(),
        'y_grad_norm': y.grad.norm().item(),
        'W_grad_norm': W.grad.norm().item()
    }
}

with open(results_dir / 'multi_input_function_results.json', 'w') as f:
    json.dump(multi_input_results, f, indent=2)

print(f"\n💾 Multi-input function results saved")
```

## 3. Advanced Custom Activation Functions

### 3.1 Sophisticated Activation Functions with Complex Derivatives

```python
def create_advanced_activation_functions():
    """Create sophisticated activation functions for research and production"""
    
    print("\n=== 3.1 Advanced Custom Activation Functions ===\n")
    
    class SwishFunction(Function):
        """
        Swish activation: f(x) = x * sigmoid(βx)
        Self-gating activation with learnable parameter β
        """
        
        @staticmethod
        def forward(ctx, input, beta=1.0):
            """
            Forward: f(x) = x * σ(βx) where σ is sigmoid
            """
            scaled_input = beta * input
            sigmoid_x = torch.sigmoid(scaled_input)
            result = input * sigmoid_x
            
            # Save for efficient backward computation
            ctx.save_for_backward(input, sigmoid_x)
            ctx.beta = beta
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward: f'(x) = σ(βx) + βx * σ(βx) * (1 - σ(βx))
                            = σ(βx) * (1 + βx * (1 - σ(βx)))
            """
            input, sigmoid_x = ctx.saved_tensors
            beta = ctx.beta
            
            # Efficient computation using saved sigmoid
            sigmoid_derivative = sigmoid_x * (1 - sigmoid_x)
            swish_derivative = sigmoid_x + beta * input * sigmoid_derivative
            
            grad_input = grad_output * swish_derivative
            
            # Gradient w.r.t. beta (if beta requires grad)
            grad_beta = None
            if ctx.needs_input_grad[1]:
                grad_beta = torch.sum(grad_output * input * input * sigmoid_derivative)
            
            return grad_input, grad_beta
    
    class MishFunction(Function):
        """
        Mish activation: f(x) = x * tanh(softplus(x))
        Smooth, non-monotonic activation with excellent properties
        """
        
        @staticmethod
        def forward(ctx, input):
            """
            Forward: f(x) = x * tanh(ln(1 + e^x))
            """
            # Use softplus for numerical stability
            softplus_x = F.softplus(input)
            tanh_softplus = torch.tanh(softplus_x)
            result = input * tanh_softplus
            
            # Save intermediate results for efficient backward
            ctx.save_for_backward(input, softplus_x, tanh_softplus)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Complex derivative computation for Mish
            f'(x) = tanh(softplus(x)) + x * sech²(softplus(x)) * sigmoid(x)
            """
            input, softplus_x, tanh_softplus = ctx.saved_tensors
            
            # Compute derivative components
            sigmoid_x = torch.sigmoid(input)
            sech_squared = 1 - tanh_softplus ** 2  # sech²(x) = 1 - tanh²(x)
            
            # Full derivative
            mish_derivative = tanh_softplus + input * sech_squared * sigmoid_x
            
            grad_input = grad_output * mish_derivative
            
            return grad_input
    
    class AdaptiveActivationFunction(Function):
        """
        Learnable activation: f(x) = a * x + b * g(c * x + d)
        Where g is a base activation (tanh, relu, etc.)
        """
        
        @staticmethod
        def forward(ctx, input, params, activation_type='tanh'):
            """
            Forward pass for adaptive activation
            
            Args:
                input: Input tensor
                params: [a, b, c, d] learnable parameters
                activation_type: Base activation ('tanh', 'relu', 'elu')
            """
            a, b, c, d = params[0], params[1], params[2], params[3]
            
            # Linear component
            linear_part = a * input
            
            # Nonlinear component
            nonlinear_input = c * input + d
            
            if activation_type == 'tanh':
                nonlinear_activation = torch.tanh(nonlinear_input)
                activation_derivative = 1 - nonlinear_activation ** 2
            elif activation_type == 'relu':
                nonlinear_activation = F.relu(nonlinear_input)
                activation_derivative = (nonlinear_input > 0).float()
            elif activation_type == 'elu':
                nonlinear_activation = F.elu(nonlinear_input)
                activation_derivative = torch.where(
                    nonlinear_input > 0, 
                    torch.ones_like(nonlinear_input),
                    nonlinear_activation + 1
                )
            else:
                raise ValueError(f"Unsupported activation type: {activation_type}")
            
            nonlinear_part = b * nonlinear_activation
            result = linear_part + nonlinear_part
            
            # Save for backward
            ctx.save_for_backward(input, params, nonlinear_activation, activation_derivative)
            ctx.activation_type = activation_type
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass with gradients for both input and parameters
            """
            input, params, nonlinear_activation, activation_derivative = ctx.saved_tensors
            a, b, c, d = params[0], params[1], params[2], params[3]
            
            # Gradient w.r.t. input
            grad_input = grad_output * (a + b * c * activation_derivative)
            
            # Gradients w.r.t. parameters
            grad_params = torch.zeros_like(params)
            
            if ctx.needs_input_grad[1]:
                # ∂f/∂a = x
                grad_params[0] = torch.sum(grad_output * input)
                
                # ∂f/∂b = g(cx + d)
                grad_params[1] = torch.sum(grad_output * nonlinear_activation)
                
                # ∂f/∂c = b * x * g'(cx + d)
                grad_params[2] = torch.sum(grad_output * b * input * activation_derivative)
                
                # ∂f/∂d = b * g'(cx + d)
                grad_params[3] = torch.sum(grad_output * b * activation_derivative)
            
            return grad_input, grad_params, None
    
    class GatedLinearUnitFunction(Function):
        """
        Gated Linear Unit: GLU(x) = x₁ ⊙ σ(x₂)
        Where x is split into two halves
        """
        
        @staticmethod
        def forward(ctx, input):
            """
            Forward: Split input and apply gating
            """
            # Split input into two halves
            dim = input.size(-1)
            assert dim % 2 == 0, "Input dimension must be even for GLU"
            
            split_dim = dim // 2
            x1 = input[..., :split_dim]
            x2 = input[..., split_dim:]
            
            # Apply sigmoid gating
            gate = torch.sigmoid(x2)
            result = x1 * gate
            
            # Save for backward
            ctx.save_for_backward(x1, x2, gate)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for GLU
            """
            x1, x2, gate = ctx.saved_tensors
            
            # Gradients
            grad_x1 = grad_output * gate
            grad_x2 = grad_output * x1 * gate * (1 - gate)
            
            # Concatenate gradients
            grad_input = torch.cat([grad_x1, grad_x2], dim=-1)
            
            return grad_input
    
    # Wrapper functions
    def swish(x, beta=1.0):
        return SwishFunction.apply(x, beta)
    
    def mish(x):
        return MishFunction.apply(x)
    
    def adaptive_activation(x, params, activation_type='tanh'):
        return AdaptiveActivationFunction.apply(x, params, activation_type)
    
    def glu(x):
        return GatedLinearUnitFunction.apply(x)
    
    return swish, mish, adaptive_activation, glu

# Create activation functions
swish, mish, adaptive_activation, glu = create_advanced_activation_functions()

print("🎨 Testing Advanced Activation Functions:")
print("-" * 50)

# Comprehensive activation function analysis
activation_analysis = {}

# Test input range
x_range = torch.linspace(-4, 4, 1000)
x_test = torch.linspace(-3, 3, 100, requires_grad=True)

# Test each activation function
activations_to_test = [
    ('Swish', lambda x: swish(x, beta=1.0)),
    ('Mish', mish),
    ('ReLU', F.relu),
    ('GELU', F.gelu),
    ('Tanh', torch.tanh),
    ('Swish β=0.5', lambda x: swish(x, beta=0.5)),
    ('Swish β=2.0', lambda x: swish(x, beta=2.0))
]

# Analyze each activation
for name, activation_func in activations_to_test:
    print(f"\n📊 Analyzing {name}:")
    
    # Clear gradients
    if x_test.grad is not None:
        x_test.grad.zero_()
    
    # Forward pass
    output = activation_func(x_test)
    
    # Backward pass for gradient analysis
    output.sum().backward()
    
    # Compute statistics
    activation_analysis[name] = {
        'output_range': [output.min().item(), output.max().item()],
        'output_mean': output.mean().item(),
        'output_std': output.std().item(),
        'gradient_norm': x_test.grad.norm().item(),
        'gradient_mean': x_test.grad.mean().item(),
        'gradient_std': x_test.grad.std().item(),
        'zero_gradient_ratio': (x_test.grad.abs() < 1e-6).float().mean().item()
    }
    
    print(f"  Output range: [{output.min():.3f}, {output.max():.3f}]")
    print(f"  Gradient norm: {x_test.grad.norm():.3f}")
    print(f"  Dead neuron ratio: {activation_analysis[name]['zero_gradient_ratio']:.3f}")

# Test adaptive activation function
print(f"\n📊 Testing Adaptive Activation:")
adaptive_params = torch.tensor([1.0, 0.5, 1.0, 0.0], requires_grad=True)

for activation_type in ['tanh', 'relu', 'elu']:
    if adaptive_params.grad is not None:
        adaptive_params.grad.zero_()
    if x_test.grad is not None:
        x_test.grad.zero_()
    
    adaptive_output = adaptive_activation(x_test, adaptive_params, activation_type)
    adaptive_output.sum().backward()
    
    print(f"  {activation_type.upper()}: Param grads = {adaptive_params.grad.numpy()}")

# Test GLU
print(f"\n📊 Testing Gated Linear Unit:")
x_glu = torch.randn(32, 128, requires_grad=True)  # Even dimension for GLU
glu_output = glu(x_glu)
glu_output.sum().backward()

print(f"  Input shape: {x_glu.shape}")
print(f"  Output shape: {glu_output.shape}")
print(f"  Gradient norm: {x_glu.grad.norm():.4f}")

# Create comprehensive visualization
def create_activation_visualization():
    """Create comprehensive activation function visualization"""
    
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    axes = axes.flatten()
    
    x_plot = torch.linspace(-4, 4, 1000)
    
    # Plot activation functions
    plot_idx = 0
    colors = plt.cm.Set3(np.linspace(0, 1, len(activations_to_test)))
    
    for i, (name, activation_func) in enumerate(activations_to_test):
        if plot_idx >= 6:  # First 6 plots for activations
            break
            
        with torch.no_grad():
            y = activation_func(x_plot)
        
        axes[plot_idx].plot(x_plot, y, linewidth=2, color=colors[i], label=name)
        axes[plot_idx].set_title(f'{name} Activation', fontweight='bold')
        axes[plot_idx].set_xlabel('x')
        axes[plot_idx].set_ylabel('f(x)')
        axes[plot_idx].grid(True, alpha=0.3)
        axes[plot_idx].legend()
        
        plot_idx += 1
    
    # Plot derivatives comparison
    ax_deriv = axes[6]
    x_deriv = torch.linspace(-3, 3, 200, requires_grad=True)
    
    for name, activation_func in activations_to_test[:4]:  # Plot first 4 derivatives
        if x_deriv.grad is not None:
            x_deriv.grad.zero_()
        
        y = activation_func(x_deriv)
        y.sum().backward(retain_graph=True)
        
        ax_deriv.plot(x_deriv.detach(), x_deriv.grad.detach(), 
                     linewidth=2, label=f"{name}'")
    
    ax_deriv.set_title('Activation Derivatives', fontweight='bold')
    ax_deriv.set_xlabel('x')
    ax_deriv.set_ylabel("f'(x)")
    ax_deriv.legend()
    ax_deriv.grid(True, alpha=0.3)
    
    # Plot gradient flow comparison
    ax_flow = axes[7]
    activation_names = list(activation_analysis.keys())[:6]
    gradient_norms = [activation_analysis[name]['gradient_norm'] for name in activation_names]
    
    bars = ax_flow.bar(range(len(activation_names)), gradient_norms, alpha=0.7)
    ax_flow.set_title('Gradient Flow Comparison', fontweight='bold')
    ax_flow.set_ylabel('Gradient Norm')
    ax_flow.set_xticks(range(len(activation_names)))
    ax_flow.set_xticklabels(activation_names, rotation=45)
    ax_flow.grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, norm in zip(bars, gradient_norms):
        height = bar.get_height()
        ax_flow.text(bar.get_x() + bar.get_width()/2., height + 0.01,
                    f'{norm:.2f}', ha='center', va='bottom')
    
    # Plot dead neuron analysis
    ax_dead = axes[8]
    dead_ratios = [activation_analysis[name]['zero_gradient_ratio'] for name in activation_names]
    
    bars_dead = ax_dead.bar(range(len(activation_names)), dead_ratios, 
                           alpha=0.7, color='red')
    ax_dead.set_title('Dead Neuron Analysis', fontweight='bold')
    ax_dead.set_ylabel('Dead Gradient Ratio')
    ax_dead.set_xticks(range(len(activation_names)))
    ax_dead.set_xticklabels(activation_names, rotation=45)
    ax_dead.grid(True, alpha=0.3)
    
    plt.suptitle('Advanced Custom Activation Functions Analysis', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(results_dir / 'advanced_activation_functions.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

# Create visualization
create_activation_visualization()

# Save activation analysis results
with open(results_dir / 'activation_function_analysis.json', 'w') as f:
    json.dump(activation_analysis, f, indent=2)

print(f"\n💾 Activation function analysis saved")
print(f"\n🎓 Key Insights:")
print(f"  • Swish and Mish provide smooth, non-monotonic behavior")
print(f"  • Adaptive activations can learn task-specific shapes")
print(f"  • GLU provides effective gating mechanisms")
print(f"  • Gradient flow varies significantly between activations")
```

## 4. Memory-Efficient Operations and Advanced Techniques

### 4.1 Memory-Efficient Custom Operations

```python
def create_memory_efficient_operations():
    """Create memory-efficient operations for large-scale training"""
    
    print("\n=== 4.1 Memory-Efficient Custom Operations ===\n")
    
    class MemoryEfficientMatMulFunction(Function):
        """
        Memory-efficient matrix multiplication with gradient checkpointing
        Trades computation for memory during backward pass
        """
        
        @staticmethod
        def forward(ctx, input, weight, bias=None, save_memory=True):
            """
            Forward pass with optional memory savings
            
            Args:
                input: Input tensor [batch_size, in_features]
                weight: Weight matrix [in_features, out_features]
                bias: Optional bias [out_features]
                save_memory: Whether to use memory-efficient mode
            """
            ctx.save_memory = save_memory
            
            # Compute output
            output = torch.mm(input, weight)
            if bias is not None:
                output += bias
            
            if save_memory:
                # Save only shapes and statistics for large tensors
                input_size_mb = input.numel() * input.element_size() / (1024 ** 2)
                weight_size_mb = weight.numel() * weight.element_size() / (1024 ** 2)
                
                if input_size_mb > 50 or weight_size_mb > 50:  # > 50MB
                    print(f"💾 Using memory-efficient mode (Input: {input_size_mb:.1f}MB, Weight: {weight_size_mb:.1f}MB)")
                    
                    # Save only essential information
                    ctx.input_shape = input.shape
                    ctx.weight_shape = weight.shape
                    ctx.has_bias = bias is not None
                    
                    # Save statistical information for approximation
                    ctx.input_mean = input.mean().item()
                    ctx.input_std = input.std().item()
                    ctx.weight_mean = weight.mean().item()
                    ctx.weight_std = weight.std().item()
                    
                    if bias is not None:
                        ctx.bias_mean = bias.mean().item()
                        ctx.bias_std = bias.std().item()
                    
                    # Don't save the actual tensors
                    ctx.saved_tensors = ()
                else:
                    # Save normally for small tensors
                    ctx.save_for_backward(input, weight, bias)
            else:
                # Standard mode - save everything
                ctx.save_for_backward(input, weight, bias)
            
            return output
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Memory-efficient backward pass
            """
            grad_input = grad_weight = grad_bias = None
            
            if ctx.save_memory and len(ctx.saved_tensors) == 0:
                # Memory-efficient mode: reconstruct or approximate
                print("⚡ Using gradient approximation for memory efficiency")
                
                # Create approximate tensors based on saved statistics
                input_approx = (
                    torch.randn(ctx.input_shape) * ctx.input_std + ctx.input_mean
                ).to(grad_output.device)
                
                weight_approx = (
                    torch.randn(ctx.weight_shape) * ctx.weight_std + ctx.weight_mean
                ).to(grad_output.device)
                
                # Compute approximate gradients
                if ctx.needs_input_grad[0]:
                    grad_input = torch.mm(grad_output, weight_approx.t())
                
                if ctx.needs_input_grad[1]:
                    grad_weight = torch.mm(input_approx.t(), grad_output)
                
                if ctx.needs_input_grad[2] and ctx.has_bias:
                    grad_bias = grad_output.sum(dim=0)
            
            else:
                # Standard backward pass
                saved_tensors = ctx.saved_tensors
                input = saved_tensors[0] if len(saved_tensors) > 0 else None
                weight = saved_tensors[1] if len(saved_tensors) > 1 else None
                bias = saved_tensors[2] if len(saved_tensors) > 2 else None
                
                if ctx.needs_input_grad[0] and input is not None and weight is not None:
                    grad_input = torch.mm(grad_output, weight.t())
                
                if ctx.needs_input_grad[1] and input is not None:
                    grad_weight = torch.mm(input.t(), grad_output)
                
                if ctx.needs_input_grad[2] and bias is not None:
                    grad_bias = grad_output.sum(dim=0)
            
            return grad_input, grad_weight, grad_bias, None
    
    class SequentialComputationFunction(Function):
        """
        Sequential computation to reduce peak memory usage
        Processes data in chunks to handle very large inputs
        """
        
        @staticmethod
        def forward(ctx, input, chunk_size=1024):
            """
            Process input in sequential chunks
            """
            batch_size = input.size(0)
            
            if batch_size <= chunk_size:
                # Small enough to process normally
                result = torch.sum(input ** 2, dim=1)
                ctx.save_for_backward(input)
                ctx.chunk_size = None
            else:
                # Process in chunks
                print(f"🔄 Processing {batch_size} samples in chunks of {chunk_size}")
                
                results = []
                input_chunks = []
                
                for i in range(0, batch_size, chunk_size):
                    end_idx = min(i + chunk_size, batch_size)
                    chunk = input[i:end_idx]
                    chunk_result = torch.sum(chunk ** 2, dim=1)
                    results.append(chunk_result)
                    input_chunks.append(chunk)
                
                result = torch.cat(results, dim=0)
                
                # Save chunks instead of full tensor
                ctx.save_for_backward(*input_chunks)
                ctx.chunk_size = chunk_size
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass handling chunked computation
            """
            if ctx.chunk_size is None:
                # Standard processing
                input, = ctx.saved_tensors
                grad_input = 2 * input * grad_output.unsqueeze(1)
            else:
                # Chunked processing
                input_chunks = ctx.saved_tensors
                grad_chunks = []
                
                start_idx = 0
                for chunk in input_chunks:
                    end_idx = start_idx + chunk.size(0)
                    chunk_grad_output = grad_output[start_idx:end_idx]
                    chunk_grad_input = 2 * chunk * chunk_grad_output.unsqueeze(1)
                    grad_chunks.append(chunk_grad_input)
                    start_idx = end_idx
                
                grad_input = torch.cat(grad_chunks, dim=0)
            
            return grad_input, None
    
    class GradientCheckpointFunction(Function):
        """
        Implement gradient checkpointing for arbitrary functions
        Recomputes forward pass during backward to save memory
        """
        
        @staticmethod
        def forward(ctx, input, function_layers, *function_args):
            """
            Forward pass with checkpointing
            
            Args:
                input: Input tensor
                function_layers: List of functions to apply sequentially
                function_args: Additional arguments for functions
            """
            ctx.function_layers = function_layers
            ctx.function_args = function_args
            
            # Only save input, not intermediate activations
            ctx.save_for_backward(input)
            
            # Compute forward pass
            x = input
            for i, layer_func in enumerate(function_layers):
                if i < len(function_args):
                    x = layer_func(x, function_args[i])
                else:
                    x = layer_func(x)
            
            return x
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass with recomputation
            """
            input, = ctx.saved_tensors
            function_layers = ctx.function_layers
            function_args = ctx.function_args
            
            # Recompute forward pass with gradients enabled
            x = input.detach().requires_grad_(True)
            
            for i, layer_func in enumerate(function_layers):
                if i < len(function_args):
                    x = layer_func(x, function_args[i])
                else:
                    x = layer_func(x)
            
            # Compute gradients
            x.backward(grad_output)
            
            return input.grad, None, None
    
    # Wrapper functions
    def memory_efficient_matmul(input, weight, bias=None, save_memory=True):
        return MemoryEfficientMatMulFunction.apply(input, weight, bias, save_memory)
    
    def sequential_computation(input, chunk_size=1024):
        return SequentialComputationFunction.apply(input, chunk_size)
    
    def gradient_checkpoint(input, function_layers, *function_args):
        return GradientCheckpointFunction.apply(input, function_layers, *function_args)
    
    return memory_efficient_matmul, sequential_computation, gradient_checkpoint

# Create memory-efficient operations
memory_efficient_matmul, sequential_computation, gradient_checkpoint = create_memory_efficient_operations()

print("💾 Testing Memory-Efficient Operations:")
print("-" * 50)

# Test memory-efficient matrix multiplication
print("\n📊 Testing Memory-Efficient Matrix Multiplication:")

# Create test data of varying sizes
test_sizes = [
    (32, 64, 32, "Small"),
    (512, 256, 128, "Medium"), 
    (1024, 1024, 512, "Large")
]

memory_results = {}

for batch_size, in_features, out_features, size_name in test_sizes:
    print(f"\n  Testing {size_name} size: {batch_size}x{in_features} -> {out_features}")
    
    # Create test tensors
    input_tensor = torch.randn(batch_size, in_features, requires_grad=True)
    weight_tensor = torch.randn(in_features, out_features, requires_grad=True)
    bias_tensor = torch.randn(out_features, requires_grad=True)
    
    # Test memory-efficient version
    start_time = time.time()
    
    output_efficient = memory_efficient_matmul(
        input_tensor, weight_tensor, bias_tensor, save_memory=True
    )
    
    loss_efficient = output_efficient.sum()
    loss_efficient.backward()
    
    efficient_time = time.time() - start_time
    
    # Clear gradients for standard test
    input_tensor.grad = None
    weight_tensor.grad = None
    bias_tensor.grad = None
    
    # Test standard version
    start_time = time.time()
    
    output_standard = memory_efficient_matmul(
        input_tensor, weight_tensor, bias_tensor, save_memory=False
    )
    
    loss_standard = output_standard.sum()
    loss_standard.backward()
    
    standard_time = time.time() - start_time
    
    # Compare results
    output_diff = (output_efficient - output_standard).abs().max().item()
    
    memory_results[size_name] = {
        'output_difference': output_diff,
        'efficient_time': efficient_time * 1000,  # ms
        'standard_time': standard_time * 1000,
        'time_ratio': efficient_time / standard_time,
        'batch_size': batch_size,
        'parameters': in_features * out_features + out_features
    }
    
    print(f"    Output difference: {output_diff:.2e}")
    print(f"    Efficient time: {efficient_time*1000:.2f}ms")
    print(f"    Standard time: {standard_time*1000:.2f}ms")
    print(f"    Time ratio: {efficient_time/standard_time:.2f}x")

# Test sequential computation
print(f"\n📊 Testing Sequential Computation:")

large_input = torch.randn(5000, 100, requires_grad=True)
print(f"  Input shape: {large_input.shape}")

# Test with different chunk sizes
chunk_sizes = [512, 1024, 2048]

for chunk_size in chunk_sizes:
    if large_input.grad is not None:
        large_input.grad.zero_()
    
    start_time = time.time()
    seq_output = sequential_computation(large_input, chunk_size=chunk_size)
    seq_output.sum().backward()
    seq_time = time.time() - start_time
    
    print(f"    Chunk size {chunk_size}: {seq_time*1000:.2f}ms, "
          f"Grad norm: {large_input.grad.norm():.4f}")

# Test gradient checkpointing
print(f"\n📊 Testing Gradient Checkpointing:")

def test_layer1(x):
    return torch.relu(x)

def test_layer2(x):
    return x ** 2

def test_layer3(x):
    return torch.sin(x)

test_input = torch.randn(100, 50, requires_grad=True)
layers = [test_layer1, test_layer2, test_layer3]

# Test checkpointed version
start_time = time.time()
checkpoint_output = gradient_checkpoint(test_input, layers)
checkpoint_output.sum().backward()
checkpoint_time = time.time() - start_time

checkpoint_grad = test_input.grad.clone()
test_input.grad = None

# Test standard version
start_time = time.time()
x = test_input
for layer in layers:
    x = layer(x)
x.sum().backward()
standard_time = time.time() - start_time

print(f"  Checkpoint time: {checkpoint_time*1000:.2f}ms")
print(f"  Standard time: {standard_time*1000:.2f}ms")
print(f"  Gradient difference: {(checkpoint_grad - test_input.grad).abs().max():.2e}")

# Save memory efficiency results
memory_efficiency_results = {
    'matrix_multiplication': memory_results,
    'sequential_computation': {
        'input_size': large_input.shape,
        'chunk_performance': {str(cs): f"Tested chunk size {cs}" for cs in chunk_sizes}
    },
    'gradient_checkpointing': {
        'checkpoint_time_ms': checkpoint_time * 1000,
        'standard_time_ms': standard_time * 1000,
        'time_overhead': checkpoint_time / standard_time,
        'gradient_accuracy': (checkpoint_grad - test_input.grad).abs().max().item()
    }
}

with open(results_dir / 'memory_efficiency_results.json', 'w') as f:
    json.dump(memory_efficiency_results, f, indent=2)

print(f"\n💾 Memory efficiency results saved")
print(f"\n🎓 Memory Efficiency Insights:")
print(f"  • Memory-efficient operations trade computation for memory")
print(f"  • Sequential processing handles arbitrarily large inputs")
print(f"  • Gradient checkpointing reduces memory at cost of recomputation")
print(f"  • Critical for training very large models")
```

## 5. Non-Differentiable Operations and Advanced Techniques

### 5.1 Handling Non-Differentiable Operations

```python
def create_non_differentiable_operations():
    """Create techniques for handling non-differentiable operations"""
    
    print("\n=== 5.1 Non-Differentiable Operations Mastery ===\n")
    
    class StraightThroughEstimatorFunction(Function):
        """
        Straight-Through Estimator for quantization
        Forward: quantize, Backward: pass gradients through unchanged
        """
        
        @staticmethod
        def forward(ctx, input, num_bits=8, method='uniform'):
            """
            Forward pass: quantize input to specified number of bits
            
            Args:
                input: Input tensor
                num_bits: Number of quantization bits
                method: Quantization method ('uniform', 'log', 'learned')
            """
            ctx.method = method
            ctx.num_bits = num_bits
            
            if method == 'uniform':
                # Uniform quantization
                min_val = input.min()
                max_val = input.max()
                
                # Avoid division by zero
                if max_val == min_val:
                    quantized = input.clone()
                else:
                    num_levels = 2 ** num_bits
                    scale = (max_val - min_val) / (num_levels - 1)
                    
                    # Quantize
                    quantized = torch.round((input - min_val) / scale) * scale + min_val
                    quantized = torch.clamp(quantized, min_val, max_val)
                
                # Save quantization statistics
                ctx.quantization_error = (quantized - input).abs().mean().item()
                
            elif method == 'log':
                # Logarithmic quantization (for weights)
                sign = torch.sign(input)
                abs_input = torch.abs(input)
                
                # Avoid log(0)
                eps = 1e-8
                log_input = torch.log(abs_input + eps)
                
                min_log = log_input.min()
                max_log = log_input.max()
                
                if max_log > min_log:
                    num_levels = 2 ** num_bits
                    scale = (max_log - min_log) / (num_levels - 1)
                    quantized_log = torch.round((log_input - min_log) / scale) * scale + min_log
                    quantized = sign * torch.exp(quantized_log)
                else:
                    quantized = input.clone()
                
                ctx.quantization_error = (quantized - input).abs().mean().item()
            
            else:
                raise ValueError(f"Unknown quantization method: {method}")
            
            return quantized
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Straight-through estimator: pass gradients unchanged
            This assumes quantization error is small enough to ignore
            """
            # Simply pass through the gradient
            return grad_output, None, None
    
    class GumbelSoftmaxFunction(Function):
        """
        Gumbel-Softmax for differentiable discrete sampling
        Provides a continuous relaxation of discrete distributions
        """
        
        @staticmethod
        def forward(ctx, logits, temperature=1.0, hard=False, dim=-1):
            """
            Gumbel-Softmax forward pass
            
            Args:
                logits: Unnormalized log probabilities
                temperature: Temperature parameter (lower = more discrete)
                hard: Whether to use hard (one-hot) or soft sampling
                dim: Dimension to apply softmax
            """
            # Sample Gumbel noise
            gumbel_noise = -torch.log(-torch.log(
                torch.rand_like(logits).clamp(min=1e-10, max=1-1e-10)
            ))
            
            # Add noise and apply temperature
            y = (logits + gumbel_noise) / temperature
            
            # Softmax
            y_soft = F.softmax(y, dim=dim)
            
            if hard:
                # Hard assignment (one-hot) using straight-through estimator
                index = y_soft.max(dim=dim, keepdim=True)[1]
                y_hard = torch.zeros_like(y_soft).scatter_(dim, index, 1.0)
                
                # Use straight-through: hard forward, soft backward
                result = y_hard - y_soft.detach() + y_soft
            else:
                result = y_soft
            
            ctx.save_for_backward(result)
            ctx.temperature = temperature
            ctx.dim = dim
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for Gumbel-Softmax
            """
            y_soft, = ctx.saved_tensors
            temperature = ctx.temperature
            dim = ctx.dim
            
            # Gradient computation for softmax with temperature
            grad_input = grad_output / temperature
            
            return grad_input, None, None, None
    
    class SoftTopKFunction(Function):
        """
        Differentiable approximation to Top-K operation
        Uses continuous relaxation for gradient flow
        """
        
        @staticmethod
        def forward(ctx, input, k, temperature=1.0, method='sigmoid'):
            """
            Soft Top-K selection
            
            Args:
                input: Input tensor
                k: Number of top elements to select
                temperature: Temperature for softness
                method: Method for soft selection ('sigmoid', 'softmax')
            """
            ctx.k = k
            ctx.temperature = temperature
            ctx.method = method
            
            if method == 'sigmoid':
                # Use sigmoid-based soft thresholding
                sorted_values, _ = torch.sort(input, dim=-1, descending=True)
                
                if k < input.size(-1):
                    threshold = sorted_values[..., k-1:k]  # k-th largest value
                else:
                    threshold = sorted_values[..., -1:]  # Smallest value
                
                # Soft mask using sigmoid
                soft_mask = torch.sigmoid((input - threshold) / temperature)
                result = input * soft_mask
                
                ctx.save_for_backward(input, soft_mask, threshold)
                
            elif method == 'softmax':
                # Use softmax-based selection
                softmax_weights = F.softmax(input / temperature, dim=-1)
                
                # Get top-k indices
                _, top_indices = torch.topk(input, k, dim=-1)
                
                # Create soft selection based on softmax weights
                result = input * softmax_weights * k  # Scale by k to maintain magnitude
                
                ctx.save_for_backward(input, softmax_weights)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for soft Top-K
            """
            if ctx.method == 'sigmoid':
                input, soft_mask, threshold = ctx.saved_tensors
                temperature = ctx.temperature
                
                # Gradient of sigmoid mask
                sigmoid_grad = soft_mask * (1 - soft_mask) / temperature
                
                # Chain rule application
                grad_input = grad_output * (soft_mask + input * sigmoid_grad)
                
            elif ctx.method == 'softmax':
                input, softmax_weights = ctx.saved_tensors
                temperature = ctx.temperature
                k = ctx.k
                
                # Gradient through softmax weights
                grad_input = grad_output * softmax_weights * k / temperature
            
            return grad_input, None, None, None
    
    class DifferentiableRoundingFunction(Function):
        """
        Differentiable approximation to rounding operation
        Uses smooth approximations for gradient flow
        """
        
        @staticmethod
        def forward(ctx, input, method='sigmoid', sharpness=10.0):
            """
            Differentiable rounding
            
            Args:
                input: Input tensor
                method: Approximation method ('sigmoid', 'tanh', 'polynomial')
                sharpness: Controls approximation sharpness
            """
            ctx.method = method
            ctx.sharpness = sharpness
            
            if method == 'sigmoid':
                # Sigmoid-based approximation
                fractional_part = input - torch.floor(input)
                smooth_round = torch.sigmoid(sharpness * (fractional_part - 0.5))
                result = torch.floor(input) + smooth_round
                
            elif method == 'tanh':
                # Tanh-based approximation  
                fractional_part = input - torch.floor(input)
                smooth_round = 0.5 * (1 + torch.tanh(sharpness * (fractional_part - 0.5)))
                result = torch.floor(input) + smooth_round
                
            elif method == 'polynomial':
                # Polynomial approximation (3rd order)
                fractional_part = input - torch.floor(input)
                # Smooth step function: 3t² - 2t³
                smooth_step = 3 * fractional_part**2 - 2 * fractional_part**3
                result = torch.floor(input) + smooth_step
                
            else:
                raise ValueError(f"Unknown method: {method}")
            
            ctx.save_for_backward(input, result)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass using smooth derivative
            """
            input, result = ctx.saved_tensors
            method = ctx.method
            sharpness = ctx.sharpness
            
            if method == 'sigmoid':
                fractional_part = input - torch.floor(input)
                sigmoid_val = torch.sigmoid(sharpness * (fractional_part - 0.5))
                grad_smooth = sharpness * sigmoid_val * (1 - sigmoid_val)
                
            elif method == 'tanh':
                fractional_part = input - torch.floor(input)
                tanh_val = torch.tanh(sharpness * (fractional_part - 0.5))
                grad_smooth = 0.5 * sharpness * (1 - tanh_val**2)
                
            elif method == 'polynomial':
                fractional_part = input - torch.floor(input)
                grad_smooth = 6 * fractional_part * (1 - fractional_part)
            
            grad_input = grad_output * grad_smooth
            
            return grad_input, None, None
    
    # Wrapper functions
    def straight_through_quantizer(input, num_bits=8, method='uniform'):
        return StraightThroughEstimatorFunction.apply(input, num_bits, method)
    
    def gumbel_softmax(logits, temperature=1.0, hard=False, dim=-1):
        return GumbelSoftmaxFunction.apply(logits, temperature, hard, dim)
    
    def soft_topk(input, k, temperature=1.0, method='sigmoid'):
        return SoftTopKFunction.apply(input, k, temperature, method)
    
    def differentiable_round(input, method='sigmoid', sharpness=10.0):
        return DifferentiableRoundingFunction.apply(input, method, sharpness)
    
    return straight_through_quantizer, gumbel_softmax, soft_topk, differentiable_round

# Create non-differentiable operation handlers
straight_through_quantizer, gumbel_softmax, soft_topk, differentiable_round = create_non_differentiable_operations()

print("🚫 Testing Non-Differentiable Operations:")
print("-" * 50)

# Test quantization with straight-through estimator
print("\n📊 Testing Quantization with Straight-Through Estimator:")

test_tensor = torch.randn(100, 50, requires_grad=True)
print(f"Original range: [{test_tensor.min():.3f}, {test_tensor.max():.3f}]")

quantization_results = {}

for num_bits in [2, 4, 8, 16]:
    if test_tensor.grad is not None:
        test_tensor.grad.zero_()
    
    # Test uniform quantization
    quantized = straight_through_quantizer(test_tensor, num_bits=num_bits, method='uniform')
    quantized.sum().backward()
    
    quantization_error = (quantized - test_tensor).abs().mean().item()
    gradient_flow = test_tensor.grad.norm().item()
    
    quantization_results[f'{num_bits}_bit'] = {
        'quantization_error': quantization_error,
        'gradient_norm': gradient_flow,
        'output_range': [quantized.min().item(), quantized.max().item()]
    }
    
    print(f"  {num_bits}-bit: Error={quantization_error:.4f}, "
          f"Grad norm={gradient_flow:.4f}")

# Test Gumbel-Softmax
print(f"\n📊 Testing Gumbel-Softmax:")

logits = torch.randn(32, 10, requires_grad=True)
temperatures = [2.0, 1.0, 0.5, 0.1]

gumbel_results = {}

for temp in temperatures:
    if logits.grad is not None:
        logits.grad.zero_()
    
    # Test soft sampling
    soft_sample = gumbel_softmax(logits, temperature=temp, hard=False)
    hard_sample = gumbel_softmax(logits, temperature=temp, hard=True)
    
    soft_sample.sum().backward()
    soft_grad_norm = logits.grad.norm().item()
    
    logits.grad.zero_()
    hard_sample.sum().backward()
    hard_grad_norm = logits.grad.norm().item()
    
    # Compute entropy (measure of discreteness)
    soft_entropy = -(soft_sample * torch.log(soft_sample + 1e-10)).sum(dim=-1).mean().item()
    hard_sparsity = (hard_sample > 0.5).float().sum(dim=-1).mean().item()
    
    gumbel_results[f'temp_{temp}'] = {
        'soft_entropy': soft_entropy,
        'hard_sparsity': hard_sparsity,
        'soft_grad_norm': soft_grad_norm,
        'hard_grad_norm': hard_grad_norm
    }
    
    print(f"  Temp {temp}: Entropy={soft_entropy:.3f}, "
          f"Sparsity={hard_sparsity:.1f}, "
          f"Soft grad={soft_grad_norm:.3f}")

# Test Soft Top-K
print(f"\n📊 Testing Soft Top-K:")

topk_input = torch.randn(16, 20, requires_grad=True)
k_values = [3, 5, 10]

topk_results = {}

for k in k_values:
    if topk_input.grad is not None:
        topk_input.grad.zero_()
    
    # Compare hard vs soft top-k
    hard_topk_vals, hard_topk_indices = torch.topk(topk_input, k, dim=-1)
    soft_topk_result = soft_topk(topk_input, k, temperature=0.1, method='sigmoid')
    
    soft_topk_result.sum().backward()
    
    # Analyze selection properties
    soft_magnitude = soft_topk_result.norm(dim=-1).mean().item()
    hard_magnitude = hard_topk_vals.norm(dim=-1).mean().item()
    
    topk_results[f'k_{k}'] = {
        'soft_magnitude': soft_magnitude,
        'hard_magnitude': hard_magnitude,
        'gradient_norm': topk_input.grad.norm().item(),
        'selection_ratio': (soft_topk_result.abs() > 1e-3).float().mean().item()
    }
    
    print(f"  k={k}: Soft mag={soft_magnitude:.3f}, "
          f"Hard mag={hard_magnitude:.3f}, "
          f"Selection ratio={topk_results[f'k_{k}']['selection_ratio']:.3f}")

# Test Differentiable Rounding
print(f"\n📊 Testing Differentiable Rounding:")

round_input = torch.randn(100, requires_grad=True) * 5  # Range roughly -15 to 15
methods = ['sigmoid', 'tanh', 'polynomial']

rounding_results = {}

for method in methods:
    if round_input.grad is not None:
        round_input.grad.zero_()
    
    rounded = differentiable_round(round_input, method=method, sharpness=10.0)
    rounded.sum().backward()
    
    # Compare with true rounding
    true_rounded = torch.round(round_input.detach())
    rounding_error = (rounded - true_rounded).abs().mean().item()
    
    rounding_results[method] = {
        'rounding_error': rounding_error,
        'gradient_norm': round_input.grad.norm().item(),
        'output_range': [rounded.min().item(), rounded.max().item()]
    }
    
    print(f"  {method}: Error={rounding_error:.4f}, "
          f"Grad norm={round_input.grad.norm():.4f}")

# Create comprehensive visualization
def create_non_differentiable_visualization():
    """Visualize non-differentiable operations and their approximations"""
    
    fig, axes = plt.subplots(3, 3, figsize=(18, 15))
    
    # 1. Quantization effects
    x_quant = torch.linspace(-3, 3, 1000)
    
    for i, bits in enumerate([2, 4, 8]):
        with torch.no_grad():
            quantized = straight_through_quantizer(x_quant, num_bits=bits)
        
        axes[0, i].plot(x_quant, x_quant, 'b--', alpha=0.5, label='Original')
        axes[0, i].plot(x_quant, quantized, 'r-', linewidth=2, label=f'{bits}-bit')
        axes[0, i].set_title(f'{bits}-bit Quantization', fontweight='bold')
        axes[0, i].set_xlabel('Input')
        axes[0, i].set_ylabel('Output')
        axes[0, i].legend()
        axes[0, i].grid(True, alpha=0.3)
    
    # 2. Gumbel-Softmax temperature effects
    logits_example = torch.tensor([[2.0, 1.0, 0.5, 0.1, -0.5]])
    temps = [2.0, 1.0, 0.1]
    
    for i, temp in enumerate(temps):
        with torch.no_grad():
            probs = gumbel_softmax(logits_example, temperature=temp, hard=False)[0]
        
        axes[1, i].bar(range(len(probs)), probs, alpha=0.7, 
                      color=plt.cm.viridis(i/len(temps)))
        axes[1, i].set_title(f'Gumbel-Softmax T={temp}', fontweight='bold')
        axes[1, i].set_xlabel('Category')
        axes[1, i].set_ylabel('Probability')
        axes[1, i].grid(True, alpha=0.3)
    
    # 3. Soft Top-K selection
    input_vals = torch.tensor([3.0, 1.5, 4.2, 0.8, 2.1, 3.8, 1.2, 2.9])
    k_vals = [2, 3, 5]
    
    for i, k in enumerate(k_vals):
        with torch.no_grad():
            soft_result = soft_topk(input_vals.unsqueeze(0), k, temperature=0.1)[0]
            hard_topk_vals, _ = torch.topk(input_vals, k)
        
        x_pos = range(len(input_vals))
        axes[2, i].bar(x_pos, input_vals, alpha=0.5, label='Original', color='blue')
        axes[2, i].bar(x_pos, soft_result, alpha=0.8, label=f'Soft Top-{k}', color='red')
        
        axes[2, i].set_title(f'Soft Top-{k} Selection', fontweight='bold')
        axes[2, i].set_xlabel('Index')
        axes[2, i].set_ylabel('Value')
        axes[2, i].legend()
        axes[2, i].grid(True, alpha=0.3)
    
    plt.suptitle('Non-Differentiable Operations Analysis', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(results_dir / 'non_differentiable_operations.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

# Create visualization
create_non_differentiable_visualization()

# Save non-differentiable operations results
non_diff_results = {
    'quantization': quantization_results,
    'gumbel_softmax': gumbel_results,
    'soft_topk': topk_results,
    'differentiable_rounding': rounding_results
}

with open(results_dir / 'non_differentiable_operations_results.json', 'w') as f:
    json.dump(non_diff_results, f, indent=2)

print(f"\n💾 Non-differentiable operations results saved")
print(f"\n🎓 Key Techniques Summary:")
print(f"  • Straight-Through Estimator: Pass gradients unchanged")
print(f"  • Gumbel-Softmax: Continuous relaxation of discrete sampling")
print(f"  • Soft approximations: Replace hard operations with smooth versions")
print(f"  • Temperature annealing: Start soft, gradually make harder")
```

## 6. Comprehensive Testing and Validation Framework

### 6.1 Advanced Gradient Checking and Validation

```python
def create_comprehensive_testing_framework():
    """Create advanced testing framework for custom autograd functions"""
    
    print("\n=== 6.1 Comprehensive Testing and Validation Framework ===\n")
    
    class GradientChecker:
        """Advanced gradient checking with multiple validation methods"""
        
        def __init__(self, eps=1e-5, tolerance=1e-4, verbose=True):
            self.eps = eps
            self.tolerance = tolerance
            self.verbose = verbose
            self.test_results = {}
        
        def numerical_gradient(self, func, inputs, eps=None):
            """
            Compute numerical gradient using finite differences
            
            Args:
                func: Function to compute gradient for
                inputs: List of input tensors
                eps: Finite difference step size
                
            Returns:
                List of numerical gradients
            """
            if eps is None:
                eps = self.eps
            
            numerical_grads = []
            
            for input_idx, input_tensor in enumerate(inputs):
                grad = torch.zeros_like(input_tensor)
                
                # Flatten for easier iteration
                flat_input = input_tensor.view(-1)
                flat_grad = grad.view(-1)
                
                for i in range(flat_input.numel()):
                    # Positive perturbation
                    flat_input[i] += eps
                    inputs_pos = [inp if j != input_idx else input_tensor 
                                 for j, inp in enumerate(inputs)]
                    
                    try:
                        loss_pos = func(*inputs_pos).sum()
                    except Exception as e:
                        if self.verbose:
                            print(f"Warning: Error in positive perturbation: {e}")
                        loss_pos = torch.tensor(0.0)
                    
                    # Negative perturbation
                    flat_input[i] -= 2 * eps
                    inputs_neg = [inp if j != input_idx else input_tensor 
                                 for j, inp in enumerate(inputs)]
                    
                    try:
                        loss_neg = func(*inputs_neg).sum()
                    except Exception as e:
                        if self.verbose:
                            print(f"Warning: Error in negative perturbation: {e}")
                        loss_neg = torch.tensor(0.0)
                    
                    # Central difference
                    flat_grad[i] = (loss_pos - loss_neg) / (2 * eps)
                    
                    # Restore original value
                    flat_input[i] += eps
                
                numerical_grads.append(grad)
            
            return numerical_grads
        
        def check_gradients(self, func, inputs, test_name="", output_shape=None):
            """
            Comprehensive gradient checking
            
            Args:
                func: Custom function to test
                inputs: List of input tensors
                test_name: Name for the test
                output_shape: Expected output shape
                
            Returns:
                Dictionary with test results
            """
            if self.verbose:
                print(f"🔍 Testing function: {test_name}")
                print(f"📊 Input shapes: {[inp.shape for inp in inputs]}")
            
            # Ensure inputs require gradients
            for inp in inputs:
                inp.requires_grad_(True)
                if inp.grad is not None:
                    inp.grad.zero_()
            
            # Test forward pass
            try:
                output = func(*inputs)
                forward_success = True
                
                if output_shape and output.shape != output_shape:
                    print(f"⚠️ Warning: Expected shape {output_shape}, got {output.shape}")
                
            except Exception as e:
                if self.verbose:
                    print(f"❌ Forward pass failed: {e}")
                return {'passed': False, 'error': f"Forward pass failed: {e}"}
            
            # Test backward pass
            try:
                loss = output.sum()
                loss.backward()
                backward_success = True
                
                analytical_grads = [inp.grad.clone() for inp in inputs]
                
            except Exception as e:
                if self.verbose:
                    print(f"❌ Backward pass failed: {e}")
                return {'passed': False, 'error': f"Backward pass failed: {e}"}
            
            # Clear gradients for numerical computation
            for inp in inputs:
                inp.grad.zero_()
            
            # Compute numerical gradients
            if self.verbose:
                print("🧮 Computing numerical gradients...")
            
            numerical_grads = self.numerical_gradient(func, inputs)
            
            # Compare gradients
            results = {
                'passed': True,
                'max_error': 0.0,
                'relative_errors': [],
                'absolute_errors': [],
                'input_analyses': []
            }
            
            for i, (analytical, numerical) in enumerate(zip(analytical_grads, numerical_grads)):
                # Handle zero gradients
                abs_error = (analytical - numerical).abs()
                rel_error = abs_error / (numerical.abs() + self.eps)
                
                max_abs_error = abs_error.max().item()
                max_rel_error = rel_error.max().item()
                mean_abs_error = abs_error.mean().item()
                mean_rel_error = rel_error.mean().item()
                
                # Check for NaN or Inf
                has_nan = torch.isnan(analytical).any() or torch.isnan(numerical).any()
                has_inf = torch.isinf(analytical).any() or torch.isinf(numerical).any()
                
                input_analysis = {
                    'input_index': i,
                    'max_absolute_error': max_abs_error,
                    'max_relative_error': max_rel_error,
                    'mean_absolute_error': mean_abs_error,
                    'mean_relative_error': mean_rel_error,
                    'has_nan': has_nan,
                    'has_inf': has_inf,
                    'analytical_norm': analytical.norm().item(),
                    'numerical_norm': numerical.norm().item()
                }
                
                results['input_analyses'].append(input_analysis)
                results['absolute_errors'].append(max_abs_error)
                results['relative_errors'].append(max_rel_error)
                results['max_error'] = max(results['max_error'], max_abs_error)
                
                # Determine pass/fail
                passed = (max_abs_error < self.tolerance and 
                         max_rel_error < self.tolerance and 
                         not has_nan and not has_inf)
                
                if not passed:
                    results['passed'] = False
                
                if self.verbose:
                    status = "✅ PASS" if passed else "❌ FAIL"
                    print(f"  Input {i}: Max abs={max_abs_error:.2e}, "
                          f"Max rel={max_rel_error:.2e} - {status}")
            
            overall_status = "✅ PASSED" if results['passed'] else "❌ FAILED"
            if self.verbose:
                print(f"\nOverall: {overall_status} (tolerance: {self.tolerance:.0e})")
            
            # Store results
            self.test_results[test_name] = results
            
            return results
        
        def performance_benchmark(self, func, inputs, num_iterations=1000):
            """
            Benchmark performance of custom function
            
            Args:
                func: Function to benchmark
                inputs: Input tensors
                num_iterations: Number of iterations for timing
                
            Returns:
                Performance metrics
            """
            if self.verbose:
                print(f"⚡ Performance benchmarking ({num_iterations} iterations)...")
            
            # Warmup
            for _ in range(10):
                output = func(*inputs)
                output.sum().backward()
                for inp in inputs:
                    if inp.grad is not None:
                        inp.grad.zero_()
            
            # Forward pass timing
            start_time = time.time()
            for _ in range(num_iterations):
                output = func(*inputs)
            forward_time = (time.time() - start_time) / num_iterations
            
            # Backward pass timing
            start_time = time.time()
            for _ in range(num_iterations):
                output = func(*inputs)
                output.sum().backward()
                for inp in inputs:
                    if inp.grad is not None:
                        inp.grad.zero_()
            backward_time = (time.time() - start_time) / num_iterations - forward_time
            
            performance_metrics = {
                'forward_time_ms': forward_time * 1000,
                'backward_time_ms': backward_time * 1000,
                'total_time_ms': (forward_time + backward_time) * 1000,
                'memory_usage_mb': sum(inp.numel() * inp.element_size() 
                                     for inp in inputs) / (1024**2)
            }
            
            if self.verbose:
                print(f"  Forward: {forward_time*1000:.3f}ms")
                print(f"  Backward: {backward_time*1000:.3f}ms")
                print(f"  Memory: {performance_metrics['memory_usage_mb']:.1f}MB")
            
            return performance_metrics
        
        def comprehensive_test_suite(self, function_tests):
            """
            Run comprehensive test suite on multiple functions
            
            Args:
                function_tests: List of (func, inputs, name) tuples
                
            Returns:
                Complete test suite results
            """
            print(f"🧪 Running Comprehensive Test Suite")
            print(f"=" * 60)
            
            suite_results = {}
            
            for func, inputs, name in function_tests:
                print(f"\n📋 Testing: {name}")
                print(f"-" * 40)
                
                # Gradient checking
                grad_results = self.check_gradients(func, inputs, name)
                
                # Performance benchmarking
                perf_results = self.performance_benchmark(func, inputs)
                
                # Combine results
                suite_results[name] = {
                    'gradient_check': grad_results,
                    'performance': perf_results,
                    'passed': grad_results.get('passed', False)
                }
            
            # Summary
            print(f"\n📊 Test Suite Summary:")
            print(f"=" * 40)
            
            total_tests = len(function_tests)
            passed_tests = sum(1 for result in suite_results.values() 
                             if result['passed'])
            
            print(f"Total tests: {total_tests}")
            print(f"Passed: {passed_tests}")
            print(f"Failed: {total_tests - passed_tests}")
            print(f"Success rate: {passed_tests/total_tests*100:.1f}%")
            
            return suite_results
    
    return GradientChecker

# Create comprehensive testing framework
GradientChecker = create_comprehensive_testing_framework()

print("🧪 Comprehensive Function Testing:")
print("-" * 50)

# Initialize gradient checker
checker = GradientChecker(eps=1e-5, tolerance=1e-4, verbose=True)

# Define test functions to validate
test_functions = [
    # Basic functions
    (lambda x: x**2, [torch.randn(5, 3, requires_grad=True)], "Square Function"),
    
    # Multi-input functions  
    (lambda x, w: torch.sum(x * w, dim=1), 
     [torch.randn(4, 6, requires_grad=True), torch.randn(6, requires_grad=True)], 
     "Weighted Sum"),
    
    # Custom activations
    (swish, [torch.randn(10, requires_grad=True)], "Swish Activation"),
    (mish, [torch.randn(10, requires_grad=True)], "Mish Activation"),
    
    # Memory-efficient operations
    (lambda x, w: memory_efficient_matmul(x, w, save_memory=False),
     [torch.randn(8, 10, requires_grad=True), torch.randn(10, 5, requires_grad=True)],
     "Memory Efficient MatMul"),
    
    # Non-differentiable operations
    (lambda x: straight_through_quantizer(x, num_bits=8),
     [torch.randn(6, 4, requires_grad=True)], 
     "Straight-Through Quantizer"),
    
    (lambda x: gumbel_softmax(x, temperature=1.0, hard=False),
     [torch.randn(3, 5, requires_grad=True)],
     "Gumbel Softmax"),
]

# Run comprehensive test suite
test_suite_results = checker.comprehensive_test_suite(test_functions)

# Create detailed analysis visualization
def create_testing_analysis_visualization(suite_results):
    """Create comprehensive visualization of testing results"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Extract data for visualization
    function_names = list(suite_results.keys())
    max_errors = [suite_results[name]['gradient_check'].get('max_error', 0) 
                  for name in function_names]
    forward_times = [suite_results[name]['performance']['forward_time_ms'] 
                    for name in function_names]
    backward_times = [suite_results[name]['performance']['backward_time_ms'] 
                     for name in function_names]
    memory_usage = [suite_results[name]['performance']['memory_usage_mb'] 
                   for name in function_names]
    passed_status = [suite_results[name]['passed'] for name in function_names]
    
    # 1. Gradient error analysis
    colors = ['green' if passed else 'red' for passed in passed_status]
    bars1 = axes[0,0].bar(range(len(function_names)), max_errors, color=colors, alpha=0.7)
    axes[0,0].set_title('Maximum Gradient Errors', fontweight='bold')
    axes[0,0].set_ylabel('Max Error')
    axes[0,0].set_yscale('log')
    axes[0,0].set_xticks(range(len(function_names)))
    axes[0,0].set_xticklabels(function_names, rotation=45, ha='right')
    axes[0,0].grid(True, alpha=0.3)
    
    # Add tolerance line
    axes[0,0].axhline(y=checker.tolerance, color='red', linestyle='--', 
                     label=f'Tolerance: {checker.tolerance:.0e}')
    axes[0,0].legend()
    
    # 2. Forward pass performance
    axes[0,1].bar(range(len(function_names)), forward_times, alpha=0.7, color='blue')
    axes[0,1].set_title('Forward Pass Performance', fontweight='bold')
    axes[0,1].set_ylabel('Time (ms)')
    axes[0,1].set_xticks(range(len(function_names)))
    axes[0,1].set_xticklabels(function_names, rotation=45, ha='right')
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Backward pass performance
    axes[0,2].bar(range(len(function_names)), backward_times, alpha=0.7, color='orange')
    axes[0,2].set_title('Backward Pass Performance', fontweight='bold')
    axes[0,2].set_ylabel('Time (ms)')
    axes[0,2].set_xticks(range(len(function_names)))
    axes[0,2].set_xticklabels(function_names, rotation=45, ha='right')
    axes[0,2].grid(True, alpha=0.3)
    
    # 4. Memory usage
    axes[1,0].bar(range(len(function_names)), memory_usage, alpha=0.7, color='purple')
    axes[1,0].set_title('Memory Usage', fontweight='bold')
    axes[1,0].set_ylabel('Memory (MB)')
    axes[1,0].set_xticks(range(len(function_names)))
    axes[1,0].set_xticklabels(function_names, rotation=45, ha='right')
    axes[1,0].grid(True, alpha=0.3)
    
    # 5. Pass/Fail summary
    pass_counts = [sum(passed_status), len(passed_status) - sum(passed_status)]
    labels = ['Passed', 'Failed']
    colors_pie = ['green', 'red']
    
    axes[1,1].pie(pass_counts, labels=labels, colors=colors_pie, autopct='%1.1f%%',
                 startangle=90)
    axes[1,1].set_title('Test Results Summary', fontweight='bold')
    
    # 6. Performance vs Accuracy scatter
    axes[1,2].scatter(forward_times, max_errors, c=colors, alpha=0.7, s=100)
    axes[1,2].set_xlabel('Forward Time (ms)')
    axes[1,2].set_ylabel('Max Gradient Error')
    axes[1,2].set_yscale('log')
    axes[1,2].set_title('Performance vs Accuracy', fontweight='bold')
    axes[1,2].grid(True, alpha=0.3)
    
    # Add function labels
    for i, name in enumerate(function_names):
        axes[1,2].annotate(name[:8], (forward_times[i], max_errors[i]), 
                          fontsize=8, ha='center')
    
    plt.suptitle('Comprehensive Function Testing Analysis', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(results_dir / 'comprehensive_testing_analysis.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

# Create testing visualization
create_testing_analysis_visualization(test_suite_results)

# Save comprehensive testing results
with open(results_dir / 'comprehensive_testing_results.json', 'w') as f:
    # Convert numpy types to native Python types for JSON serialization
    json_results = {}
    for name, result in test_suite_results.items():
        json_results[name] = {
            'gradient_check': {
                'passed': result['gradient_check'].get('passed', False),
                'max_error': float(result['gradient_check'].get('max_error', 0)),
                'num_inputs': len(result['gradient_check'].get('absolute_errors', []))
            },
            'performance': {
                'forward_time_ms': float(result['performance']['forward_time_ms']),
                'backward_time_ms': float(result['performance']['backward_time_ms']),
                'memory_usage_mb': float(result['performance']['memory_usage_mb'])
            },
            'overall_passed': result['passed']
        }
    
    json.dump(json_results, f, indent=2)

print(f"\n💾 Comprehensive testing results saved")
print(f"\n🎓 Testing Framework Insights:")
print(f"  • Numerical gradient checking validates custom implementations")
print(f"  • Performance benchmarking identifies bottlenecks")
print(f"  • Comprehensive testing ensures production readiness")
print(f"  • Error analysis guides optimization efforts")
```

## 7. Advanced Loss Functions and Specialized Operations

### 7.1 Production-Ready Custom Loss Functions

```python
def create_advanced_loss_functions():
    """Create sophisticated loss functions for specialized applications"""
    
    print("\n=== 7.1 Advanced Custom Loss Functions ===\n")
    
    class FocalLossFunction(Function):
        """
        Focal Loss for addressing class imbalance
        Paper: "Focal Loss for Dense Object Detection" (Lin et al., 2017)
        """
        
        @staticmethod
        def forward(ctx, input, target, alpha=1.0, gamma=2.0, reduction='mean'):
            """
            Focal Loss: FL(p_t) = -α_t(1-p_t)^γ log(p_t)
            
            Args:
                input: Logits [batch_size, num_classes]
                target: Target classes [batch_size]
                alpha: Weighting factor for rare class
                gamma: Focusing parameter
                reduction: 'mean', 'sum', or 'none'
            """
            # Compute softmax probabilities
            log_probs = F.log_softmax(input, dim=-1)
            probs = torch.exp(log_probs)
            
            # Get probabilities and log probabilities for target classes
            target_log_probs = log_probs.gather(1, target.unsqueeze(1)).squeeze(1)
            target_probs = probs.gather(1, target.unsqueeze(1)).squeeze(1)
            
            # Compute focal weights
            focal_weights = alpha * (1 - target_probs) ** gamma
            
            # Compute focal loss
            focal_losses = -focal_weights * target_log_probs
            
            if reduction == 'mean':
                result = focal_losses.mean()
            elif reduction == 'sum':
                result = focal_losses.sum()
            else:
                result = focal_losses
            
            # Save for backward pass
            ctx.save_for_backward(probs, target, focal_weights, target_probs)
            ctx.alpha = alpha
            ctx.gamma = gamma
            ctx.reduction = reduction
            ctx.batch_size = input.size(0)
            ctx.num_classes = input.size(1)
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Complex gradient computation for focal loss
            """
            probs, target, focal_weights, target_probs = ctx.saved_tensors
            alpha = ctx.alpha
            gamma = ctx.gamma
            reduction = ctx.reduction
            batch_size = ctx.batch_size
            num_classes = ctx.num_classes
            
            # Initialize gradient
            grad_input = torch.zeros_like(probs)
            
            # Create one-hot encoding
            target_one_hot = torch.zeros_like(probs)
            target_one_hot.scatter_(1, target.unsqueeze(1), 1)
            
            # For each sample and class
            for i in range(batch_size):
                for j in range(num_classes):
                    if j == target[i]:  # Target class
                        # Complex derivative for target class
                        grad_input[i, j] = alpha * (
                            gamma * (1 - target_probs[i]) ** (gamma - 1) * 
                            torch.log(target_probs[i]) * probs[i, j] +
                            (1 - target_probs[i]) ** gamma * (probs[i, j] - 1)
                        )
                    else:  # Non-target class
                        grad_input[i, j] = alpha * (
                            gamma * (1 - target_probs[i]) ** (gamma - 1) * 
                            torch.log(target_probs[i]) * probs[i, j] +
                            (1 - target_probs[i]) ** gamma * probs[i, j]
                        )
            
            # Apply reduction scaling
            if reduction == 'mean':
                grad_input = grad_input / batch_size
            
            grad_input = grad_input * grad_output
            
            return grad_input, None, None, None, None
    
    class TripletLossFunction(Function):
        """
        Triplet Loss for metric learning
        Used in face recognition, person re-identification, etc.
        """
        
        @staticmethod
        def forward(ctx, anchor, positive, negative, margin=1.0, p=2):
            """
            Triplet Loss: max(0, d(a,p) - d(a,n) + margin)
            
            Args:
                anchor: Anchor embeddings [batch_size, embedding_dim]
                positive: Positive embeddings [batch_size, embedding_dim]
                negative: Negative embeddings [batch_size, embedding_dim]
                margin: Margin for separation
                p: Norm order (1 or 2)
            """
            # Compute distances
            if p == 1:
                pos_dist = (anchor - positive).abs().sum(dim=1)
                neg_dist = (anchor - negative).abs().sum(dim=1)
            elif p == 2:
                pos_dist = (anchor - positive).pow(2).sum(dim=1).sqrt()
                neg_dist = (anchor - negative).pow(2).sum(dim=1).sqrt()
            else:
                pos_dist = (anchor - positive).abs().pow(p).sum(dim=1).pow(1.0/p)
                neg_dist = (anchor - negative).abs().pow(p).sum(dim=1).pow(1.0/p)
            
            # Compute triplet loss
            losses = F.relu(pos_dist - neg_dist + margin)
            
            # Save for backward
            ctx.save_for_backward(anchor, positive, negative, pos_dist, neg_dist, losses)
            ctx.margin = margin
            ctx.p = p
            
            return losses.mean()
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for triplet loss
            """
            anchor, positive, negative, pos_dist, neg_dist, losses = ctx.saved_tensors
            margin = ctx.margin
            p = ctx.p
            
            batch_size = anchor.size(0)
            embedding_dim = anchor.size(1)
            
            # Initialize gradients
            grad_anchor = torch.zeros_like(anchor)
            grad_positive = torch.zeros_like(positive)
            grad_negative = torch.zeros_like(negative)
            
            # Only compute gradients for non-zero losses
            active_mask = (losses > 0).float()
            
            for i in range(batch_size):
                if active_mask[i] > 0:
                    if p == 1:
                        # L1 distance gradients
                        pos_grad = torch.sign(anchor[i] - positive[i])
                        neg_grad = torch.sign(anchor[i] - negative[i])
                    elif p == 2:
                        # L2 distance gradients
                        pos_grad = (anchor[i] - positive[i]) / (pos_dist[i] + 1e-8)
                        neg_grad = (anchor[i] - negative[i]) / (neg_dist[i] + 1e-8)
                    else:
                        # General Lp distance gradients
                        pos_diff = anchor[i] - positive[i]
                        neg_diff = anchor[i] - negative[i]
                        
                        pos_grad = (torch.sign(pos_diff) * torch.abs(pos_diff) ** (p-1) / 
                                  (pos_dist[i] ** (p-1) + 1e-8))
                        neg_grad = (torch.sign(neg_diff) * torch.abs(neg_diff) ** (p-1) / 
                                  (neg_dist[i] ** (p-1) + 1e-8))
                    
                    # Accumulate gradients
                    grad_anchor[i] = pos_grad - neg_grad
                    grad_positive[i] = -pos_grad
                    grad_negative[i] = neg_grad
            
            # Scale by output gradient and batch size
            scale = grad_output / batch_size
            grad_anchor *= scale
            grad_positive *= scale
            grad_negative *= scale
            
            return grad_anchor, grad_positive, grad_negative, None, None
    
    class ContrastiveLossFunction(Function):
        """
        Contrastive Loss for siamese networks
        Learns to minimize distance for similar pairs, maximize for dissimilar
        """
        
        @staticmethod
        def forward(ctx, output1, output2, label, margin=1.0):
            """
            Contrastive Loss: 
            - Similar: 0.5 * d²
            - Dissimilar: 0.5 * max(0, margin - d)²
            
            Args:
                output1: First output embeddings [batch_size, embedding_dim]
                output2: Second output embeddings [batch_size, embedding_dim]
                label: Binary labels (1 for similar, 0 for dissimilar)
                margin: Margin for dissimilar pairs
            """
            # Compute Euclidean distance
            diff = output1 - output2
            distances = torch.sqrt(torch.sum(diff ** 2, dim=1) + 1e-8)
            
            # Compute contrastive loss
            similar_loss = 0.5 * distances ** 2
            dissimilar_loss = 0.5 * F.relu(margin - distances) ** 2
            
            losses = label.float() * similar_loss + (1 - label.float()) * dissimilar_loss
            
            # Save for backward
            ctx.save_for_backward(output1, output2, label, distances, diff)
            ctx.margin = margin
            
            return losses.mean()
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for contrastive loss
            """
            output1, output2, label, distances, diff = ctx.saved_tensors
            margin = ctx.margin
            batch_size = output1.size(0)
            
            # Initialize gradients
            grad_output1 = torch.zeros_like(output1)
            grad_output2 = torch.zeros_like(output2)
            
            for i in range(batch_size):
                if label[i] == 1:  # Similar pair
                    # Gradient for similar pairs: d * (o1 - o2) / ||o1 - o2||
                    grad = diff[i] / (distances[i] + 1e-8)
                else:  # Dissimilar pair
                    # Gradient for dissimilar pairs: -max(0, margin - d) * (o1 - o2) / ||o1 - o2||
                    if distances[i] < margin:
                        grad = -(margin - distances[i]) * diff[i] / (distances[i] + 1e-8)
                    else:
                        grad = torch.zeros_like(diff[i])
                
                grad_output1[i] = grad
                grad_output2[i] = -grad
            
            # Scale by output gradient and batch size
            scale = grad_output / batch_size
            grad_output1 *= scale
            grad_output2 *= scale
            
            return grad_output1, grad_output2, None, None
    
    class CenterLossFunction(Function):
        """
        Center Loss for deep feature learning
        Learns class centers and penalizes distance from centers
        """
        
        @staticmethod
        def forward(ctx, features, labels, centers, alpha=0.5):
            """
            Center Loss: 0.5 * sum(||f_i - c_{y_i}||²)
            
            Args:
                features: Feature embeddings [batch_size, feature_dim]
                labels: Class labels [batch_size]
                centers: Class centers [num_classes, feature_dim]
                alpha: Learning rate for center updates
            """
            batch_size = features.size(0)
            feature_dim = features.size(1)
            
            # Get centers for each sample
            selected_centers = centers[labels]
            
            # Compute center loss
            diff = features - selected_centers
            losses = 0.5 * torch.sum(diff ** 2, dim=1)
            
            # Update centers (part of forward pass for this implementation)
            unique_labels = torch.unique(labels)
            updated_centers = centers.clone()
            
            for label in unique_labels:
                mask = (labels == label)
                count = mask.sum().float()
                
                if count > 0:
                    center_diff = (features[mask] - centers[label]).mean(dim=0)
                    updated_centers[label] = centers[label] + alpha * center_diff
            
            # Save for backward
            ctx.save_for_backward(features, labels, centers, selected_centers)
            ctx.updated_centers = updated_centers
            
            return losses.mean()
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Backward pass for center loss
            """
            features, labels, centers, selected_centers = ctx.saved_tensors
            batch_size = features.size(0)
            
            # Gradient w.r.t. features
            grad_features = (features - selected_centers) * grad_output / batch_size
            
            return grad_features, None, None, None
    
    # Wrapper functions
    def focal_loss(input, target, alpha=1.0, gamma=2.0, reduction='mean'):
        return FocalLossFunction.apply(input, target, alpha, gamma, reduction)
    
    def triplet_loss(anchor, positive, negative, margin=1.0, p=2):
        return TripletLossFunction.apply(anchor, positive, negative, margin, p)
    
    def contrastive_loss(output1, output2, label, margin=1.0):
        return ContrastiveLossFunction.apply(output1, output2, label, margin)
    
    def center_loss(features, labels, centers, alpha=0.5):
        return CenterLossFunction.apply(features, labels, centers, alpha)
    
    return focal_loss, triplet_loss, contrastive_loss, center_loss

# Create advanced loss functions
focal_loss, triplet_loss, contrastive_loss, center_loss = create_advanced_loss_functions()

print("🔥 Testing Advanced Loss Functions:")
print("-" * 50)

# Test Focal Loss
print("\n📊 Testing Focal Loss:")

# Create imbalanced dataset
batch_size, num_classes = 64, 5
logits = torch.randn(batch_size, num_classes, requires_grad=True)

# Create severely imbalanced targets (class 0 is very rare)
targets = torch.randint(1, num_classes, (batch_size,))
targets[:4] = 0  # Only 4 samples of class 0

print(f"Class distribution: {torch.bincount(targets)}")

# Compare standard CE vs Focal Loss
ce_loss = F.cross_entropy(logits, targets)
focal_loss_val = focal_loss(logits, targets, alpha=2.0, gamma=2.0)

print(f"Standard CE Loss: {ce_loss.item():.4f}")
print(f"Focal Loss (α=2, γ=2): {focal_loss_val.item():.4f}")

# Test gradient flow
focal_loss_val.backward()
print(f"Focal loss gradient norm: {logits.grad.norm().item():.4f}")

# Test Triplet Loss
print(f"\n📊 Testing Triplet Loss:")

embedding_dim = 128
anchor = torch.randn(batch_size, embedding_dim, requires_grad=True)
positive = torch.randn(batch_size, embedding_dim, requires_grad=True)
negative = torch.randn(batch_size, embedding_dim, requires_grad=True)

triplet_loss_val = triplet_loss(anchor, positive, negative, margin=1.0, p=2)
triplet_loss_val.backward()

print(f"Triplet loss: {triplet_loss_val.item():.4f}")
print(f"Anchor gradient norm: {anchor.grad.norm().item():.4f}")
print(f"Positive gradient norm: {positive.grad.norm().item():.4f}")
print(f"Negative gradient norm: {negative.grad.norm().item():.4f}")

# Test Contrastive Loss
print(f"\n📊 Testing Contrastive Loss:")

output1 = torch.randn(batch_size, embedding_dim, requires_grad=True)
output2 = torch.randn(batch_size, embedding_dim, requires_grad=True)
pair_labels = torch.randint(0, 2, (batch_size,))  # Binary labels

contrastive_loss_val = contrastive_loss(output1, output2, pair_labels, margin=2.0)
contrastive_loss_val.backward()

print(f"Contrastive loss: {contrastive_loss_val.item():.4f}")
print(f"Output1 gradient norm: {output1.grad.norm().item():.4f}")
print(f"Output2 gradient norm: {output2.grad.norm().item():.4f}")
print(f"Similar pairs: {pair_labels.sum().item()}/{len(pair_labels)}")

# Test Center Loss
print(f"\n📊 Testing Center Loss:")

feature_dim = 64
features = torch.randn(batch_size, feature_dim, requires_grad=True)
class_labels = torch.randint(0, num_classes, (batch_size,))
centers = torch.randn(num_classes, feature_dim, requires_grad=True)

center_loss_val = center_loss(features, class_labels, centers, alpha=0.5)
center_loss_val.backward()

print(f"Center loss: {center_loss_val.item():.4f}")
print(f"Features gradient norm: {features.grad.norm().item():.4f}")

# Advanced loss function analysis and visualization
def create_loss_function_analysis():
    """Analyze behavior of different loss functions"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # 1. Focal Loss vs Cross Entropy
    confidences = torch.linspace(0.01, 0.99, 100)
    ce_losses = -torch.log(confidences)
    
    gamma_values = [0, 1, 2, 5]
    
    for gamma in gamma_values:
        focal_losses = -(1 - confidences)**gamma * torch.log(confidences)
        label = f'Focal (γ={gamma})' if gamma > 0 else 'Cross Entropy'
        axes[0,0].plot(confidences, focal_losses, linewidth=2, label=label)
    
    axes[0,0].set_xlabel('Confidence (p_t)')
    axes[0,0].set_ylabel('Loss')
    axes[0,0].set_title('Focal Loss vs Cross Entropy', fontweight='bold')
    axes[0,0].legend()
    axes[0,0].grid(True, alpha=0.3)
    axes[0,0].set_yscale('log')
    
    # 2. Triplet Loss visualization
    distances = torch.linspace(0, 3, 100)
    margin = 1.0
    
    triplet_losses = F.relu(distances - margin)
    axes[0,1].plot(distances, triplet_losses, linewidth=3, color='red')
    axes[0,1].axvline(x=margin, color='black', linestyle='--', label=f'Margin = {margin}')
    axes[0,1].set_xlabel('Distance (d_pos - d_neg)')
    axes[0,1].set_ylabel('Triplet Loss')
    axes[0,1].set_title('Triplet Loss Function', fontweight='bold')
    axes[0,1].legend()
    axes[0,1].grid(True, alpha=0.3)
    
    # 3. Contrastive Loss visualization
    distances = torch.linspace(0, 3, 100)
    margin = 2.0
    
    similar_loss = 0.5 * distances**2
    dissimilar_loss = 0.5 * F.relu(margin - distances)**2
    
    axes[0,2].plot(distances, similar_loss, linewidth=2, label='Similar pairs', color='blue')
    axes[0,2].plot(distances, dissimilar_loss, linewidth=2, label='Dissimilar pairs', color='red')
    axes[0,2].axvline(x=margin, color='black', linestyle='--', label=f'Margin = {margin}')
    axes[0,2].set_xlabel('Distance')
    axes[0,2].set_ylabel('Contrastive Loss')
    axes[0,2].set_title('Contrastive Loss Function', fontweight='bold')
    axes[0,2].legend()
    axes[0,2].grid(True, alpha=0.3)
    
    # 4. Loss comparison on imbalanced data
    class_frequencies = [0.05, 0.15, 0.25, 0.25, 0.30]  # Imbalanced
    loss_types = ['Standard CE', 'Focal (γ=1)', 'Focal (γ=2)', 'Focal (γ=5)']
    
    # Simulate loss behavior
    np.random.seed(42)
    loss_values = np.random.rand(len(loss_types), len(class_frequencies)) * 2 + 1
    
    x_pos = np.arange(len(class_frequencies))
    width = 0.2
    
    for i, loss_type in enumerate(loss_types):
        offset = (i - len(loss_types)/2 + 0.5) * width
        axes[1,0].bar(x_pos + offset, loss_values[i], width, 
                     label=loss_type, alpha=0.8)
    
    axes[1,0].set_xlabel('Class')
    axes[1,0].set_ylabel('Loss')
    axes[1,0].set_title('Loss Comparison on Imbalanced Classes', fontweight='bold')
    axes[1,0].set_xticks(x_pos)
    axes[1,0].set_xticklabels([f'Class {i}' for i in range(len(class_frequencies))])
    axes[1,0].legend()
    axes[1,0].grid(True, alpha=0.3)
    
    # 5. Gradient flow analysis
    loss_functions = ['CE', 'Focal', 'Triplet', 'Contrastive', 'Center']
    gradient_norms = [2.3, 1.8, 3.1, 2.7, 1.9]  # Example values
    
    bars = axes[1,1].bar(loss_functions, gradient_norms, alpha=0.8, 
                        color=plt.cm.viridis(np.linspace(0, 1, len(loss_functions))))
    axes[1,1].set_title('Gradient Flow Comparison', fontweight='bold')
    axes[1,1].set_ylabel('Gradient Norm')
    axes[1,1].grid(True, alpha=0.3)
    
    # Add value labels on bars
    for bar, norm in zip(bars, gradient_norms):
        height = bar.get_height()
        axes[1,1].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                      f'{norm:.1f}', ha='center', va='bottom')
    
    # 6. Loss landscape visualization (conceptual)
    x = np.linspace(-2, 2, 50)
    y = np.linspace(-2, 2, 50)
    X, Y = np.meshgrid(x, y)
    
    # Simulate different loss landscapes
    Z_ce = X**2 + Y**2  # Simple quadratic (cross-entropy-like)
    Z_focal = X**2 + Y**2 + 0.5 * np.sin(X*Y)  # More complex (focal-like)
    
    contour1 = axes[1,2].contour(X, Y, Z_ce, levels=10, alpha=0.5, colors='blue')
    contour2 = axes[1,2].contour(X, Y, Z_focal, levels=10, alpha=0.5, colors='red')
    
    axes[1,2].set_title('Loss Landscape Comparison', fontweight='bold')
    axes[1,2].set_xlabel('Parameter 1')
    axes[1,2].set_ylabel('Parameter 2')
    
    # Create custom legend
    from matplotlib.lines import Line2D
    legend_elements = [Line2D([0], [0], color='blue', label='Standard Loss'),
                      Line2D([0], [0], color='red', label='Advanced Loss')]
    axes[1,2].legend(handles=legend_elements)
    
    plt.suptitle('Advanced Loss Functions Analysis', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(results_dir / 'advanced_loss_functions_analysis.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

# Create loss function analysis
create_loss_function_analysis()

# Test advanced loss functions with gradient checker
print(f"\n🔍 Gradient Checking Advanced Loss Functions:")

# Prepare test cases for gradient checking
loss_test_functions = [
    (lambda x, y: focal_loss(x, y, alpha=1.0, gamma=2.0),
     [torch.randn(8, 5, requires_grad=True), torch.randint(0, 5, (8,))],
     "Focal Loss"),
    
    (lambda a, p, n: triplet_loss(a, p, n, margin=1.0),
     [torch.randn(4, 10, requires_grad=True), 
      torch.randn(4, 10, requires_grad=True),
      torch.randn(4, 10, requires_grad=True)],
     "Triplet Loss"),
    
    (lambda o1, o2, l: contrastive_loss(o1, o2, l, margin=1.0),
     [torch.randn(6, 8, requires_grad=True),
      torch.randn(6, 8, requires_grad=True),
      torch.randint(0, 2, (6,))],
     "Contrastive Loss"),
]

# Run gradient checking on loss functions
loss_checker = GradientChecker(tolerance=1e-3, verbose=False)  # More lenient for complex losses
loss_test_results = loss_checker.comprehensive_test_suite(loss_test_functions)

# Compile advanced loss function results
advanced_loss_results = {
    'focal_loss': {
        'test_loss_value': focal_loss_val.item(),
        'compared_to_ce': ce_loss.item(),
        'gradient_test_passed': loss_test_results['Focal Loss']['passed']
    },
    'triplet_loss': {
        'test_loss_value': triplet_loss_val.item(),
        'gradient_test_passed': loss_test_results['Triplet Loss']['passed']
    },
    'contrastive_loss': {
        'test_loss_value': contrastive_loss_val.item(),
        'gradient_test_passed': loss_test_results['Contrastive Loss']['passed']
    },
    'center_loss': {
        'test_loss_value': center_loss_val.item()
    }
}

with open(results_dir / 'advanced_loss_functions_results.json', 'w') as f:
    json.dump(advanced_loss_results, f, indent=2)

print(f"\n💾 Advanced loss functions results saved")
print(f"\n🎓 Advanced Loss Functions Insights:")
print(f"  • Focal Loss effectively handles class imbalance")
print(f"  • Triplet Loss learns discriminative embeddings") 
print(f"  • Contrastive Loss works well for siamese networks")
print(f"  • Center Loss improves intra-class compactness")
print(f"  • Custom losses enable domain-specific optimization")
```

## 8. Comprehensive Summary and Production Guidelines

### 8.1 Best Practices and Production Template

```python
def create_production_template():
    """Create comprehensive production-ready custom function template"""
    
    print("\n=== 8.1 Production-Ready Custom Function Template ===\n")
    
    class ProductionCustomFunction(Function):
        """
        Production-ready template for custom autograd functions.
        
        This template includes:
        - Comprehensive input validation
        - Memory efficiency considerations  
        - Numerical stability safeguards
        - Proper error handling
        - Performance optimization
        - Extensive documentation
        """
        
        @staticmethod
        def forward(ctx, input_tensor, weight_tensor=None, 
                   scalar_param=1.0, optional_param=None, 
                   validate_inputs=True, numerical_stability=True):
            """
            Production-ready forward pass template.
            
            Args:
                ctx: PyTorch context object
                input_tensor: Primary input tensor
                weight_tensor: Optional weight tensor
                scalar_param: Scalar parameter (default: 1.0)
                optional_param: Optional parameter
                validate_inputs: Whether to validate inputs (default: True)
                numerical_stability: Whether to apply stability checks (default: True)
                
            Returns:
                output: Computed result tensor
                
            Raises:
                TypeError: If inputs are of wrong type
                ValueError: If inputs have invalid values or shapes
                RuntimeError: If computation fails
            """
            
            # 1. Input Validation
            if validate_inputs:
                if not isinstance(input_tensor, torch.Tensor):
                    raise TypeError(f"Expected torch.Tensor, got {type(input_tensor)}")
                
                if input_tensor.numel() == 0:
                    raise ValueError("Input tensor cannot be empty")
                
                if weight_tensor is not None:
                    if not isinstance(weight_tensor, torch.Tensor):
                        raise TypeError(f"Weight must be torch.Tensor, got {type(weight_tensor)}")
                    
                    # Check shape compatibility
                    if input_tensor.size(-1) != weight_tensor.size(0):
                        raise ValueError(f"Incompatible shapes: input {input_tensor.shape}, "
                                       f"weight {weight_tensor.shape}")
                
                if not isinstance(scalar_param, (int, float)):
                    raise TypeError(f"Scalar param must be numeric, got {type(scalar_param)}")
            
            # 2. Numerical Stability Checks
            if numerical_stability:
                if torch.isnan(input_tensor).any():
                    raise ValueError("Input contains NaN values")
                
                if torch.isinf(input_tensor).any():
                    raise ValueError("Input contains infinite values")
                
                if weight_tensor is not None:
                    if torch.isnan(weight_tensor).any():
                        raise ValueError("Weight contains NaN values")
                    
                    if torch.isinf(weight_tensor).any():
                        raise ValueError("Weight contains infinite values")
            
            # 3. Handle Optional Parameters
            if optional_param is None:
                optional_param = torch.ones_like(input_tensor)
            
            # 4. Memory Efficiency Considerations
            input_size_mb = input_tensor.numel() * input_tensor.element_size() / (1024**2)
            
            # For very large tensors, consider alternative storage strategies
            if input_size_mb > 1000:  # > 1GB
                print(f"⚠️ Large tensor detected ({input_size_mb:.1f}MB). "
                      f"Consider using gradient checkpointing.")
            
            # 5. Core Computation with Error Handling
            try:
                # Example computation: f(x, w, s, o) = s * (x @ w) + o
                if weight_tensor is not None:
                    result = torch.matmul(input_tensor, weight_tensor)
                else:
                    result = input_tensor
                
                result = scalar_param * result + optional_param
                
            except RuntimeError as e:
                raise RuntimeError(f"Forward computation failed: {e}")
            
            # 6. Save Information for Backward Pass
            # Only save what's absolutely necessary
            tensors_to_save = [input_tensor]
            
            if weight_tensor is not None:
                tensors_to_save.append(weight_tensor)
            
            ctx.save_for_backward(*tensors_to_save)
            
            # Save non-tensor parameters
            ctx.scalar_param = scalar_param
            ctx.has_weight = weight_tensor is not None
            ctx.input_shape = input_tensor.shape
            
            # Mark which inputs need gradients for efficiency
            ctx.needs_input_grad = [
                input_tensor.requires_grad,
                weight_tensor.requires_grad if weight_tensor is not None else False,
                False,  # scalar_param doesn't need gradients
                False,  # optional_param doesn't need gradients  
                False,  # validate_inputs doesn't need gradients
                False   # numerical_stability doesn't need gradients
            ]
            
            # 7. Output Validation
            if numerical_stability:
                if torch.isnan(result).any():
                    raise RuntimeError("Forward pass produced NaN values")
                
                if torch.isinf(result).any():
                    raise RuntimeError("Forward pass produced infinite values")
            
            return result
        
        @staticmethod
        def backward(ctx, grad_output):
            """
            Production-ready backward pass template.
            
            Args:
                ctx: Context with saved forward information
                grad_output: Gradient w.r.t. output
                
            Returns:
                Tuple of gradients for each forward input
                
            Raises:
                RuntimeError: If backward computation fails
            """
            
            # 1. Input Validation
            if grad_output is None:
                raise ValueError("grad_output cannot be None")
            
            if torch.isnan(grad_output).any():
                raise ValueError("grad_output contains NaN values")
            
            # 2. Retrieve Saved Information
            saved_tensors = ctx.saved_tensors
            input_tensor = saved_tensors[0]
            weight_tensor = saved_tensors[1] if ctx.has_weight else None
            
            scalar_param = ctx.scalar_param
            
            # 3. Initialize Gradients
            grad_input = grad_weight = None
            grad_scalar = grad_optional = grad_validate = grad_stability = None
            
            # 4. Compute Gradients Only for Required Inputs
            try:
                if ctx.needs_input_grad[0]:  # input_tensor gradient
                    if weight_tensor is not None:
                        grad_input = scalar_param * torch.matmul(grad_output, weight_tensor.t())
                    else:
                        grad_input = scalar_param * grad_output
                
                if ctx.needs_input_grad[1] and weight_tensor is not None:  # weight_tensor gradient
                    grad_weight = scalar_param * torch.matmul(input_tensor.t(), grad_output)
                
                # scalar_param, optional_param, and flags don't need gradients (return None)
                
            except RuntimeError as e:
                raise RuntimeError(f"Backward computation failed: {e}")
            
            # 5. Gradient Validation
            if grad_input is not None:
                if torch.isnan(grad_input).any():
                    raise RuntimeError("Input gradient contains NaN values")
                
                if torch.isinf(grad_input).any():
                    raise RuntimeError("Input gradient contains infinite values")
            
            if grad_weight is not None:
                if torch.isnan(grad_weight).any():
                    raise RuntimeError("Weight gradient contains NaN values")
                
                if torch.isinf(grad_weight).any():
                    raise RuntimeError("Weight gradient contains infinite values")
            
            # 6. Return Gradients in Same Order as Forward Inputs
            return (grad_input, grad_weight, grad_scalar, grad_optional, 
                   grad_validate, grad_stability)
    
    class CustomFunctionRegistry:
        """Registry for managing custom functions in production"""
        
        def __init__(self):
            self.functions = {}
            self.performance_stats = {}
            self.validation_results = {}
        
        def register(self, name, function_class, description=""):
            """Register a custom function"""
            self.functions[name] = {
                'class': function_class,
                'description': description,
                'registered_at': time.time()
            }
            print(f"✅ Registered function: {name}")
        
        def validate_function(self, name, test_inputs, gradient_checker=None):
            """Validate a registered function"""
            if name not in self.functions:
                raise ValueError(f"Function {name} not registered")
            
            func_class = self.functions[name]['class']
            
            if gradient_checker is None:
                gradient_checker = GradientChecker(tolerance=1e-4)
            
            # Create wrapper function for testing
            def test_wrapper(*inputs):
                return func_class.apply(*inputs)
            
            # Run validation
            results = gradient_checker.check_gradients(test_wrapper, test_inputs, name)
            self.validation_results[name] = results
            
            return results
        
        def benchmark_function(self, name, test_inputs, num_iterations=1000):
            """Benchmark a registered function"""
            if name not in self.functions:
                raise ValueError(f"Function {name} not registered")
            
            func_class = self.functions[name]['class']
            
            # Create wrapper
            def test_wrapper(*inputs):
                return func_class.apply(*inputs)
            
            # Benchmark
            if name not in self.performance_stats:
                gradient_checker = GradientChecker()
                perf_stats = gradient_checker.performance_benchmark(
                    test_wrapper, test_inputs, num_iterations
                )
                self.performance_stats[name] = perf_stats
            
            return self.performance_stats[name]
        
        def get_summary(self):
            """Get summary of all registered functions"""
            summary = {
                'total_functions': len(self.functions),
                'validated_functions': len(self.validation_results),
                'benchmarked_functions': len(self.performance_stats),
                'functions': {}
            }
            
            for name, info in self.functions.items():
                func_summary = {
                    'description': info['description'],
                    'validated': name in self.validation_results,
                    'benchmarked': name in self.performance_stats
                }
                
                if name in self.validation_results:
                    func_summary['validation_passed'] = self.validation_results[name]['passed']
                
                if name in self.performance_stats:
                    func_summary['performance'] = self.performance_stats[name]
                
                summary['functions'][name] = func_summary
            
            return summary
    
    # Production wrapper function
    def production_custom_function(input_tensor, weight_tensor=None, 
                                 scalar_param=1.0, optional_param=None,
                                 validate_inputs=True, numerical_stability=True):
        """Convenient wrapper for production custom function"""
        return ProductionCustomFunction.apply(
            input_tensor, weight_tensor, scalar_param, optional_param,
            validate_inputs, numerical_stability
        )
    
    return ProductionCustomFunction, CustomFunctionRegistry, production_custom_function

# Create production framework
ProductionCustomFunction, CustomFunctionRegistry, production_custom_function = create_production_template()

print("🏭 Testing Production Framework:")
print("-" * 50)

# Initialize function registry
registry = CustomFunctionRegistry()

# Register our custom functions
registry.register("production_function", ProductionCustomFunction, 
                 "Production-ready template function")
registry.register("swish_activation", type("SwishClass", (), {
    'apply': staticmethod(swish)
})(), "Swish activation function")
registry.register("focal_loss", type("FocalClass", (), {
    'apply': staticmethod(focal_loss)  
})(), "Focal loss for imbalanced classes")

# Test production function
print("\n📊 Testing Production Function:")

test_input = torch.randn(16, 32, requires_grad=True)
test_weight = torch.randn(32, 16, requires_grad=True)

try:
    # Test with full validation
    output = production_custom_function(
        test_input, test_weight, scalar_param=0.5,
        validate_inputs=True, numerical_stability=True
    )
    
    output.sum().backward()
    
    print(f"✅ Production function test passed")
    print(f"   Output shape: {output.shape}")
    print(f"   Input gradient norm: {test_input.grad.norm():.4f}")
    print(f"   Weight gradient norm: {test_weight.grad.norm():.4f}")
    
except Exception as e:
    print(f"❌ Production function test failed: {e}")

# Test error handling
print(f"\n🚨 Testing Error Handling:")

try:
    # Test with invalid input (should raise error)
    invalid_input = torch.tensor([float('nan')], requires_grad=True)
    production_custom_function(invalid_input, validate_inputs=True, numerical_stability=True)
    print("❌ Error handling failed - should have caught NaN")
except ValueError as e:
    print(f"✅ Correctly caught error: {e}")

# Validate registered functions
print(f"\n🔍 Validating Registered Functions:")

validation_results = {}

# Validate production function
test_inputs_prod = [torch.randn(8, 10, requires_grad=True), 
                   torch.randn(10, 5, requires_grad=True)]
validation_results['production'] = registry.validate_function(
    "production_function", test_inputs_prod
)

# Validate swish activation
test_inputs_swish = [torch.randn(20, requires_grad=True)]
validation_results['swish'] = registry.validate_function(
    "swish_activation", test_inputs_swish
)

# Get registry summary
registry_summary = registry.get_summary()

print(f"\n📋 Registry Summary:")
print(f"Total functions: {registry_summary['total_functions']}")
print(f"Validated functions: {registry_summary['validated_functions']}")

for name, info in registry_summary['functions'].items():
    status = "✅" if info.get('validation_passed', False) else "❌"
    print(f"  {name}: {status} - {info['description']}")

# Create final comprehensive summary
def generate_final_summary():
    """Generate comprehensive summary of entire notebook"""
    
    print("\n" + "="*80)
    print("🎓 CUSTOM AUTOGRAD FUNCTIONS MASTERY - FINAL SUMMARY")
    print("="*80)
    
    summary_data = {
        'completion_timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
        'sections_completed': 8,
        'functions_implemented': 0,
        'concepts_mastered': [],
        'practical_skills': [],
        'production_capabilities': [],
        'testing_framework': [],
        'performance_metrics': {}
    }
    
    # Count implemented functions
    function_categories = {
        'Basic Functions': ['Square', 'Weighted Sum', 'Multi-input'],
        'Activation Functions': ['Swish', 'Mish', 'Adaptive', 'GLU'],
        'Memory-Efficient': ['Memory-Efficient MatMul', 'Sequential Computation', 'Gradient Checkpointing'],
        'Non-Differentiable': ['Straight-Through', 'Gumbel-Softmax', 'Soft Top-K', 'Differentiable Rounding'],
        'Advanced Loss Functions': ['Focal Loss', 'Triplet Loss', 'Contrastive Loss', 'Center Loss'],
        'Production Tools': ['Production Template', 'Registry System', 'Validation Framework']
    }
    
    total_functions = sum(len(funcs) for funcs in function_categories.values())
    summary_data['functions_implemented'] = total_functions
    
    # Concepts mastered
    concepts_mastered = [
        "✅ torch.autograd.Function API mastery",
        "✅ Forward and backward pass implementation",
        "✅ Complex gradient computation techniques", 
        "✅ Memory-efficient operation design",
        "✅ Non-differentiable operation handling",
        "✅ Advanced activation function development",
        "✅ Custom loss function creation",
        "✅ Numerical stability and error handling",
        "✅ Performance optimization strategies",
        "✅ Production-ready development practices"
    ]
    
    practical_skills = [
        "🔧 Custom autograd function implementation",
        "🔧 Gradient checking and validation",
        "🔧 Performance benchmarking and optimization",
        "🔧 Memory-efficient computation design",
        "🔧 Numerical stability enforcement",
        "🔧 Error handling and robustness",
        "🔧 Production deployment preparation",
        "🔧 Function registry and management",
        "🔧 Comprehensive testing frameworks",
        "🔧 Advanced debugging techniques"
    ]
    
    production_capabilities = [
        "🏭 Production-ready function templates",
        "🏭 Comprehensive input validation",
        "🏭 Memory efficiency optimization",
        "🏭 Numerical stability safeguards",
        "🏭 Performance monitoring and benchmarking",
        "🏭 Function registry and management systems",
        "🏭 Automated testing and validation",
        "🏭 Error handling and recovery",
        "🏭 Documentation and maintenance standards",
        "🏭 Scalability and deployment considerations"
    ]
    
    testing_framework = [
        "🧪 Numerical gradient checking",
        "🧪 Performance benchmarking",
        "🧪 Comprehensive test suites",
        "🧪 Error condition testing",
        "🧪 Memory usage profiling",
        "🧪 Gradient flow analysis",
        "🧪 Comparative validation",
        "🧪 Production readiness assessment"
    ]
    
    # Display summary
    print(f"\n📚 CONCEPTS MASTERED ({len(concepts_mastered)}):")
    for concept in concepts_mastered:
        print(f"  {concept}")
    
    print(f"\n🛠️ PRACTICAL SKILLS DEVELOPED ({len(practical_skills)}):")
    for skill in practical_skills:
        print(f"  {skill}")
    
    print(f"\n🏭 PRODUCTION CAPABILITIES ({len(production_capabilities)}):")
    for capability in production_capabilities:
        print(f"  {capability}")
    
    print(f"\n🧪 TESTING FRAMEWORK COMPONENTS ({len(testing_framework)}):")
    for component in testing_framework:
        print(f"  {component}")
    
    print(f"\n📊 IMPLEMENTATION STATISTICS:")
    print(f"  Total custom functions implemented: {total_functions}")
    print(f"  Function categories covered: {len(function_categories)}")
    print(f"  Advanced techniques demonstrated: 15+")
    print(f"  Production-ready templates created: 3")
    print(f"  Testing frameworks developed: 2")
    
    print(f"\n🏆 MASTERY LEVEL ACHIEVED:")
    print(f"  ⭐⭐⭐⭐⭐ EXPERT LEVEL CUSTOM AUTOGRAD FUNCTIONS")
    print(f"  Ready for advanced research and production deployment")
    
    # Next steps and advanced challenges
    next_steps = [
        "📓 Advanced Meta-Learning with Custom Gradients",
        "📓 Neural Architecture Search Implementation", 
        "📓 Automatic Differentiation for Scientific Computing",
        "📓 Custom CUDA Kernels with PyTorch Integration",
        "📓 Distributed Training with Custom Operations",
        "📓 Research Applications in Novel Domains"
    ]
    
    advanced_challenges = [
        "🏆 Implement higher-order gradient computations",
        "🏆 Create domain-specific loss functions for your field",
        "🏆 Build memory-efficient transformer operations",
        "🏆 Develop custom optimizers with novel gradient processing",
        "🏆 Implement differentiable programming languages",
        "🏆 Create custom operations for quantum machine learning"
    ]
    
    print(f"\n🚀 RECOMMENDED NEXT STEPS:")
    for step in next_steps:
        print(f"  {step}")
    
    print(f"\n🏆 ADVANCED CHALLENGES:")
    for challenge in advanced_challenges:
        print(f"  {challenge}")
    
    # Key insights and best practices
    best_practices = [
        "💡 Always validate inputs thoroughly in production functions",
        "💡 Implement numerical stability checks for robust operations",
        "💡 Use gradient checking to validate custom implementations",
        "💡 Profile memory usage for large-scale applications",
        "💡 Create comprehensive test suites for all custom functions",
        "💡 Document gradient computation mathematically",
        "💡 Consider memory-efficiency vs computation trade-offs",
        "💡 Handle edge cases and error conditions gracefully",
        "💡 Use function registries for production organization",
        "💡 Benchmark performance against standard implementations"
    ]
    
    print(f"\n💡 KEY BEST PRACTICES:")
    for practice in best_practices:
        print(f"  {practice}")
    
    # Save comprehensive summary
    summary_data.update({
        'concepts_mastered': concepts_mastered,
        'practical_skills': practical_skills,
        'production_capabilities': production_capabilities,
        'testing_framework': testing_framework,
        'function_categories': function_categories,
        'next_steps': next_steps,
        'advanced_challenges': advanced_challenges,
        'best_practices': best_practices
    })
    
    with open(results_dir / 'comprehensive_mastery_summary.json', 'w') as f:
        json.dump(summary_data, f, indent=2)
    
    print(f"\n💾 Complete mastery summary saved to:")
    print(f"    {results_dir / 'comprehensive_mastery_summary.json'}")
    
    # List all generated artifacts
    print(f"\n📂 Generated Learning Artifacts:")
    all_files = list(results_dir.glob('*'))
    
    for file_path in sorted(all_files):
        if file_path.is_file():
            size_mb = file_path.stat().st_size / (1024 * 1024)
            print(f"  📄 {file_path.name} ({size_mb:.2f} MB)")
    
    total_size = sum(f.stat().st_size for f in all_files if f.is_file()) / (1024 * 1024)
    print(f"\n📊 Total artifacts: {len(all_files)} files ({total_size:.1f} MB)")
    
    print(f"\n🌟 CONGRATULATIONS!")
    print(f"🎯 You've achieved CUSTOM AUTOGRAD FUNCTIONS MASTERY!")
    print(f"🚀 Ready for advanced research and production applications!")
    
    return summary_data

# Generate final comprehensive summary
final_summary = generate_final_summary()

print(f"\n" + "="*80)
print("🎉 CUSTOM AUTOGRAD FUNCTIONS MASTERY - COMPLETE!")
print("="*80)
```

---

## Final Notes

This comprehensive custom autograd functions notebook has taken you from basic Function API usage to advanced production-ready implementations. You've mastered:

### 🎓 **Core Technical Skills**
- Complete mastery of `torch.autograd.Function` API
- Complex forward and backward pass implementations
- Multi-input function handling with proper gradient computation
- Memory-efficient operations and gradient checkpointing
- Non-differentiable operation approximations

### 🛠️ **Advanced Implementation Techniques**
- Custom activation functions with sophisticated derivatives
- Advanced loss functions for specialized applications
- Straight-through estimators and Gumbel-Softmax
- Production-ready templates with comprehensive validation
- Performance optimization and debugging strategies

### 🏭 **Production-Ready Capabilities**
- Comprehensive input validation and error handling
- Numerical stability safeguards and memory efficiency
- Function registry and management systems
- Automated testing and validation frameworks
- Performance benchmarking and optimization

### 🧪 **Research and Development Tools**
- Gradient checking and validation frameworks
- Custom loss functions for domain-specific problems
- Advanced techniques for handling discrete operations
- Memory-efficient implementations for large-scale training
- Comprehensive debugging and analysis tools

**You are now equipped to:**
- Implement any custom differentiable operation
- Create production-ready custom functions with proper validation
- Debug and optimize complex gradient computations
- Handle non-differentiable operations with advanced approximations
- Build comprehensive testing frameworks for custom implementations
- Develop specialized loss functions for research applications

**Next recommended learning paths:**
- Advanced Meta-Learning and Few-Shot Learning
- Neural Architecture Search Implementation
- Custom CUDA Kernels for PyTorch
- Automatic Differentiation for Scientific Computing
- Advanced Optimization Algorithms

**Happy custom function development! 🚀**