Here is **Chapter 11: Deep Learning Frameworks** — mastering the tools of modern AI.

---

# **CHAPTER 11: DEEP LEARNING FRAMEWORKS**

*The Engines of Deep Learning*

## **Chapter Overview**

Building neural networks from scratch taught you the mathematics; now you need industrial-strength tools. PyTorch dominates research, TensorFlow powers production, and JAX represents the future of high-performance computing. This chapter makes you fluent in all three, with emphasis on writing efficient, scalable, and maintainable deep learning code.

**Estimated Time:** 50-60 hours (3-4 weeks)  
**Prerequisites:** Chapter 10 (Neural Network fundamentals), strong Python skills

---

## **11.0 Learning Objectives**

By the end of this chapter, you will be able to:
1. Build and train models in PyTorch using both imperative and object-oriented patterns
2. Utilize TensorFlow's ecosystem (Keras, tf.data, TF Serving) for production pipelines
3. Leverage JAX for high-performance research code with functional transformations (jit, vmap, grad)
4. Implement distributed training across multiple GPUs (DDP, FSDP, data/model parallelism)
5. Export models for production deployment (ONNX, TorchScript, SavedModel)
6. Profile and optimize training throughput (mixed precision, gradient accumulation, efficient data loading)

---

## **11.1 PyTorch: The Research Standard**

PyTorch's dynamic computation graphs and Pythonic design make it the preferred choice for research and rapid prototyping.

#### **11.1.1 Tensors and Autograd**

Tensors are n-dimensional arrays with automatic differentiation capabilities.

```python
import torch
import torch.nn as nn

# Creation
x = torch.tensor([[1., 2.], [3., 4.]], requires_grad=True)
x = torch.randn(3, 3, device='cuda', dtype=torch.float32)  # GPU tensor
x = torch.zeros(2, 3).to('mps')  # Apple Silicon

# Operations (automatically tracked if requires_grad=True)
y = x ** 2 + 3 * x + 1
z = y.mean()

# Backward pass
z.backward()  # Computes dz/dx for all tensors with requires_grad=True
print(x.grad)  # Gradient of z w.r.t. x

# Detaching (stopping gradient flow)
with torch.no_grad():
    evaluation = model(x)  # No graph built, saves memory
    
# Or: x.detach() creates new tensor without grad history
```

**In-place Operations:** Avoid when possible (breaks autograd graph). Operations ending with `_` (e.g., `add_()`, `zero_()`) are in-place.

#### **11.1.2 nn.Module and Model Definition**

```python
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    
    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        out += identity  # Residual connection
        out = self.relu(out)
        return out

# Sequential API for simple models
simple_model = nn.Sequential(
    nn.Linear(784, 256),
    nn.ReLU(),
    nn.Dropout(0.5),
    nn.Linear(256, 10)
)
```

#### **11.1.3 Training Loop (The PyTorch Way)**

```python
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (inputs, targets) in enumerate(dataloader):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero gradients
        optimizer.zero_grad()  # or optimizer.zero_grad(set_to_none=True) for speed
        
        # Forward
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward
        loss.backward()
        
        # Gradient clipping (optional)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        # Update
        optimizer.step()
        
        # Metrics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        if batch_idx % 100 == 0:
            print(f'Batch {batch_idx}, Loss: {loss.item():.4f}')
    
    return running_loss / len(dataloader), 100. * correct / total

# Learning rate scheduling
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)
# Or: ReduceLROnPlateau, StepLR, OneCycleLR

# Training loop
for epoch in range(num_epochs):
    train_loss, train_acc = train_epoch(model, trainloader, criterion, optimizer, device)
    val_loss, val_acc = validate(model, valloader, criterion, device)
    scheduler.step()
    
    # Save checkpoint
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': loss,
    }, 'checkpoint.pth')
```

#### **11.1.4 Data Loading and Transforms**

