# PyTorch DAGs, Gradients, and Detachment

This notebook covers PyTorch's automatic differentiation system and computation graph management:

## 📚 **What You'll Learn**

- **Computation Graph (DAG)**: Understanding PyTorch's directed acyclic graph
- **Gradient Tracking**: Enabling and disabling automatic differentiation
- **Tensor Detachment**: Breaking gradient flow and memory management
- **Autograd System**: How PyTorch computes gradients automatically

## 🎯 **Learning Objectives**

By the end of this notebook, you'll understand:
- How PyTorch builds computation graphs for backpropagation
- When and how to enable/disable gradient tracking
- The difference between leaf and non-leaf tensors
- How to use `.detach()` for memory management and gradient control

This is the foundation of how neural networks learn! 🧠

In [None]:
import torch
import numpy as np

print(f"PyTorch version: {torch.__version__}")
torch.manual_seed(42)  # For reproducibility

## Inspecting the Computation Graph (DAG)

PyTorch builds a Directed Acyclic Graph (DAG) to track operations for automatic differentiation. Let's explore how to inspect it:

In [None]:
# Create tensors that will be part of computation graph
print("=== Building a Computation Graph ===")

x = torch.tensor([[1.0, 2.0]], requires_grad=True)
y = torch.tensor([[3.0, 4.0]], requires_grad=True)

print(f"x: {x}")
print(f"y: {y}")
print(f"x.requires_grad: {x.requires_grad}")
print(f"y.requires_grad: {y.requires_grad}")

# Check initial gradient functions (should be None for leaf tensors)
print(f"x.grad_fn: {x.grad_fn}")  # None - this is a leaf tensor
print(f"y.grad_fn: {y.grad_fn}")  # None - this is a leaf tensor
print(f"x.is_leaf: {x.is_leaf}")  # True - created by user
print(f"y.is_leaf: {y.is_leaf}")  # True - created by user

print("\n=== Operations Create Graph Nodes ===")

# Perform operations - each creates a node in the computation graph
z = x + y  # AddBackward
print(f"z = x + y: {z}")
print(f"z.grad_fn: {z.grad_fn}")
print(f"z.requires_grad: {z.requires_grad}")
print(f"z.is_leaf: {z.is_leaf}")  # False - computed from other tensors

w = z * 2  # MulBackward
print(f"w = z * 2: {w}")
print(f"w.grad_fn: {w.grad_fn}")

result = w.sum()  # SumBackward
print(f"result = w.sum(): {result}")
print(f"result.grad_fn: {result.grad_fn}")

In [None]:
print("\n=== Exploring the Graph Structure ===")

# You can access the next functions in the graph
print(f"result.grad_fn.next_functions: {result.grad_fn.next_functions}")

# Each next_function is a tuple of (function, input_nr)
for i, (func, input_nr) in enumerate(result.grad_fn.next_functions):
    print(f"  Function {i}: {func}, Input number: {input_nr}")

print("\n=== More Complex Example ===")

# Reset and create a more complex computation
a = torch.tensor([2.0, 3.0], requires_grad=True)
b = torch.tensor([4.0, 5.0], requires_grad=True)

# Build computation: result = (a * b).sum() + (a ** 2).mean()
c = a * b              # MulBackward
d = a ** 2             # PowBackward  
e = c.sum()            # SumBackward
f = d.mean()           # MeanBackward
result2 = e + f        # AddBackward

print(f"a: {a}")
print(f"b: {b}")
print(f"c = a * b: {c}, grad_fn: {c.grad_fn}")
print(f"d = a ** 2: {d}, grad_fn: {d.grad_fn}")
print(f"e = c.sum(): {e}, grad_fn: {e.grad_fn}")
print(f"f = d.mean(): {f}, grad_fn: {f.grad_fn}")
print(f"result2 = e + f: {result2}, grad_fn: {result2.grad_fn}")

## Gradient Tracking and Computation

Let's explore how PyTorch tracks and computes gradients:

In [None]:
# Creating tensors with and without gradient tracking
print("=== Creating Tensors with Gradient Tracking ===")

