<a href="https://colab.research.google.com/github/Lcocks/DS6050-DeepLearning/blob/main/PyTorch_Decorators.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Essential Decorators for PyTorch Deep Learning

In the class we discussed `@dataclass` and how it can simplify our life, here I prepared a guide to some other decorators that make your PyTorch code cleaner, faster, and more maintainable.

---

## Table of Contents

1. [Python Built-in Decorators](#1-python-built-in-decorators)
2. [PyTorch-Specific Decorators](#2-pytorch-specific-decorators)
3. [Performance & Caching Decorators](#3-performance--caching-decorators)
4. [Custom Utility Decorators](#4-custom-utility-decorators)

---

## 1. Python Built-in Decorators

### `@dataclass` – Data Containers

Auto-generates `__init__`, `__repr__`, `__eq__` for classes that hold data (configs, hyperparameters).

```python
from dataclasses import dataclass

@dataclass
class TrainConfig:
    batch_size: int = 32
    lr: float = 1e-3
    epochs: int = 10
    device: str = "cuda"

config = TrainConfig(lr=3e-4)
print(config)  # TrainConfig(batch_size=32, lr=0.0003, epochs=10, device='cuda')
```

**Use case:** Model configs, training hyperparameters, dataset settings.

---

### `@property` – Computed Attributes

Exposes methods as attributes, useful for lazy computation or derived values.

```python
class ConvNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.conv2 = nn.Conv2d(64, 128, 3)
    
    @property
    def num_parameters(self):
        """Compute total parameters on-the-fly."""
        return sum(p.numel() for p in self.parameters())

model = ConvNet()
print(model.num_parameters)  # Access like an attribute, not a method call
```

**Use case:** Model introspection (param count, layer info), dynamic properties.

---

### `@staticmethod` & `@classmethod` – Non-Instance Methods

**`@staticmethod`**: Function that doesn't need `self` or class state.

```python
class DataProcessor:
    @staticmethod
    def normalize(x):
        """Utility function that doesn't need instance state."""
        return (x - x.mean()) / x.std()

# Call without creating an instance
normalized = DataProcessor.normalize(tensor)
```

**`@classmethod`**: Function that receives the class (`cls`) instead of instance (`self`).

```python
class ResNet(nn.Module):
    @classmethod
    def from_pretrained(cls, model_name):
        """Factory method to load pretrained models."""
        model = cls()
        state_dict = torch.hub.load_state_dict_from_url(MODEL_URLS[model_name])
        model.load_state_dict(state_dict)
        return model

model = ResNet.from_pretrained("resnet50")  # Alternative constructor
```

**Use case:** Utility functions, factory methods, alternative constructors.

---

## 2. PyTorch-Specific Decorators

### `@torch.no_grad()` – Disable Gradient Computation

Disables autograd for inference/validation, reducing memory and speeding up computation.

```python
@torch.no_grad()
def evaluate(model, dataloader):
    """Validation loop without gradient tracking."""
    model.eval()
    total_loss = 0
    for batch in dataloader:
        outputs = model(batch)
        total_loss += loss_fn(outputs, batch.labels)
    return total_loss / len(dataloader)
```

**Why?** Saves memory (~2x), faster forward pass. **Always use for inference.**

**Alternative (PyTorch 1.9+):** `@torch.inference_mode()` – even faster, more restrictive.

```python
@torch.inference_mode()
def predict(model, x):
    """Inference with maximum optimization."""
    return model(x).argmax(dim=-1)
```

---

### `@torch.compile()` – JIT Compilation (PyTorch 2.0+)

Compiles your model for faster execution using TorchDynamo.

```python
model = ResNet50()
model = torch.compile(model)  # One-line speedup!

# Or as a decorator on custom modules
@torch.compile
class CustomLayer(nn.Module):
    def forward(self, x):
        return torch.relu(x @ self.weight)
```

**Benefits:** 30-200% speedup on many models. **Trade-off:** Longer first run (compilation).

---

### `@torch.jit.script` – TorchScript for Optimization

Converts Python code to optimized intermediate representation for production deployment.

```python
@torch.jit.script
def fused_gelu(x):
    """Custom activation with TorchScript optimization."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608 * (x + 0.044715 * x ** 3)))

# Can export to C++, mobile, etc.
scripted = torch.jit.script(model)
scripted.save("model.pt")
```

**Use case:** Production deployment, mobile, edge devices, C++ inference.

---

## 3. Performance & Caching Decorators

### `@lru_cache` – Memoization

Caches function results to avoid redundant computation.

```python
from functools import lru_cache

@lru_cache(maxsize=128)
def get_positional_encoding(seq_len, d_model):
    """Compute once, reuse for same seq_len."""
    position = torch.arange(seq_len).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
    pe = torch.zeros(seq_len, d_model)
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    return pe

# First call computes, subsequent calls with same args return cached result
pe1 = get_positional_encoding(512, 768)  # Computed
pe2 = get_positional_encoding(512, 768)  # Cached (instant)
```

**Use case:** Expensive computations with repeated inputs (positional encodings, lookup tables).

**Note:** Use `@cache` (Python 3.9+) for unlimited cache size.

---

### `@functools.wraps` – Preserve Function Metadata

Essential when writing custom decorators to preserve original function's name/docstring.

```python
from functools import wraps
import time

def timer(func):
    """Decorator to time function execution."""
    @wraps(func)  # Preserves func.__name__, func.__doc__, etc.
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        print(f"{func.__name__} took {time.time() - start:.4f}s")
        return result
    return wrapper

@timer
def train_epoch(model, dataloader):
    """Train for one epoch."""
    # ... training code ...
    pass

train_epoch(model, train_loader)  # Output: "train_epoch took 45.2341s"
```

---

## 4. Custom Utility Decorators

### Timing & Profiling Decorator

```python
import time
from functools import wraps

def timeit(func):
    """Measure execution time of functions."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"⏱️  {func.__name__}: {elapsed:.4f}s")
        return result
    return wrapper

@timeit
def forward_pass(model, batch):
    return model(batch)
```

---

### GPU Memory Tracking Decorator

```python
def track_gpu_memory(func):
    """Monitor GPU memory usage before/after function call."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        if torch.cuda.is_available():
            torch.cuda.reset_peak_memory_stats()
            start_mem = torch.cuda.memory_allocated() / 1024**2  # MB
        
        result = func(*args, **kwargs)
        
        if torch.cuda.is_available():
            end_mem = torch.cuda.memory_allocated() / 1024**2
            peak_mem = torch.cuda.max_memory_allocated() / 1024**2
            print(f"📊 {func.__name__}: {start_mem:.1f}MB → {end_mem:.1f}MB (peak: {peak_mem:.1f}MB)")
        
        return result
    return wrapper

@track_gpu_memory
def train_batch(model, batch):
    loss = model(batch).loss
    loss.backward()
    return loss
```

---

### Automatic Mixed Precision (AMP) Decorator

```python
from torch.cuda.amp import autocast

def mixed_precision(func):
    """Enable automatic mixed precision for faster training."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        with autocast():
            return func(*args, **kwargs)
    return wrapper

@mixed_precision
def forward_pass(model, x):
    return model(x)  # Automatically uses fp16 where beneficial
```

---

### Reproducibility Decorator

```python
def seed_everything(seed=42):
    """Decorator to ensure reproducible results."""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            torch.manual_seed(seed)
            torch.cuda.manual_seed_all(seed)
            np.random.seed(seed)
            random.seed(seed)
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            return func(*args, **kwargs)
        return wrapper
    return decorator

@seed_everything(seed=123)
def train_model():
    # Training will be reproducible
    pass
```

---

### Gradient Clipping Decorator

```python
def clip_gradients(max_norm=1.0):
    """Automatically clip gradients during backward pass."""
    def decorator(func):
        @wraps(func)
        def wrapper(model, *args, **kwargs):
            loss = func(model, *args, **kwargs)
            if loss.requires_grad:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
            return loss
        return wrapper
    return decorator

@clip_gradients(max_norm=1.0)
def compute_loss(model, batch):
    return model(batch).loss
```

---

### Exception Handling for Training

```python
def safe_training(func):
    """Catch and log exceptions without crashing training."""
    @wraps(func)
    def wrapper(*args, **kwargs):
        try:
            return func(*args, **kwargs)
        except RuntimeError as e:
            if "out of memory" in str(e):
                print(f"⚠️  OOM Error in {func.__name__}. Try reducing batch size.")
                torch.cuda.empty_cache()
            else:
                print(f"❌ Error in {func.__name__}: {e}")
            return None
    return wrapper

@safe_training
def train_step(model, batch):
    return model(batch).loss.backward()
```

---

## Quick Reference Table

| Decorator | Primary Use | Key Benefit |
|:----------|:------------|:------------|
| `@dataclass` | Config classes | Auto-generate boilerplate |
| `@property` | Computed attributes | Clean API, lazy computation |
| `@staticmethod` | Utility functions | No instance needed |
| `@classmethod` | Factory methods | Alternative constructors |
| `@torch.no_grad()` | Inference/validation | 2x memory savings |
| `@torch.inference_mode()` | Pure inference | Maximum speed |
| `@torch.compile()` | Model optimization | 30-200% speedup |
| `@torch.jit.script` | Production deployment | C++ export, mobile |
| `@lru_cache` | Expensive computations | Avoid recomputation |
| `@timeit` | Profiling | Track execution time |
| `@track_gpu_memory` | Memory debugging | Find memory leaks |

---

## Best Practices

1. **Always use `@torch.no_grad()` for evaluation** – Default for validation/test loops
2. **Cache expensive computations** – Use `@lru_cache` for positional encodings, masks
3. **Profile before optimizing** – Use `@timeit` to find bottlenecks
4. **Prefer `@torch.compile` over manual optimization** – Easier and often faster (PyTorch 2.0+)
5. **Use `@property` for model introspection** – Makes debugging easier
6. **Write custom decorators for repetitive patterns** – DRY principle for training loops

---

## Further Reading

- [Python Decorators Guide](https://realpython.com/primer-on-python-decorators/)
- [PyTorch Performance Tuning](https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html)
- [TorchScript Documentation](https://pytorch.org/docs/stable/jit.html)
- [torch.compile Guide](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html)

---