```python
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

class CustomDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        self.transform = transform
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = load_image(self.data.iloc[idx]['path'])
        label = self.data.iloc[idx]['label']
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Transforms pipeline
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),  # Converts PIL [0,255] to Tensor [0,1]
    transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                        std=[0.229, 0.224, 0.225])
])

# DataLoader
trainloader = DataLoader(
    dataset, 
    batch_size=64, 
    shuffle=True, 
    num_workers=4,      # Parallel data loading
    pin_memory=True,    # Speeds up CPU->GPU transfer
    persistent_workers=True,  # Keep workers alive between epochs
    prefetch_factor=2   # Batches to prefetch per worker
)
```

#### **11.1.5 Mixed Precision Training**

Uses FP16 (half precision) for compute, FP32 for master weights. 2-3x speedup on modern GPUs.

```python
from torch.cuda.amp import autocast, GradScaler

scaler = GradScaler()

for inputs, targets in dataloader:
    inputs, targets = inputs.cuda(), targets.cuda()
    optimizer.zero_grad()
    
    # Automatic Mixed Precision
    with autocast():
        outputs = model(inputs)
        loss = criterion(outputs, targets)
    
    # Scale loss, backward, unscale, step
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
```

---

## **11.2 TensorFlow and Keras: The Production Ecosystem**

TensorFlow excels in production deployment, mobile optimization (TFLite), and scalable data pipelines (tf.data).

#### **11.2.1 Keras Functional API**

More flexible than Sequential API for multi-input/multi-output models.

```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

# Functional API
inputs = keras.Input(shape=(784,), name='img')
x = layers.Dense(64, activation='relu')(inputs)
x = layers.Dense(64, activation='relu')(x)
outputs = layers.Dense(10, activation='softmax', name='predictions')(x)

model = keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')

# Multi-input example
text_input = keras.Input(shape=(None,), dtype='int32', name='text')
image_input = keras.Input(shape=(224, 224, 3), name='image')

# Process each
x1 = layers.Embedding(10000, 128)(text_input)
x1 = layers.LSTM(64)(x1)

x2 = layers.Conv2D(32, 3, activation='relu')(image_input)
x2 = layers.GlobalMaxPooling2D()(x2)

# Combine
combined = layers.concatenate([x1, x2])
output = layers.Dense(1, activation='sigmoid', name='output')(combined)

model = keras.Model(inputs=[text_input, image_input], outputs=output)
```

#### **11.2.2 Custom Training with tf.GradientTape**

PyTorch-style imperative training in TensorFlow.

```python
optimizer = keras.optimizers.Adam(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc = keras.metrics.SparseCategoricalAccuracy()

@tf.function  # Compiles to graph (speeds up execution)
def train_step(x, y):
    with tf.GradientTape() as tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
    
    grads = tape.gradient(loss_value, model.trainable_weights)
    optimizer.apply_gradients(zip(grads, model.trainable_weights))
    
    train_acc.update_state(y, logits)
    return loss_value

# Training loop
for epoch in range(epochs):
    for step, (x_batch, y_batch) in enumerate(train_dataset):
        loss_value = train_step(x_batch, y_batch)
        if step % 200 == 0:
            print(f'Epoch {epoch}, Step {step}, Loss: {loss_value:.4f}')
```

#### **11.2.3 tf.data for High-Performance Input Pipelines**

```python
def preprocess(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.resize(image, [224, 224])
    image = image / 255.0
    return image, label

# Create dataset
list_ds = tf.data.Dataset.list_files('path/to/data/*.jpg')
labeled_ds = list_ds.map(lambda x: (load_image(x), get_label(x)))

# Optimization pipeline
dataset = labeled_ds.cache()  # Cache in memory after first epoch
dataset = dataset.shuffle(buffer_size=1000)
dataset = dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
dataset = dataset.batch(32)
dataset = dataset.prefetch(tf.data.AUTOTUNE)  # Overlap preprocessing and training

# For very large datasets: cache to file instead of memory
dataset = dataset.cache(filename='./cache.tf-data')
```

---

## **11.3 JAX: The Future of High-Performance ML**