# Default: no gradient tracking
tensor_no_grad = torch.randn(2, 2)
print(f"Default tensor requires_grad: {tensor_no_grad.requires_grad}")

# Explicitly enable gradient tracking
tensor_with_grad = torch.randn(2, 2, requires_grad=True)
print(f"Explicit requires_grad tensor: {tensor_with_grad.requires_grad}")

# Enable gradient tracking on existing tensor
tensor_no_grad.requires_grad_(True)  # In-place modification
print(f"Modified tensor requires_grad: {tensor_no_grad.requires_grad}")

print("\n=== Disabling Gradients Temporarily ===")

# Using torch.no_grad() context manager
x = torch.randn(3, 3, requires_grad=True)
print(f"x.requires_grad: {x.requires_grad}")

with torch.no_grad():
    y = x * 2
    print(f"y.requires_grad (inside no_grad): {y.requires_grad}")

# Outside the context, gradients work normally
z = x * 3
print(f"z.requires_grad (outside no_grad): {z.requires_grad}")

print("\n=== Computing Gradients ===")

# Simple gradient computation example
x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x.pow(2).sum()  # y = x₁² + x₂²

print(f"x: {x}")
print(f"y: {y}")
print(f"x.grad before backward(): {x.grad}")

# Compute gradients
y.backward()

print(f"x.grad after backward(): {x.grad}")
print("Expected: [2*x₁, 2*x₂] = [4.0, 6.0]")

## Tensor Detachment and Memory Management

The `.detach()` method is crucial for controlling gradient flow and memory management:

In [None]:
print("=== What is .detach()? ===")
print("detach() creates a new tensor that shares the same data but is removed from the computation graph")
print("- Same data (shares memory)")
print("- No gradient tracking (requires_grad=False)")
print("- Not connected to computation graph")

# Example 1: Basic detachment
print("\n=== Basic Detachment ===")

x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
y = x * 2
z = y + 1

print(f"x: {x}")
print(f"y: {y}")
print(f"z: {z}")
print(f"z.requires_grad: {z.requires_grad}")
print(f"z.grad_fn: {z.grad_fn}")

# Detach z
z_detached = z.detach()
print(f"\nAfter detachment:")
print(f"z_detached: {z_detached}")
print(f"z_detached.requires_grad: {z_detached.requires_grad}")
print(f"z_detached.grad_fn: {z_detached.grad_fn}")

# Verify they share the same data
print(f"\nShared memory check:")
print(f"Same data: {torch.equal(z, z_detached)}")
print(f"z.data_ptr(): {z.data_ptr()}")
print(f"z_detached.data_ptr(): {z_detached.data_ptr()}")
print(f"Same memory location: {z.data_ptr() == z_detached.data_ptr()}")

print("\n=== Detach vs Clone vs Copy ===")

original = torch.tensor([1.0, 2.0], requires_grad=True)
print(f"original: {original}, requires_grad: {original.requires_grad}")

# Method 1: detach() - shares memory, no gradients
detached = original.detach()
print(f"detached: {detached}, requires_grad: {detached.requires_grad}")

# Method 2: clone() - new memory, keeps gradients
cloned = original.clone()
print(f"cloned: {cloned}, requires_grad: {cloned.requires_grad}")

# Method 3: detach().clone() - new memory, no gradients
detached_cloned = original.detach().clone()
print(f"detached_cloned: {detached_cloned}, requires_grad: {detached_cloned.requires_grad}")

# Demonstrate memory sharing
print(f"\nMemory sharing:")
print(f"original ptr: {original.data_ptr()}")
print(f"detached ptr: {detached.data_ptr()} (same: {original.data_ptr() == detached.data_ptr()})")
print(f"cloned ptr: {cloned.data_ptr()} (same: {original.data_ptr() == cloned.data_ptr()})")
print(f"detached_cloned ptr: {detached_cloned.data_ptr()} (same: {original.data_ptr() == detached_cloned.data_ptr()})")

