<div style="background-color: #008B8B; padding: 15px; border-radius: 5px; font-size: 20px; color: black; font-weight: bold;">
Debugging the Gradient Flow
</div>

In [None]:
for name, param in model.named_parameters():
    print(f"name : {name}, param = {param}")

In [None]:
def check_gradient_flow(model):
    """Check gradient flow for all parameters"""
    for name, param in model.named_parameters():
        if param.requires_grad:
            if param.grad is not None:
                grad_norm = param.grad.norm().item()
                print(f"{name}: grad_norm = {grad_norm:.6f}")
            else:
                print(f"{name}: NO GRADIENT")
        else:
            print(f"{name}: requires_grad = False")

In [None]:
# Install torchviz if not already installed
# pip install torchviz

try:
    from torchviz import make_dot
    TORCHVIZ_AVAILABLE = True
except ImportError:
    print("torchviz not available. Install with: pip install torchviz")
    TORCHVIZ_AVAILABLE = False

def visualize_computational_graph(loss, model, save_path="computation_graph"):
    """Visualize computational graph"""
    if TORCHVIZ_AVAILABLE:
        # Create computational graph
        params = dict(model.named_parameters())
        dot = make_dot(loss, params=params)
        
        # Save and display
        dot.format = 'png'
        dot.render(save_path)
        print(f"Computational graph saved as {save_path}.png")
        
        # Also print parameter connections
        print("\nParameter connections in graph:")
        for name, param in model.named_parameters():
            if param.grad_fn is not None:
                print(f"{name}: Connected (grad_fn = {param.grad_fn})")
            else:
                print(f"{name}: NOT connected (no grad_fn)")
    else:
        print("torchviz not available for visualization")

def trace_computation_graph(tensor, depth=0, max_depth=5):
    """Manually trace the computational graph"""
    if depth > max_depth:
        return
    
    indent = "  " * depth
    print(f"{indent}Tensor: {tensor.shape if hasattr(tensor, 'shape') else 'scalar'}")
    print(f"{indent}grad_fn: {tensor.grad_fn}")
    print(f"{indent}requires_grad: {tensor.requires_grad}")
    
    if hasattr(tensor, 'grad_fn') and tensor.grad_fn is not None:
        if hasattr(tensor.grad_fn, 'next_functions'):
            for i, (next_fn, _) in enumerate(tensor.grad_fn.next_functions):
                if next_fn is not None:
                    print(f"{indent}  -> next_function[{i}]: {next_fn}")
                    if hasattr(next_fn, 'variable') and next_fn.variable is not None:
                        print(f"{indent}     variable shape: {next_fn.variable.shape}")
                        trace_computation_graph(next_fn.variable, depth+1, max_depth)

In [None]:
def debug_parameter_usage(model, loss):
    """Debug which parameters are actually used in computation"""
    
    print("=== DEBUGGING PARAMETER USAGE ===")
    
    # Check if loss requires grad
    print(f"Loss requires_grad: {loss.requires_grad}")
    print(f"Loss grad_fn: {loss.grad_fn}")
    
    # Check each parameter
    print("\nParameter status:")
    for name, param in model.named_parameters():
        print(f"\n{name}:")
        print(f"  requires_grad: {param.requires_grad}")
        print(f"  grad_fn: {param.grad_fn}")
        print(f"  is_leaf: {param.is_leaf}")
        print(f"  value: {param.data}")
        
        # Check if parameter is used in computation
        if param.grad_fn is not None:
            print(f"  ✓ Connected to computation graph")
        else:
            print(f"  ✗ NOT connected to computation graph")
    
    print("\n=== TRACING LOSS COMPUTATION ===")
    trace_computation_graph(loss)

def check_sde_computation_flow(model, times, z0s):
    """Check if SDE computation connects to parameters"""
    
    print("=== CHECKING SDE COMPUTATION FLOW ===")
    
    # Check if z0s connects to qz0_mean and qz0_logvar
    print(f"z0s requires_grad: {z0s.requires_grad}")
    print(f"z0s grad_fn: {z0s.grad_fn}")
    
    # Manually check if qz0_mean is used in z0s computation
    print(f"\nqz0_mean in z0s computation: {model.qz0_mean.grad_fn is not None}")
    print(f"qz0_logvar in z0s computation: {model.qz0_logvar.grad_fn is not None}")
    
    # Check the SDE integration
    print(f"\nChecking torchsde.sdeint connection...")
    # The issue might be here - torchsde.sdeint might break the gradient flow