JAX combines NumPy's syntax with XLA (Accelerated Linear Algebra) compilation for TPUs/GPUs, plus automatic differentiation and vectorization.

#### **11.3.1 Functional Approach**

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from flax import linen as nn  # Neural network library
import optax  # Optimization library

# Pure functions (no side effects)
def predict(params, inputs):
    for w, b in params:
        outputs = jnp.dot(inputs, w) + b
        inputs = jnp.maximum(0, outputs)  # ReLU
    return outputs

# Gradients
def loss_fn(params, x, y):
    pred = predict(params, x)
    return jnp.mean((pred - y) ** 2)

grad_fn = jit(grad(loss_fn))  # JIT compiles to optimized XLA
grads = grad_fn(params, x, y)
```

#### **11.3.2 Vectorization with vmap**

Automatically batch operations without manual loops.

```python
# Without vmap: loop over batch
def apply_model_single(params, x):
    return predict(params, x)

batch_predictions = [apply_model_single(params, x) for x in batch]

# With vmap: automatic batching
batch_predict = vmap(apply_model_single, in_axes=(None, 0))
predictions = batch_predict(params, batch)  # Shape (batch_size, output_dim)
```

#### **11.3.3 Neural Networks with Flax**

```python
class MLP(nn.Module):
    features: tuple = (256, 256, 10)
    
    @nn.compact
    def __call__(self, x):
        for feat in self.features[:-1]:
            x = nn.Dense(feat)(x)
            x = nn.relu(x)
        x = nn.Dense(self.features[-1])(x)
        return x

# Initialize
model = MLP(features=[256, 256, 10])
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 784)))

# Training step
@jit
def train_step(params, opt_state, x, y):
    def loss_fn(params):
        pred = model.apply(params, x)
        return jnp.mean(optax.softmax_cross_entropy(pred, y))
    
    loss, grads = jax.value_and_grad(loss_fn)(params)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss
```

---

## **11.4 Distributed Training**

#### **11.4.1 Data Parallelism (DDP - DistributedDataParallel)**

Each GPU has copy of model, processes different batch slice, gradients synchronized.

**PyTorch DDP:**
```python
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'
    init_process_group("nccl", rank=rank, world_size=world_size)

def train(rank, world_size):
    setup(rank, world_size)
    model = MyModel().to(rank)
    ddp_model = DDP(model, device_ids=[rank])
    
    # DataLoader with DistributedSampler
    sampler = torch.utils.data.DistributedSampler(dataset)
    dataloader = DataLoader(dataset, sampler=sampler, ...)
    
    for epoch in range(num_epochs):
        sampler.set_epoch(epoch)  # Important for proper shuffling
        for batch in dataloader:
            # Training loop
            pass
    
    destroy_process_group()

# Launch
world_size = torch.cuda.device_count()
mp.spawn(train, args=(world_size,), nprocs=world_size, join=True)
```

#### **11.4.2 Fully Sharded Data Parallel (FSDP)**

Shards model parameters across GPUs (needed for models > 1B parameters that don't fit in single GPU memory).

```python
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import transformer_auto_wrap_policy

model = FSDP(
    model,
    auto_wrap_policy=transformer_auto_wrap_policy,
    device_id=torch.cuda.current_device(),
    limit_all_gathers=True  # Reduce memory pressure
)
```

#### **11.4.3 Model Parallelism**

Split model layers across GPUs (for massive models like GPT-3).

```python
# Simple pipeline parallelism
device_0 = torch.device('cuda:0')
device_1 = torch.device('cuda:1')

class ModelParallel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(784, 256).to(device_0)
        self.layer2 = nn.Linear(256, 10).to(device_1)
    
    def forward(self, x):
        x = x.to(device_0)
        x = self.layer1(x)
        x = x.to(device_1)
        x = self.layer2(x)
        return x
```

---

## **11.5 Model Export and Deployment**

#### **11.5.1 PyTorch: TorchScript and ONNX**

**TorchScript (for C++ deployment):**
```python
# Tracing
model.eval()
example_input = torch.rand(1, 3, 224, 224)
traced_script_module = torch.jit.trace(model, example_input)
traced_script_module.save("model.pt")