print("\n=== When to Use .detach() ===")
use_cases = [
    "1. Converting tensor to NumPy: tensor.detach().numpy()",
    "2. Stopping gradient flow in part of computation",
    "3. Creating inputs for inference without gradients",
    "4. Memory optimization by breaking graph references",
    "5. Debugging: isolating parts of computation graph"
]

for use_case in use_cases:
    print(use_case)

## Practical Examples: Real-world Usage

Let's see some practical examples of DAGs, gradients, and detachment in deep learning contexts:

In [None]:
# Example 1: Simple neural network forward pass with gradient tracking
print("=== Example 1: Neural Network Forward Pass ===")

# Simulate a simple linear layer: y = Wx + b
W = torch.randn(3, 2, requires_grad=True)  # Weight matrix
b = torch.randn(3, requires_grad=True)     # Bias vector
x = torch.randn(5, 2)                      # Input batch (5 samples, 2 features)

print(f"W shape: {W.shape}, requires_grad: {W.requires_grad}")
print(f"b shape: {b.shape}, requires_grad: {b.requires_grad}")
print(f"x shape: {x.shape}, requires_grad: {x.requires_grad}")

# Forward pass
y = torch.matmul(x, W.T) + b  # Linear transformation
print(f"y shape: {y.shape}, requires_grad: {y.requires_grad}")
print(f"y.grad_fn: {y.grad_fn}")

# Apply activation function
z = torch.relu(y)
print(f"z (after ReLU): {z.shape}, requires_grad: {z.requires_grad}")
print(f"z.grad_fn: {z.grad_fn}")

# Compute loss (mean squared error with target)
target = torch.randn(5, 3)
loss = torch.mean((z - target) ** 2)
print(f"loss: {loss.item():.4f}, requires_grad: {loss.requires_grad}")
print(f"loss.grad_fn: {loss.grad_fn}")

print("\n=== Example 2: Training Step with Gradients ===")

# Compute gradients
loss.backward()

print(f"W.grad shape: {W.grad.shape}")
print(f"W.grad:\n{W.grad}")
print(f"b.grad shape: {b.grad.shape}")
print(f"b.grad: {b.grad}")

# Simulate optimizer step (gradient descent)
learning_rate = 0.01
with torch.no_grad():  # Disable gradient tracking for parameter updates
    W -= learning_rate * W.grad
    b -= learning_rate * b.grad

print("Parameters updated with gradient descent")

# Clear gradients for next iteration
W.grad.zero_()
b.grad.zero_()
print("Gradients cleared for next iteration")

## Summary

### Key Concepts Covered:

#### Computation Graph (DAG):
- **DAG Structure**: PyTorch builds a Directed Acyclic Graph to track operations
- **Node Types**: Leaf tensors (created by user) vs intermediate tensors (computed)
- **Graph Functions**: Each operation creates a `grad_fn` for backpropagation
- **Graph Inspection**: Use `.grad_fn`, `.is_leaf`, and `.next_functions` to explore

#### Gradient Tracking:
- **requires_grad**: Controls whether gradients are computed for a tensor
- **Gradient Flow**: Only tensors with `requires_grad=True` participate in gradients
- **Context Managers**: Use `torch.no_grad()` to temporarily disable gradients
- **Backward Pass**: Call `.backward()` to compute gradients through the graph

#### Tensor Detachment:
- **Memory Sharing**: `.detach()` creates new tensor sharing same data
- **Gradient Breaking**: Detached tensors have `requires_grad=False`
- **Use Cases**: Converting to NumPy, stopping gradient flow, memory optimization
- **Alternatives**: `.clone()` for new memory, `.detach().clone()` for both

### Best Practices:
1. **Enable gradients** only for parameters that need training
2. **Use `torch.no_grad()`** during inference to save memory
3. **Clear gradients** with `.zero_()` between training steps
4. **Detach tensors** when moving data between training components
5. **Monitor memory usage** in complex graphs with many intermediate tensors

Understanding these concepts is essential for debugging PyTorch models and optimizing training performance! 🚀