# Scripting (for control flow)
scripted_model = torch.jit.script(model)
```

**ONNX (framework-agnostic):**
```python
torch.onnx.export(
    model,
    example_input,
    "model.onnx",
    export_params=True,
    opset_version=11,
    input_names=['input'],
    output_names=['output'],
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}
)
```

#### **11.5.2 TensorFlow: SavedModel**

```python
# Save the complete model
model.save('saved_model/my_model')  # SavedModel format

# Load for inference
loaded_model = keras.models.load_model('saved_model/my_model')

# Convert to TensorFlow Lite (mobile/edge)
converter = tf.lite.TFLiteConverter.from_saved_model('saved_model/my_model')
tflite_model = converter.convert()
open('model.tflite', 'wb').write(tflite_model)
```

---

## **11.6 Workbook Labs**

### **Lab 1: PyTorch CIFAR-10 Classifier**
Build ResNet-18 from scratch (no torchvision.models):
1. Implement BasicBlock with residual connections
2. Train with mixed precision
3. Implement custom learning rate scheduler (warmup + cosine decay)
4. Achieve >85% accuracy

**Deliverable:** Training script with tensorboard logging.

### **Lab 2: TensorFlow Production Pipeline**
Build tf.data pipeline for large image dataset (>100GB):
1. TFRecord creation and reading
2. Data augmentation in graph (tf.image)
3. Train with distribution strategy (MultiWorkerMirroredStrategy)
4. Export SavedModel with serving signatures

**Deliverable:** End-to-end pipeline with throughput benchmarks (images/second).

### **Lab 3: JAX vs PyTorch Comparison**
Implement same transformer block in both frameworks:
1. Measure forward and backward pass speed
2. Implement custom gradient (stop gradient for certain ops)
3. Use vmap for batched inference comparison

**Deliverable:** Benchmark report showing JAX XLA advantages.

### **Lab 4: Distributed Training from Scratch**
Train ResNet-50 on ImageNet subset using PyTorch DDP:
1. Multi-GPU on single node
2. Gradient accumulation to simulate large batch
3. Checkpoint saving/loading for fault tolerance
4. Measure scaling efficiency (1 GPU time vs N GPU time)

**Deliverable:** Distributed training script with 80%+ scaling efficiency (i.e., 4 GPUs should be ~3.2x faster than 1).

---

## **11.7 Common Pitfalls**

1. **CUDA OOM (Out of Memory):**
   - Use `torch.cuda.empty_cache()` between epochs (rarely needed, usually symptom of bug)
   - Reduce batch size or use gradient accumulation
   - Check for retained computation graph: `loss.backward()` not `total_loss += loss` (retains graph!)

2. **DDP Hanging:**
   - Ensure all processes reach same number of backward calls
   - Check `set_epoch()` on DistributedSampler
   - Ensure same batch size on all GPUs (drop_last=True if needed)

3. **TF Data Pipeline Bottlenecks:**
   - Use `tf.data.AUTOTUNE` for num_parallel_calls and prefetch
   - Profile with `tf.data.experimental.Analysis`

4. **JAX Random Number Gotcha:**
   - JAX requires explicit PRNGKey management (not stateful like NumPy). Split keys properly!

5. **Mixed Precision Underflow:**
   - Some layers (BatchNorm, Softmax) must stay in FP32. `autocast` handles this mostly, but check losses don't become NaN.

---

## **11.8 Interview Questions**

**Q1:** What is the difference between PyTorch's DataParallel and DistributedDataParallel (DDP)?
*A: DataParallel is single-process, multi-threaded, replicates model on each GPU, and scatters/gathers data. It's slower due to GIL contention and GPU 0 bottleneck. DDP is multi-process, each GPU has its own process, uses ring-allreduce for gradient synchronization, and is significantly faster. DDP is the production standard.*

**Q2:** Explain XLA and why JAX uses it. What are its advantages?
*A: XLA (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that optimizes computations by fusing operations, eliminating intermediate allocations, and generating efficient machine code. JAX uses XLA to compile functional programs to GPU/TPU, enabling aggressive optimization across operations (fusion) that aren't possible with eager mode. Benefits: better memory usage, kernel fusion, and TPU support.*

**Q3:** What is gradient accumulation and when would you use it?
*A: Gradient accumulation splits a large logical batch into smaller micro-batches that fit in GPU memory. Forward/backward is run on each micro-batch, gradients are accumulated (summed), and optimizer step happens after N micro-batches. Used when you want large batch training (for batch norm stability or convergence properties) but don't have enough GPU memory for the full batch.*

**Q4:** How do you debug a distributed training job that hangs?
*A: 1) Check all ranks are initialized (print rank at start). 2) Ensure same number of backward calls on all ranks (uneven data causes deadlock). 3) Use `torch.distributed.barrier()` to identify where hang occurs. 4) Check NCCL environment variables (timeouts). 5) Ensure no process crashed leaving others waiting. 6) Use single GPU mode to verify logic works first.*

**Q5:** When would you choose TorchScript over ONNX for model deployment?
*A: TorchScript is better for PyTorch-specific features (custom ops, complex control flow) and when deploying to PyTorch Mobile/C++ environments. ONNX is better for cross-framework deployment (deploying to TensorRT, OpenVINO, or non-PyTorch runtimes) and when you need optimization tools specific to ONNX ecosystem. TorchScript preserves PyTorch semantics; ONNX is a generic exchange format.*

---

## **11.9 Further Reading**

**Documentation:**
- PyTorch Performance Tuning Guide: https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html
- JAX The Sharp Bits: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html
- TensorFlow Performance Guide: https://www.tensorflow.org/guide/profiler

**Papers:**
- "PyTorch Distributed: Experiences on Accelerating Data Parallel Training" (Li et al., 2020)
- "JAX: Composable transformations of Python+NumPy programs" (Bradbury et al.)

---

## **11.10 Checkpoint Project: Multi-Framework Model Zoo**

Implement identical ResNet-50 architecture in all three frameworks and benchmark.

**Requirements:**

1. **Model Specification:**
   - ResNet-50 v1.5 architecture
   - ImageNet preprocessing (same in all frameworks)
   - Synchronized BatchNorm (for distributed)

2. **Implementations:**
   - **PyTorch:** Native nn.Module, DDP training
   - **TensorFlow:** Keras Model with custom training loop or fit()
   - **JAX:** Flax linen module, pmap for multi-GPU

3. **Benchmarking:**
   - Throughput (images/sec) single GPU and 4-GPU
   - Memory usage
   - Lines of code (complexity metric)
   - Ease of debugging (profile one iteration)

4. **Interoperability:**
   - Export PyTorch model to ONNX
   - Run ONNX in TensorFlow (onnx-tf)
   - Compare outputs (numerical accuracy)

**Deliverable:**
- `benchmarks/` directory with three implementations
- `report.md` with performance tables and analysis
- Docker Compose setup to run all three with fixed dependencies

**Success Criteria:**
- Within 5% accuracy of each other on ImageNet validation (top-1)
- JAX achieves highest throughput (XLA optimization)
- PyTorch easiest to debug (dynamic nature)
- TensorFlow best for production deployment features

---

**End of Chapter 11**

*You now command the tools of modern deep learning. Chapter 12 will cover Convolutional Neural Networks (CNNs) for computer vision.*

---

<div style='width:100%; display:flex; justify-content:space-between; align-items:center; margin: 1em 0;'>
  <a href='10. neural_network_fundamentals.ipynb' style='font-weight:bold; font-size:1.05em;'>&larr; Previous</a>
  <a href='../TOC.md' style='font-weight:bold; font-size:1.05em; text-align:center;'>Table of Contents</a>
  <a href='12. convolutional_neural_networks.ipynb' style='font-weight:bold; font-size:1.05em;'>Next &rarr;</a>
</div>
