# Python & PyTorch Coding Review — Synopsys ML Internship

**Interview:** Monday, February 24, 2026, with Xin Xu (Principal R&D Engineer)

Redesigned for PRACTICAL assessment of Python + PyTorch fluency.
A senior R&D engineer will likely test:
1. Can you write clean, correct Python? (data structures, OOP, generators)
2. Can you manipulate tensors and data? (NumPy/PyTorch operations)
3. Can you build a proper ML training pipeline? (Dataset, training loop, eval)
4. Can you debug common issues? (shapes, devices, gradients, memory)

**NOT likely:** "Implement a Fourier Neural Operator from scratch"

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
from typing import List, Dict, Tuple, Optional
from collections import defaultdict
import os
import time
import math
from functools import wraps

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

---
## Part 1: Python Fundamentals

These test clean Python — the kind of code you'd write daily at Synopsys.

### Python Data Structures Cheat Sheet

| Type | Mutable? | Ordered? | Example | Use Case |
|------|----------|----------|---------|----------|
| `list` | Yes | Yes | `[1, 2, 3]` | Dynamic collections, sequences |
| `tuple` | No | Yes | `(1, 2, 3)` | Fixed collections, dict keys, return values |
| `dict` | Yes | Yes (3.7+) | `{"a": 1}` | Key-value lookup, configs, JSON-like data |
| `set` | Yes | No | `{1, 2, 3}` | Membership tests, deduplication |
| `defaultdict` | Yes | Yes | `defaultdict(list)` | Grouping without KeyError |

### Dictionary Operations You Must Know

```python
d = {"lr": 0.001, "epochs": 100}

d["lr"]                    # Access → 0.001 (KeyError if missing)
d.get("batch", 32)         # Safe access → 32 (default if missing)
d["new_key"] = "value"     # Insert/update
del d["epochs"]            # Delete a key
"lr" in d                  # Membership check → True

# Iteration patterns
for key in d:              ...   # keys only
for key, val in d.items(): ...   # key-value pairs
for val in d.values():     ...   # values only

# Comprehension
squared = {k: v**2 for k, v in d.items() if isinstance(v, (int, float))}
```

### List Comprehension Patterns

```python
# Basic: [expression for item in iterable]
squares = [x**2 for x in range(10)]

# With filter: [expression for item in iterable if condition]
even_squares = [x**2 for x in range(10) if x % 2 == 0]

# Nested: [expr for outer in iterable1 for inner in iterable2]
pairs = [(i, j) for i in range(3) for j in range(3) if i != j]

# Dict comprehension
name_to_len = {name: len(name) for name in ["dipole", "patch", "horn"]}
```

### Exercise 1: Data Processing with Dictionaries & Comprehensions

**Prompt:** "We have simulation results stored as a list of dictionaries. Write a function to group them by geometry type and compute the average error per group."

**Key concepts tested:**
- `defaultdict(list)` — auto-creates empty list for new keys, avoiding `KeyError` or `if key not in dict` checks
- Dict comprehension — `{k: expr for k, v in dict.items()}` to build the result in one line
- `sum(v) / len(v)` — computing the mean without importing numpy

**`defaultdict` vs regular `dict`:**
```python
# Without defaultdict — verbose and error-prone
groups = {}
for r in results:
    key = r["geometry"]
    if key not in groups:
        groups[key] = []       # Must check every time!
    groups[key].append(r["error"])

# With defaultdict — clean and Pythonic
groups = defaultdict(list)     # Missing keys auto-create empty list
for r in results:
    groups[r["geometry"]].append(r["error"])  # Just works
```

In [None]:
def group_and_average(
    results: List[Dict],
    group_key: str = "geometry",
    value_key: str = "error",
) -> Dict[str, float]:
    """
    Group simulation results by a key and compute the mean of a value field.

    >>> results = [
    ...     {"geometry": "dipole", "error": 0.05, "freq": 2.4e9},
    ...     {"geometry": "patch",  "error": 0.12, "freq": 5.0e9},
    ...     {"geometry": "dipole", "error": 0.03, "freq": 2.4e9},
    ...     {"geometry": "patch",  "error": 0.08, "freq": 5.0e9},
    ...     {"geometry": "horn",   "error": 0.02, "freq": 10e9},
    ... ]
    >>> group_and_average(results)
    {'dipole': 0.04, 'patch': 0.1, 'horn': 0.02}
    """
    # Step 1: Group values by key
    groups = defaultdict(list)
    for r in results:
        key = r[group_key]        # e.g., "dipole"
        value = r[value_key]      # e.g., 0.05
        groups[key].append(value)

    # Step 2: Compute average for each group
    averages = {}
    for key, values in groups.items():
        averages[key] = sum(values) / len(values)

    return averages

In [None]:
def group_and_average(results: list[dict], group_key: str = "geometry", value_key: str = "error"):

    groups = defaultdict(list)
    for r in results:
        groups[r[group_key]].append(r[value_key])
    return {k: sum(v) / len(v) for k, v in groups.items()}

In [None]:
# Test Exercise 1
results = [
    {"geometry": "dipole", "error": 0.05},
    {"geometry": "patch",  "error": 0.12},
    {"geometry": "dipole", "error": 0.03},
    {"geometry": "patch",  "error": 0.08},
]
avg = group_and_average(results)
assert abs(avg["dipole"] - 0.04) < 1e-10
assert abs(avg["patch"] - 0.10) < 1e-10
print(f"Averages: {avg}")
print("[PASS] Exercise 1: group_and_average")

### Exercise 2: Generator for Large File Processing

**Prompt:** "We have a huge CSV of simulation parameters. Write a generator that yields batches of N lines without loading the entire file into memory."

**Key concepts tested:**
- **`yield` vs `return`:** `return` sends one result and the function is done. `yield` produces a value, *pauses* the function, and resumes where it left off on the next call. This makes the function a **generator**.
- **Lazy evaluation:** items are produced one-at-a-time, so only one batch is in memory at any point.
- **`with` statement:** ensures the file is properly closed even if an error occurs (context manager pattern).

**Generator vs List — why it matters for large data:**
```python
# LIST: loads ALL 10GB into memory at once → OOM crash
all_lines = [line for line in open("huge_file.csv")]  # 10GB in RAM!

# GENERATOR: processes one batch at a time → constant memory
def batch_reader(path, batch_size=32):
    batch = []
    with open(path) as f:
        for line in f:
            batch.append(line.strip())
            if len(batch) == batch_size:
                yield batch       # Pause here, return batch
                batch = []        # Resume here on next call
    if batch:
        yield batch               # Don't forget the last partial batch!

# Usage:
for batch in batch_reader("huge_file.csv", batch_size=64):
    process(batch)  # Only 64 lines in memory at a time
```

**Generator expression (one-liner version):**
```python
# List comprehension → creates full list in memory
total = sum([x**2 for x in range(1_000_000)])  # 1M-element list

# Generator expression → computes lazily, no list created
total = sum(x**2 for x in range(1_000_000))    # Nearly zero memory
```

**This is the same principle behind `DataLoader(num_workers=N)`** — it pre-fetches batches lazily so only a few are in memory at once.

In [None]:
def batch_reader(filepath: str, batch_size: int = 32):
    """
    Yield batches of lines from a file. Memory-efficient for large files.

    Why a generator?
    - Simulation datasets can be GBs. Loading all into RAM is wasteful.
    - Generators produce items lazily — only one batch in memory at a time.
    - This is the same principle behind PyTorch's DataLoader with num_workers.
    """
    batch = []
    with open(filepath, "r") as f:
        for line in f:
            batch.append(line.strip())
            if len(batch) == batch_size:
                yield batch
                batch = []
    if batch:  # Don't forget the last incomplete batch!
        yield batch

# Note: Can't easily test without a file, but the pattern is what matters.
# Key points: yield (not return), handles last incomplete batch, uses 'with' for cleanup.
print("[INFO] Exercise 2: batch_reader — review the generator pattern above")

### Exercise 3: Class Design — Simulation Result Container

**Prompt:** "Design a class to hold simulation results with proper validation."

**OOP concepts tested:**

| Concept | Syntax | Purpose |
|---------|--------|---------|
| `__init__` | `def __init__(self, ...)` | Constructor — validate inputs, store attributes |
| `@property` | `@property` above a method | Makes a method look like an attribute — `obj.x` not `obj.x()`. Use for derived/computed values |
| `@classmethod` | `@classmethod` above a method | Alternative constructor — gets `cls` (the class) as first arg. E.g., `SimResult.from_file(path)` |
| `@staticmethod` | `@staticmethod` above a method | No access to `self` or `cls` — just a function in the class namespace |
| `__repr__` | `def __repr__(self)` | Developer-friendly string — what shows when you `print(obj)` or type `obj` in REPL |
| `__eq__` | `def __eq__(self, other)` | Equality check — `obj1 == obj2`. Return `NotImplemented` for wrong types |

**`@property` in detail — why not just use a regular attribute?**
```python
class Result:
    def __init__(self, s_params):
        self.s_params = s_params

    @property
    def n_ports(self):
        return self.s_params.shape[1]  # Computed from data, always in sync

    @property
    def s11_db(self):
        return 20 * np.log10(np.abs(self.s_params[:, 0, 0]))

# Usage — no parentheses! Looks like an attribute:
r = Result(some_data)
print(r.n_ports)    # Not r.n_ports() — cleaner API
print(r.s11_db)     # Recomputed on access, always correct
```

**`@classmethod` — alternative constructors:**
```python
class Result:
    def __init__(self, name, data):
        self.name = name
        self.data = data

    @classmethod
    def from_file(cls, path):           # cls = the class itself (Result)
        data = np.load(path)
        return cls(name=path, data=data)  # Calls __init__

    @classmethod
    def from_dict(cls, d):
        return cls(name=d["name"], data=np.array(d["values"]))

# Usage:
r1 = Result("test", some_array)             # Normal constructor
r2 = Result.from_file("sim_001.npz")        # Alternative from file
r3 = Result.from_dict({"name": "x", ...})   # Alternative from dict
```

In [None]:
class SimulationResult:
    """
    Container for a single simulation result with validation.

    Demonstrates:
    - __init__ with validation
    - __repr__ for debugging
    - __eq__ for comparison
    - Property for derived quantity
    - Class method as alternative constructor
    """

    def __init__(self, name: str, s_params: np.ndarray, freq_ghz: np.ndarray):
        """
        Args:
            name: Design identifier (e.g., "antenna_v3")
            s_params: S-parameter matrix, shape (n_freq, n_ports, n_ports), complex
            freq_ghz: Frequency points in GHz, shape (n_freq,)
        """
        if s_params.shape[0] != freq_ghz.shape[0]:
            raise ValueError(
                f"Frequency dimension mismatch: s_params has {s_params.shape[0]} "
                f"points but freq_ghz has {freq_ghz.shape[0]}"
            )
        self.name = name
        self.s_params = s_params
        self.freq_ghz = freq_ghz

    @property
    def n_ports(self) -> int:
        return self.s_params.shape[1]

    @property
    def s11_db(self) -> np.ndarray:
        """Return S11 in dB: 20 * log10(|S11|)"""
        return 20.0 * np.log10(np.abs(self.s_params[:, 0, 0]) + 1e-12)

    @property
    def resonant_freq_ghz(self) -> float:
        """Frequency where |S11| is minimized (resonance)."""
        return float(self.freq_ghz[np.argmin(self.s11_db)])

    @classmethod
    def from_touchstone(cls, filepath: str) -> "SimulationResult":
        """Alternative constructor: load from a .s2p touchstone file."""
        raise NotImplementedError("Touchstone parsing not implemented for demo")

    def __repr__(self):
        return (
            f"SimulationResult(name='{self.name}', "
            f"ports={self.n_ports}, "
            f"freq=[{self.freq_ghz[0]:.1f}-{self.freq_ghz[-1]:.1f}] GHz, "
            f"resonance={self.resonant_freq_ghz:.2f} GHz)"
        )

    def __eq__(self, other):
        if not isinstance(other, SimulationResult):
            return NotImplemented
        return (
            self.name == other.name
            and np.allclose(self.s_params, other.s_params)
            and np.allclose(self.freq_ghz, other.freq_ghz)
        )

In [None]:
# Test Exercise 3
s_params = np.random.randn(10, 2, 2) + 1j * np.random.randn(10, 2, 2)
freqs = np.linspace(1.0, 10.0, 10)
result = SimulationResult("test_antenna", s_params, freqs)

assert result.n_ports == 2
assert 1.0 <= result.resonant_freq_ghz <= 10.0
print(f"Result: {result}")
print(f"S11 (dB): {result.s11_db}")
print(f"Resonant freq: {result.resonant_freq_ghz:.2f} GHz")
print("[PASS] Exercise 3: SimulationResult")

### Exercise 4: Decorator for Timing Functions

**Prompt:** "Write a decorator to time any function. We use this to profile simulation preprocessing, training, and inference."

**What is a decorator?**

A decorator is a function that takes another function as input, wraps it with extra behavior, and returns the wrapped version. The `@decorator` syntax is just syntactic sugar:

```python
# These two are identical:
@timer
def train():
    ...

# is equivalent to:
def train():
    ...
train = timer(train)
```

**Anatomy of a decorator:**
```python
def timer(func):                    # Takes a function as input
    @wraps(func)                    # Preserves func.__name__, func.__doc__
    def wrapper(*args, **kwargs):   # Accepts ANY arguments
        start = time.perf_counter()
        result = func(*args, **kwargs)   # Call the original function
        elapsed = time.perf_counter() - start
        print(f"{func.__name__}: {elapsed:.4f}s")
        return result               # Return the original result
    return wrapper                  # Return the wrapped function
```

**Why `@wraps(func)`?**
Without it, the wrapped function loses its identity:
```python
# Without @wraps:
print(train.__name__)  # "wrapper" — BAD for debugging!

# With @wraps:
print(train.__name__)  # "train" — the real name is preserved
```

**`*args` and `**kwargs` explained:**
```python
def flexible(*args, **kwargs):
    # args = tuple of positional arguments
    # kwargs = dict of keyword arguments
    print(f"Positional: {args}")    # (1, 2, 3)
    print(f"Keyword: {kwargs}")     # {"lr": 0.01, "epochs": 50}

flexible(1, 2, 3, lr=0.01, epochs=50)
```

In [None]:
def timer(func):
    """
    Decorator that prints execution time of the wrapped function.

    Why @wraps(func)?
    - Preserves the original function's __name__, __doc__, etc.
    - Without it, debugging shows "wrapper" instead of the actual function name.
    """
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.perf_counter()
        result = func(*args, **kwargs)
        elapsed = time.perf_counter() - start
        print(f"[TIMER] {func.__name__}: {elapsed:.4f}s")
        return result
    return wrapper


# Test it
@timer
def slow_function(n):
    """Example function to time."""
    return sum(i**2 for i in range(n))

result = slow_function(1_000_000)
print(f"Result: {result}")
print(f"Function name preserved: {slow_function.__name__}")
print("[PASS] Exercise 4: timer decorator")

### Exercise 5: Error Handling & Defensive Programming

**Prompt:** "Write a function that loads and validates simulation config from a dict. Handle missing keys, wrong types, and invalid values gracefully."

**Key concepts tested:**
- **Fail early with clear messages** — don't let bad config silently propagate into a training run that crashes hours later
- **`isinstance()` for type checking** — `isinstance(x, (int, float))` checks multiple types at once
- **`.get(key, default)` for optional fields** — returns default if key is missing, never raises `KeyError`
- **`try/except` vs validation** — validate upfront (this exercise) is better than catching errors later

**Error handling patterns:**
```python
# Pattern 1: Validate upfront (preferred — fail early)
def train(config):
    if "lr" not in config:
        raise ValueError("Missing 'lr' in config")  # Clear message
    if config["lr"] <= 0:
        raise ValueError(f"lr must be positive, got {config['lr']}")

# Pattern 2: try/except for external operations (file I/O, network)
try:
    data = np.load(filepath)
except FileNotFoundError:
    print(f"Warning: {filepath} not found")
    return None
except Exception as e:
    print(f"Unexpected error: {e}")
    raise  # Re-raise unknown errors — don't silently swallow them!

# Pattern 3: .get() with defaults for optional config
batch_size = config.get("batch_size", 32)      # 32 if missing
device = config.get("device", "cuda")          # "cuda" if missing
```

**`raise` vs `return None` vs silent default:**
```python
# RAISE — for required values that must be correct
if lr <= 0:
    raise ValueError(f"Bad lr: {lr}")   # Crashes immediately — good!

# RETURN NONE — for optional operations that can be skipped
def load_checkpoint(path):
    if not os.path.exists(path):
        return None   # Caller can check: if ckpt is not None: ...

# SILENT DEFAULT — for optional config with sensible defaults
batch = config.get("batch", 32)   # Silent fallback is fine here
```

In [None]:
def validate_config(config: Dict) -> Dict:
    """
    Validate and normalize a training configuration dictionary.

    Returns a clean config dict with defaults filled in.
    Raises ValueError with a clear message if something is wrong.
    """
    required_keys = ["model_type", "learning_rate", "num_epochs"]
    missing = [k for k in required_keys if k not in config]
    if missing:
        raise ValueError(f"Missing required config keys: {missing}")

    # Type checking with clear messages
    if not isinstance(config["learning_rate"], (int, float)):
        raise ValueError(
            f"learning_rate must be numeric, got {type(config['learning_rate']).__name__}"
        )

    lr = float(config["learning_rate"])
    if not (1e-8 <= lr <= 1.0):
        raise ValueError(f"learning_rate={lr} is outside valid range [1e-8, 1.0]")

    valid_models = {"mlp", "gnn", "fno", "siren"}
    if config["model_type"] not in valid_models:
        raise ValueError(
            f"model_type='{config['model_type']}' not in {valid_models}"
        )

    # Return clean config with defaults
    return {
        "model_type": config["model_type"],
        "learning_rate": lr,
        "num_epochs": int(config["num_epochs"]),
        "batch_size": config.get("batch_size", 32),
        "hidden_dim": config.get("hidden_dim", 128),
        "weight_decay": float(config.get("weight_decay", 1e-4)),
        "device": config.get("device", "cuda" if torch.cuda.is_available() else "cpu"),
    }

In [None]:
# Test Exercise 5 — valid config
config = {"model_type": "gnn", "learning_rate": 1e-3, "num_epochs": 50}
clean = validate_config(config)
assert clean["batch_size"] == 32  # Default filled in
print(f"Clean config: {clean}")

# Test invalid model type
try:
    validate_config({"model_type": "transformer", "learning_rate": 1e-3, "num_epochs": 10})
    assert False, "Should have raised ValueError"
except ValueError as e:
    print(f"Caught expected error: {e}")

# Test missing key
try:
    validate_config({"model_type": "gnn"})
except ValueError as e:
    print(f"Caught expected error: {e}")

print("[PASS] Exercise 5: validate_config")

---
## Part 2: NumPy / Tensor Operations

Core data manipulation — the daily bread of ML research engineering.

### Tensor Fundamentals Cheat Sheet

| Concept | NumPy | PyTorch | Notes |
|---------|-------|---------|-------|
| Create | `np.array([1,2])` | `torch.tensor([1,2])` | PyTorch tracks gradients |
| Zeros | `np.zeros((3,4))` | `torch.zeros(3,4)` | Note: tuple vs args |
| Shape | `x.shape` | `x.shape` or `x.size()` | Identical |
| Reshape | `x.reshape(2,3)` | `x.reshape(2,3)` or `x.view(2,3)` | `.view()` requires contiguous memory |
| Transpose | `x.T` | `x.T` or `x.permute(1,0)` | `.permute()` for >2D |
| Matrix multiply | `A @ B` | `A @ B` or `torch.matmul(A,B)` | Same operator |
| Element-wise | `A * B` | `A * B` | Broadcasting applies |
| Sum/Mean | `x.sum(axis=0)` | `x.sum(dim=0)` | `axis` vs `dim` keyword |
| Concatenate | `np.concatenate` | `torch.cat` | Along existing dim |
| Stack | `np.stack` | `torch.stack` | Creates new dim |

### Shape Manipulation — The Most Common Source of Bugs

```python
x = torch.randn(8, 3, 64, 64)   # (batch, channels, H, W)

# Flatten spatial dims
x.reshape(8, 3, -1)              # (8, 3, 4096)  — -1 = infer

# Add a dimension (for broadcasting)
x.unsqueeze(0)                   # (1, 8, 3, 64, 64)
x.unsqueeze(-1)                  # (8, 3, 64, 64, 1)

# Remove size-1 dimensions
x.unsqueeze(0).squeeze(0)        # Back to (8, 3, 64, 64)

# keepdim=True — critical for broadcasting after reduction
mean = x.mean(dim=(-2,-1))              # (8, 3) — dims gone
mean = x.mean(dim=(-2,-1), keepdim=True) # (8, 3, 1, 1) — dims kept
# Now x - mean works via broadcasting!
```

### Broadcasting Rules (must memorize)

Two tensors are broadcastable if, reading dimensions **right to left**:
1. Dimensions are equal, OR
2. One of them is 1, OR
3. One of them doesn't exist (treated as 1)

```python
# Example: (8, 3, 64, 64) - (3, 1, 1) → works!
# Reading right to left:  64 vs 1 ✓,  64 vs 1 ✓,  3 vs 3 ✓,  8 vs (none) ✓

# Example: (8, 3) - (8, 1) → (8, 3) ✓
# Example: (8, 3) - (4, 3) → ERROR! 8 ≠ 4 and neither is 1
```

### Exercise 6: Broadcasting & Vectorized Operations

**Prompt:** "Compute pairwise Euclidean distances between two sets of points WITHOUT loops."

**Key concepts tested:**
- **Vectorization** — replacing Python loops with tensor operations for 100-1000x speedup
- **Broadcasting** — automatic expansion of shapes so operations work on different-sized tensors
- **The expansion trick:** `||a - b||² = ||a||² + ||b||² - 2·a·b` avoids creating the full `(N, M, D)` difference tensor

**Why not just use a loop?**
```python
# SLOW: O(N×M) Python iterations, each calling into C++
for i in range(N):
    for j in range(M):
        dist[i,j] = ((x[i] - y[j])**2).sum().sqrt()

# FAST: One BLAS call for the matrix multiply, vectorized norms
xx = (x*x).sum(dim=1, keepdim=True)  # (N,1) — squared norms
yy = (y*y).sum(dim=1, keepdim=True).T # (1,M) — transposed
xy = x @ y.T                          # (N,M) — dot products
dist = (xx + yy - 2*xy).clamp(min=0).sqrt()
```

**Broadcasting in this exercise:**
```
xx shape: (N, 1)  ─┐
yy shape: (1, M)  ─┤→ xx + yy broadcasts to (N, M)
xy shape: (N, M)  ─┘
```

**Built-in alternative:** `torch.cdist(x, y)` — use this in production, know the manual version for interviews.

In [None]:
def pairwise_distances(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    """
    Compute pairwise Euclidean distances between two point sets.

    Args:
        x: shape (N, D) — N points in D dimensions
        y: shape (M, D) — M points in D dimensions

    Returns:
        dist: shape (N, M) — dist[i,j] = ||x[i] - y[j]||

    Key concept: ||a - b||^2 = ||a||^2 + ||b||^2 - 2*a·b
    This avoids the O(NMD) loop and uses O(NM + ND + MD) with BLAS.
    """
    # Method 1: Using the expansion trick (numerically less stable but fast)
    xx = (x * x).sum(dim=1, keepdim=True)   # (N, 1)
    yy = (y * y).sum(dim=1, keepdim=True).T  # (1, M)
    xy = x @ y.T                              # (N, M)
    dist_sq = xx + yy - 2 * xy
    # Clamp to avoid negative values from numerical errors
    return torch.sqrt(dist_sq.clamp(min=0.0))

    # Method 2 (simpler, more memory): torch.cdist(x, y)

In [None]:
# Test Exercise 6
x = torch.randn(5, 3)
y = torch.randn(7, 3)
dist = pairwise_distances(x, y)

assert dist.shape == (5, 7)
assert (dist >= 0).all()

# Verify against torch.cdist
expected = torch.cdist(x, y)
assert torch.allclose(dist, expected, atol=1e-5), f"Max diff: {(dist - expected).abs().max()}"

print(f"Distance matrix shape: {dist.shape}")
print(f"Sample distances: {dist[0, :3]}")
print("[PASS] Exercise 6: pairwise_distances")

### Exercise 7: Advanced Indexing — Gather and Scatter

**Prompt:** "Given node features and an edge list, gather source/target node features for all edges, then scatter messages back to nodes."

This is the **fundamental operation behind ALL graph neural networks** — including MeshGraphNets (directly relevant to HFSS mesh simulation).

**Key concepts tested:**
- **Fancy indexing:** `tensor[index_tensor]` selects rows by index — this is how we "gather" features
- **`index_add_`:** The reverse operation — "scatter" values back, accumulating at specified indices
- **Edge list format:** `edge_index` shape `(2, E)` where row 0 = source nodes, row 1 = target nodes

**Step-by-step walkthrough:**
```python
node_features = torch.tensor([       # 4 nodes, 3 features each
    [1.0, 0.0, 0.0],  # node 0
    [0.0, 1.0, 0.0],  # node 1
    [0.0, 0.0, 1.0],  # node 2
    [1.0, 1.0, 1.0],  # node 3
])

edge_index = torch.tensor([
    [0, 1, 2],   # source: 0→1, 1→2, 2→3
    [1, 2, 3],   # target
])

# GATHER: get features for each edge's endpoints
src_feat = node_features[edge_index[0]]  # [[1,0,0], [0,1,0], [0,0,1]]
tgt_feat = node_features[edge_index[1]]  # [[0,1,0], [0,0,1], [1,1,1]]

# COMPUTE MESSAGES (e.g., difference)
messages = src_feat - tgt_feat  # [[1,-1,0], [0,1,-1], [-1,-1,0]]

# SCATTER: aggregate messages at target nodes
aggregated = torch.zeros_like(node_features)
aggregated.index_add_(0, edge_index[1], messages)
# Node 1 gets message from edge 0: [1,-1,0]
# Node 2 gets message from edge 1: [0,1,-1]
# Node 3 gets message from edge 2: [-1,-1,0]
```

**Scatter operations comparison:**
| Method | Operation | Notes |
|--------|-----------|-------|
| `index_add_(dim, idx, src)` | `out[idx[i]] += src[i]` | Sum aggregation (most common) |
| `index_copy_(dim, idx, src)` | `out[idx[i]] = src[i]` | Overwrites (last write wins) |
| `scatter_add(dim, idx, src)` | Similar to index_add_ | Different index shape convention |
| `scatter_reduce(dim, idx, src, "mean")` | Mean aggregation | PyTorch 2.0+ |

In [None]:
def gather_scatter_demo(
    node_features: torch.Tensor,  # (num_nodes, feature_dim)
    edge_index: torch.Tensor,      # (2, num_edges) — [src; tgt]
) -> torch.Tensor:
    """
    GNN-style gather-scatter: collect neighbor info, aggregate per node.

    This is the fundamental operation behind ALL graph neural networks.
    Understanding indexing here is critical for mesh-based simulation ML.
    """
    src, tgt = edge_index[0], edge_index[1]

    # GATHER: get features for each edge's source and target
    src_feat = node_features[src]  # (num_edges, feature_dim)
    tgt_feat = node_features[tgt]  # (num_edges, feature_dim)

    # Compute messages (simple example: difference of features)
    messages = src_feat - tgt_feat  # (num_edges, feature_dim)

    # SCATTER: aggregate messages back to target nodes (sum)
    aggregated = torch.zeros_like(node_features)
    aggregated.index_add_(0, tgt, messages)

    return aggregated

In [None]:
# Test Exercise 7
node_feat = torch.randn(4, 8)
edges = torch.tensor([[0, 1, 2, 3],   # source nodes
                       [1, 2, 3, 0]])  # target nodes (a ring graph)

agg = gather_scatter_demo(node_feat, edges)
assert agg.shape == (4, 8)

# Verify: node 1 receives message from node 0 → msg = feat[0] - feat[1]
expected_msg_to_1 = node_feat[0] - node_feat[1]
assert torch.allclose(agg[1], expected_msg_to_1)

print(f"Aggregated shape: {agg.shape}")
print("[PASS] Exercise 7: gather_scatter")

### Exercise 8: Masking & Boolean Indexing

**Prompt:** "Filter simulation data: keep only samples where the error is below a threshold and the frequency is within a range."

**Key concepts tested:**
- **Boolean masks** — comparison operators return boolean tensors: `x > 5` → `tensor([False, True, True, ...])`
- **Boolean indexing** — `tensor[bool_mask]` keeps only elements where mask is `True`
- **Combining masks** — use `&` (and), `|` (or), `~` (not). **NOT** Python `and`/`or` (those don't work on tensors!)

**Boolean mask operations:**
```python
errors = torch.tensor([0.05, 0.15, 0.03, 0.20, 0.08])

# Create masks (each returns a boolean tensor)
low_error = errors < 0.1           # [True, False, True, False, True]
high_error = errors > 0.15         # [False, False, False, True, False]

# Combine with bitwise operators (NOT Python and/or!)
combined = low_error & ~high_error  # AND + NOT
either = low_error | high_error     # OR

# Apply mask to select elements
filtered = errors[low_error]        # tensor([0.05, 0.03, 0.08])

# Count matching elements
n_good = low_error.sum().item()     # 3

# Apply same mask to multiple tensors (keep them aligned!)
filtered_errors = errors[low_error]
filtered_freqs = frequencies[low_error]  # Same mask → same indices
filtered_fields = fields[low_error]
```

**Common mistake — using Python `and` instead of `&`:**
```python
# WRONG: Python 'and' tries to convert entire tensor to bool
mask = (errors < 0.1) and (freqs > 1e9)  # RuntimeError!

# RIGHT: Bitwise '&' operates element-wise
mask = (errors < 0.1) & (freqs > 1e9)    # Works correctly

# NOTE: Parentheses are required because & has higher precedence than <
```

**Why masking matters for simulation data:**
- Filter out diverged simulations (error > threshold)
- Select frequency bands of interest
- Remove padding in variable-size batches (see Exercise 11's `mask` field)

In [None]:
def filter_simulation_data(
    errors: torch.Tensor,       # (N,)
    frequencies: torch.Tensor,  # (N,)
    fields: torch.Tensor,       # (N, H, W) — field data
    max_error: float = 0.1,
    freq_range: Tuple[float, float] = (1e9, 10e9),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Filter data using boolean masks. Returns filtered errors, freqs, fields.

    Key concepts:
    - Boolean indexing: tensor[mask] returns elements where mask is True
    - Combining masks with & (and), | (or), ~ (not)
    - This is vectorized — no Python loops needed
    """
    # Build boolean masks
    error_mask = errors < max_error
    freq_mask = (frequencies >= freq_range[0]) & (frequencies <= freq_range[1])

    # Combine masks
    valid_mask = error_mask & freq_mask

    # Apply mask to all tensors consistently
    return errors[valid_mask], frequencies[valid_mask], fields[valid_mask]

In [None]:
# Test Exercise 8
errors = torch.tensor([0.05, 0.15, 0.03, 0.20, 0.08])
freqs = torch.tensor([2e9, 5e9, 3e9, 15e9, 8e9])
fields = torch.randn(5, 4, 4)

e_filt, f_filt, fd_filt = filter_simulation_data(errors, freqs, fields)

print(f"Original: {len(errors)} samples")
print(f"After filtering: {len(e_filt)} samples")
print(f"Kept errors: {e_filt}")
print(f"Kept freqs: {f_filt}")

# error < 0.1: indices 0,2,4 (0.05, 0.03, 0.08)
# freq in [1e9, 10e9]: indices 0,1,2,4 (2e9, 5e9, 3e9, 8e9)
# Both: indices 0,2,4 → 3 samples
assert len(e_filt) == 3
print("[PASS] Exercise 8: filter_simulation_data")

### Exercise 9: Reshaping & Per-Sample Metrics

**Prompt:** "Given a batch of 2D field predictions, compute the relative L2 error per sample (not averaged across the batch)."

**Key concepts tested:**
- **`.reshape(batch, -1)`** — flatten spatial dims while keeping batch dim intact
- **Reducing over the right dimension** — `dim=1` reduces spatial, keeps batch
- **`keepdim=True`** — keeps the reduced dimension as size 1 for broadcasting

**Reshape vs View vs Flatten:**
```python
x = torch.randn(4, 16, 16)   # (batch, H, W)

# Method 1: reshape — always works, may copy data
x.reshape(4, -1)              # (4, 256)  — -1 infers 16*16=256

# Method 2: view — only works if tensor is contiguous in memory
x.view(4, -1)                 # (4, 256)  — fails after .permute()

# Method 3: flatten — explicit and readable
x.flatten(start_dim=1)        # (4, 256)  — flatten from dim 1 onward

# RULE: Use .reshape() unless you specifically need .view()'s error
# on non-contiguous tensors (rare). .flatten() is most readable.
```

**Relative L2 error — the standard metric for physics surrogates:**
```
relative_L2 = ||pred - target||₂ / ||target||₂

# Per-sample (this exercise): shape (batch,)
# Global (less useful):       single scalar
```

**Why per-sample, not batch-averaged?**
- You need to know WHICH samples have high error
- Batch-average hides outliers — one bad prediction gets buried
- In production: flag samples with error > threshold for full simulation

**Dimension reduction pattern:**
```python
x = torch.randn(4, 256)      # (batch, features)

# Reduce features, keep batch:
norm = torch.norm(x, dim=1)   # (4,)  — one value per sample

# Reduce batch, keep features:
mean = x.mean(dim=0)          # (256,) — one value per feature

# Reduce both:
total = x.sum()               # scalar
```

In [None]:
def per_sample_relative_l2(
    pred: torch.Tensor,   # (batch, H, W)
    target: torch.Tensor,  # (batch, H, W)
) -> torch.Tensor:
    """
    Compute relative L2 error for each sample: ||pred - target|| / ||target||

    Returns shape (batch,) — one error value per sample.

    Key concept: Use .reshape(batch, -1) to flatten spatial dims,
    then reduce over the flattened dim only (not the batch dim).
    """
    batch = pred.shape[0]
    # Flatten spatial dimensions
    p = pred.reshape(batch, -1)   # (batch, H*W)
    t = target.reshape(batch, -1)  # (batch, H*W)

    # L2 norm along the spatial dimension (dim=1), keep batch dim
    diff_norm = torch.norm(p - t, dim=1)    # (batch,)
    target_norm = torch.norm(t, dim=1)       # (batch,)

    # Avoid division by zero
    return diff_norm / target_norm.clamp(min=1e-8)

In [None]:
# Test Exercise 9
pred = torch.randn(4, 16, 16)
target = torch.randn(4, 16, 16) + 5  # Offset so target_norm > 0
rel_err = per_sample_relative_l2(pred, target)

assert rel_err.shape == (4,)
assert (rel_err >= 0).all()

# Perfect prediction should give 0 error
perfect_err = per_sample_relative_l2(target, target)
assert torch.allclose(perfect_err, torch.zeros(4), atol=1e-6)

print(f"Per-sample relative L2: {rel_err}")
print(f"Perfect prediction error: {perfect_err}")
print("[PASS] Exercise 9: per_sample_relative_l2")

### Exercise 10: Per-Sample Normalization

**Prompt:** "Normalize each sample in a batch independently (zero mean, unit variance) along the spatial dimensions."

**Key concepts tested:**
- **`keepdim=True`** — essential for broadcasting the mean/std back to original shape
- **`.clamp(min=1e-8)`** — avoid division by zero for constant-valued regions
- **Understanding normalization types** — which dims to reduce over

**Normalization types compared (critical interview topic):**

| Type | Reduces over | Shape kept | Use case |
|------|-------------|------------|----------|
| **BatchNorm** | Batch + Spatial `(N,H,W)` | `(C,)` | Standard for images, needs large batches |
| **LayerNorm** | Channel + Spatial `(C,H,W)` | `(N,)` | Transformers, small-batch friendly |
| **InstanceNorm** | Spatial `(H,W)` | `(N,C)` | Style transfer, per-sample normalization |
| **GroupNorm** | Group of channels + Spatial | `(N, G)` | Compromise between BN and LN |

```python
x = torch.randn(2, 3, 8, 8)  # (batch=2, channels=3, H=8, W=8)

# BatchNorm: stats across batch+spatial, one mean/var per channel
nn.BatchNorm2d(3)      # learns γ,β per channel; needs batch > 1

# LayerNorm: stats per sample across channels+spatial
nn.LayerNorm([3,8,8])  # normalizes each sample independently

# InstanceNorm: stats per sample per channel across spatial
nn.InstanceNorm2d(3)   # each (sample, channel) normalized separately

# This exercise: manual InstanceNorm (no learnable params)
mean = x.mean(dim=(-2,-1), keepdim=True)   # (2, 3, 1, 1)
std  = x.std(dim=(-2,-1), keepdim=True)    # (2, 3, 1, 1)
normed = (x - mean) / std.clamp(min=1e-8)  # Broadcasting: (2,3,8,8)
```

**Why `keepdim=True` matters:**
```python
x = torch.randn(2, 3, 8, 8)

# WITHOUT keepdim: mean has shape (2, 3) → can't subtract from (2, 3, 8, 8)
mean = x.mean(dim=(-2,-1))                  # (2, 3) — WRONG shape!
# x - mean → RuntimeError or silent broadcasting bug!

# WITH keepdim: mean has shape (2, 3, 1, 1) → broadcasts correctly
mean = x.mean(dim=(-2,-1), keepdim=True)     # (2, 3, 1, 1) — RIGHT!
# x - mean → (2, 3, 8, 8) — works via broadcasting
```

**For simulation data at Synopsys:** each simulation has different field magnitudes (e.g., E-field for a 2 GHz antenna vs 10 GHz). Per-sample normalization ensures the model sees consistent scales.

In [None]:
def per_sample_normalize(x: torch.Tensor) -> torch.Tensor:
    """
    Normalize each sample to zero mean and unit variance.

    Input:  (batch, channels, H, W)
    Output: (batch, channels, H, W) — each sample independently normalized

    This is InstanceNorm without learnable params.
    Useful for simulation data where each sample has different magnitude.
    """
    # Compute mean and std over spatial dims (H, W), keep batch and channel
    mean = x.mean(dim=(-2, -1), keepdim=True)  # (B, C, 1, 1)
    std = x.std(dim=(-2, -1), keepdim=True)    # (B, C, 1, 1)
    return (x - mean) / std.clamp(min=1e-8)

In [None]:
# Test Exercise 10
x = torch.randn(2, 3, 8, 8) * 100 + 50  # Arbitrary scale
normed = per_sample_normalize(x)

# Check zero mean per sample/channel
means = normed.mean(dim=(-2, -1))
stds = normed.std(dim=(-2, -1))
assert torch.allclose(means, torch.zeros_like(means), atol=1e-5)

print(f"Input mean range: [{x.mean(dim=(-2,-1)).min():.1f}, {x.mean(dim=(-2,-1)).max():.1f}]")
print(f"Normalized mean range: [{means.min():.6f}, {means.max():.6f}]")
print(f"Normalized std range: [{stds.min():.4f}, {stds.max():.4f}]")
print("[PASS] Exercise 10: per_sample_normalize")

---
## Part 3: PyTorch Practical Patterns

Building real ML pipelines — what you'd actually do on the job.

### PyTorch Pipeline Overview

```
Raw Data (.npz, .csv, .h5)
    ↓
Dataset  ──→ __getitem__(idx) returns one sample
    ↓
DataLoader ──→ batches, shuffles, multi-process loading
    ↓
Model (nn.Module) ──→ forward(x) returns predictions
    ↓
Loss Function ──→ scalar loss value
    ↓
loss.backward() ──→ computes gradients
    ↓
optimizer.step() ──→ updates weights
    ↓
scheduler.step() ──→ adjusts learning rate
    ↓
Evaluation (model.eval() + torch.no_grad())
    ↓
Save/Load (torch.save / torch.load state_dict)
```

### Key PyTorch Classes to Know

| Class | Purpose | You implement |
|-------|---------|---------------|
| `Dataset` | Wraps your data | `__len__`, `__getitem__` |
| `DataLoader` | Batching + shuffling | (use as-is, maybe custom `collate_fn`) |
| `nn.Module` | Neural network base | `__init__`, `forward` |
| `optim.Adam` | Optimizer | (use as-is) |
| `lr_scheduler` | LR schedule | (use as-is) |
| `nn.MSELoss` | Loss function | (or write custom) |

### Exercise 11: Custom Dataset for Variable-Size Simulation Data

**Prompt:** "Write a PyTorch Dataset for loading simulation field data. Each sample has different spatial resolution (variable-size meshes)."

**Key concepts tested:**
- **`Dataset` interface** — only two methods: `__len__()` and `__getitem__(idx)`
- **Lazy loading** — don't load all data in `__init__`, load per-sample in `__getitem__`
- **Custom `collate_fn`** — how to batch variable-size data

**`Dataset` vs `IterableDataset`:**
```python
# Map-style Dataset (this exercise) — most common
class MyDataset(Dataset):
    def __len__(self):        return N         # Total number of samples
    def __getitem__(self, i): return data[i]   # Access by index
# Supports: shuffling, random access, len(), sampler

# IterableDataset — for streaming data (huge files, network)
class MyStream(IterableDataset):
    def __iter__(self):
        for line in open("huge.csv"):
            yield process(line)
# No shuffling, no len(), no random access
```

**Why lazy loading?**
```python
# BAD: loads entire dataset into RAM at init — may OOM
class BadDataset(Dataset):
    def __init__(self, data_dir):
        self.data = [np.load(f) for f in glob(data_dir)]  # 100GB in RAM!

# GOOD: only stores file paths at init, loads one sample at a time
class GoodDataset(Dataset):
    def __init__(self, data_dir):
        self.paths = sorted(glob(data_dir))  # Just paths — kilobytes
    def __getitem__(self, i):
        return np.load(self.paths[i])         # Load on demand
```

**Custom `collate_fn` — when default stacking doesn't work:**
```python
# Default collate: torch.stack(samples) → requires same shape!
# Variable-size data (meshes, sequences) needs custom handling:

# Option 1: Pad to max size + mask (this exercise)
# Option 2: Concatenate + batch index (PyG style)
# Option 3: Nested tensors (PyTorch 2.0+)

# DataLoader usage:
loader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,               # Shuffle for training
    num_workers=4,               # Parallel data loading
    collate_fn=my_collate,       # Custom batching
    pin_memory=True,             # Faster CPU→GPU transfer
)
```

**Key design decisions in this exercise:**
1. **Lazy loading** — `__init__` only stores file paths
2. **Return dict** (not tuple) — `{"params": ..., "field": ..., "coords": ...}` is self-documenting
3. **Padding + mask** — mask is CRITICAL: loss must ignore padded values!

In [None]:
class SimulationDataset(Dataset):
    """
    Custom Dataset for simulation field data stored as .npz files.

    Each file contains:
      - 'params': input parameters (e.g., geometry, frequency), shape (P,)
      - 'field':  output field values, shape (N_i, F) — N_i varies per sample
      - 'coords': node coordinates, shape (N_i, D)
    """

    def __init__(self, data_dir: str, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        # Only store file paths at init — lazy loading
        self.file_paths = sorted(
            [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(".npz")]
        )

    def __len__(self) -> int:
        return len(self.file_paths)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        data = np.load(self.file_paths[idx])

        sample = {
            "params": torch.tensor(data["params"], dtype=torch.float32),
            "coords": torch.tensor(data["coords"], dtype=torch.float32),
            "field": torch.tensor(data["field"], dtype=torch.float32),
        }

        if self.transform is not None:
            sample = self.transform(sample)

        return sample


def collate_variable_size(batch: List[Dict]) -> Dict[str, torch.Tensor]:
    """
    Custom collate function for variable-size simulation data.

    Since each sample has different number of nodes N_i, we can't just stack.
    Options:
      1. Pad to max size (shown here) — simple, works with standard PyTorch
      2. Concatenate with batch index (PyG-style) — more memory efficient
      3. Use nested tensors (PyTorch 2.0+)
    """
    # Fixed-size: stack normally
    params = torch.stack([s["params"] for s in batch])  # (B, P)

    # Variable-size: pad to max
    max_nodes = max(s["coords"].shape[0] for s in batch)
    coord_dim = batch[0]["coords"].shape[1]
    field_dim = batch[0]["field"].shape[1]

    padded_coords = torch.zeros(len(batch), max_nodes, coord_dim)
    padded_fields = torch.zeros(len(batch), max_nodes, field_dim)
    masks = torch.zeros(len(batch), max_nodes, dtype=torch.bool)

    for i, s in enumerate(batch):
        n = s["coords"].shape[0]
        padded_coords[i, :n] = s["coords"]
        padded_fields[i, :n] = s["field"]
        masks[i, :n] = True  # True = valid node, False = padding

    return {
        "params": params,
        "coords": padded_coords,
        "field": padded_fields,
        "mask": masks,  # CRITICAL: loss/metrics must use this mask!
    }

# Usage:
# loader = DataLoader(dataset, batch_size=8, collate_fn=collate_variable_size)
print("[INFO] Exercise 11: SimulationDataset + collate_variable_size defined")
print("       (Requires .npz files on disk to test — review the pattern)")

### Exercise 12: Complete Training Loop with Best Practices

**Prompt:** "Write a training function with validation, early stopping, checkpointing, and proper device management."

**Training loop components explained:**

| Component | What it does | Why it matters |
|-----------|-------------|----------------|
| `optimizer.zero_grad()` | Clears old gradients | Gradients accumulate by default! |
| `loss.backward()` | Computes gradients via backprop | Populates `.grad` for all parameters |
| `clip_grad_norm_` | Caps gradient magnitude | Prevents exploding gradients |
| `optimizer.step()` | Updates weights using gradients | The actual learning step |
| `scheduler.step()` | Adjusts learning rate | Cosine annealing, step decay, etc. |
| `model.train()` | Enables dropout + BN training mode | Must call before training |
| `model.eval()` | Disables dropout, uses running BN stats | Must call before validation |
| `torch.no_grad()` | Disables gradient tracking | Saves memory during validation |

**Mixed Precision Training (AMP) — why and how:**
```python
# WHY: float16 is 2x faster on GPU, uses half the memory
# HOW: autocast converts eligible ops to float16 automatically
scaler = torch.amp.GradScaler("cuda")

with torch.amp.autocast("cuda"):         # Forward pass in float16
    pred = model(x)
    loss = criterion(pred, y)

scaler.scale(loss).backward()             # Scale loss to prevent underflow
scaler.unscale_(optimizer)                # Unscale before clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
scaler.step(optimizer)                    # Step with unscaled gradients
scaler.update()                           # Adjust scale factor
```

**Early stopping — prevent overfitting:**
```python
best_val_loss = float("inf")
patience_counter = 0

if val_loss < best_val_loss:
    best_val_loss = val_loss
    patience_counter = 0
    torch.save(model.state_dict(), "best.pt")  # Save best
else:
    patience_counter += 1
    if patience_counter >= patience:            # No improvement for N epochs
        print("Early stopping!")
        break
```

**Checkpointing — save everything needed to resume:**
```python
# SAVE (not just model — also optimizer, epoch, loss for resuming)
torch.save({
    "epoch": epoch,
    "model_state_dict": model.state_dict(),
    "optimizer_state_dict": optimizer.state_dict(),
    "val_loss": val_loss,
}, "checkpoint.pt")

# LOAD
ckpt = torch.load("checkpoint.pt", weights_only=True)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
start_epoch = ckpt["epoch"] + 1
```

In [None]:
def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    num_epochs: int = 100,
    lr: float = 1e-3,
    patience: int = 10,
    save_path: str = "best_model.pt",
    device: str = "cuda",
):
    """
    Production-quality training loop.
    """
    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
    scaler = torch.amp.GradScaler("cuda")  # For mixed precision

    best_val_loss = float("inf")
    patience_counter = 0

    for epoch in range(num_epochs):
        # ---- TRAINING ----
        model.train()
        train_loss = 0.0
        num_batches = 0

        for batch in train_loader:
            x = batch["input"].to(device)
            y = batch["target"].to(device)

            optimizer.zero_grad(set_to_none=True)  # Slightly faster

            # Mixed precision forward pass
            with torch.amp.autocast("cuda"):
                pred = model(x)
                loss = F.mse_loss(pred, y)

            # Mixed precision backward pass
            scaler.scale(loss).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            scaler.step(optimizer)
            scaler.update()

            train_loss += loss.item()  # .item() to avoid memory leak!
            num_batches += 1

        train_loss /= num_batches
        scheduler.step()

        # ---- VALIDATION ----
        model.eval()
        val_loss = 0.0
        num_val_batches = 0

        with torch.no_grad():  # CRITICAL: save memory
            for batch in val_loader:
                x = batch["input"].to(device)
                y = batch["target"].to(device)

                with torch.amp.autocast("cuda"):
                    pred = model(x)
                    loss = F.mse_loss(pred, y)

                val_loss += loss.item()
                num_val_batches += 1

        val_loss /= num_val_batches

        # ---- EARLY STOPPING & CHECKPOINTING ----
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "val_loss": val_loss,
            }, save_path)
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch}")
                break

        if epoch % 10 == 0:
            current_lr = optimizer.param_groups[0]["lr"]
            print(
                f"Epoch {epoch:3d} | "
                f"Train: {train_loss:.6f} | "
                f"Val: {val_loss:.6f} | "
                f"LR: {current_lr:.2e} | "
                f"Best: {best_val_loss:.6f}"
            )

    # Load best model
    checkpoint = torch.load(save_path, weights_only=True)
    model.load_state_dict(checkpoint["model_state_dict"])
    return model

print("[INFO] Exercise 12: train_model defined")
print("       (Requires DataLoaders to run — review the pattern)")

### Exercise 13: Model Definition with Flexible Architecture

**Prompt:** "Write a simple but flexible MLP that could serve as a surrogate model."

**Key `nn.Module` patterns tested:**
- **`__init__`** — define layers (they get registered for `.parameters()`, `.to(device)`, `.state_dict()`)
- **`forward`** — define computation graph (called via `model(x)`, NOT `model.forward(x)`)
- **`nn.Sequential`** — chain layers into a single callable module
- **Weight initialization** — Kaiming for ReLU/GELU, Xavier for sigmoid/tanh

**Why `nn.Module` and not just functions?**
```python
# nn.Module gives you for FREE:
model.parameters()        # All learnable params (for optimizer)
model.to("cuda")          # Move all params to GPU
model.state_dict()        # Serializable dict of params
model.train() / .eval()   # Toggle dropout/batchnorm
model.children()          # Iterate sub-modules
print(model)              # Human-readable architecture summary
```

**Dynamic layer construction pattern (this exercise):**
```python
layers = []
in_dim = input_dim
for h_dim in hidden_dims:        # e.g., [256, 256, 256]
    layers.extend([
        nn.Linear(in_dim, h_dim),
        nn.LayerNorm(h_dim),     # Normalization
        nn.GELU(),               # Activation
        nn.Dropout(0.1),         # Regularization
    ])
    in_dim = h_dim               # Output of this layer → input of next

self.backbone = nn.Sequential(*layers)  # Unpack list into Sequential
self.head = nn.Linear(in_dim, output_dim)
```

**Weight initialization — why it matters:**
```python
# Default PyTorch init (Kaiming uniform) is usually fine
# But explicit init shows you understand the theory:

for m in self.modules():
    if isinstance(m, nn.Linear):
        # Kaiming: designed for ReLU-family activations
        nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
        # Zero bias is standard
        nn.init.zeros_(m.bias)

# Other options:
# nn.init.xavier_normal_  → for sigmoid/tanh
# nn.init.orthogonal_     → for RNNs
# nn.init.constant_       → for specific values
```

**Activation function choices:**
| Function | Formula | Use case |
|----------|---------|----------|
| ReLU | `max(0, x)` | Default, fast, can "die" (output always 0) |
| GELU | `x · Φ(x)` | Transformers, smooth, no dead neurons |
| SiLU/Swish | `x · σ(x)` | Modern CNNs, smooth |
| Tanh | `(eˣ-e⁻ˣ)/(eˣ+e⁻ˣ)` | Output in [-1,1], can saturate |

In [None]:
class SurrogateMLP(nn.Module):
    """
    Flexible MLP for surrogate modeling: params -> field values.

    Demonstrates:
    - Dynamic layer construction from config
    - Multiple normalization options
    - Proper weight initialization
    """

    def __init__(
        self,
        input_dim: int,
        output_dim: int,
        hidden_dims: List[int] = [256, 256, 256],
        activation: str = "gelu",
        norm: str = "layer",
        dropout: float = 0.0,
    ):
        super().__init__()

        act_fn = {"relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU}[activation]
        norm_fn = {
            "layer": nn.LayerNorm,
            "batch": nn.BatchNorm1d,
            "none": lambda d: nn.Identity(),
        }[norm]

        layers = []
        in_dim = input_dim
        for h_dim in hidden_dims:
            layers.extend([
                nn.Linear(in_dim, h_dim),
                norm_fn(h_dim),
                act_fn(),
                nn.Dropout(dropout),
            ])
            in_dim = h_dim

        self.backbone = nn.Sequential(*layers)
        self.head = nn.Linear(in_dim, output_dim)

        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (batch, input_dim) -> (batch, output_dim)"""
        return self.head(self.backbone(x))

In [None]:
# Test Exercise 13
model = SurrogateMLP(input_dim=10, output_dim=5, hidden_dims=[64, 64])
x = torch.randn(4, 10)
y = model(x)

assert y.shape == (4, 5)
y.sum().backward()

# Count parameters
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
mb = total * 4 / (1024 ** 2)

print(f"Model: {model}")
print(f"Parameters: {total:,} total, {trainable:,} trainable ({mb:.2f} MB)")
print(f"Output shape: {y.shape}")
print("[PASS] Exercise 13: SurrogateMLP")

### Exercise 14: Parameter Counting Utility

**Prompt:** "Write a function to count total and trainable parameters in a model."

**Why this matters:**
- **Paper reporting** — "Our model has 1.2M parameters" is standard in every ML paper
- **Memory estimation** — `params × 4 bytes (float32) = model size in memory` (×2 for gradients during training)
- **Architecture comparison** — helps decide between model variants
- **Deployment constraints** — Synopsys may have memory/latency budgets for surrogate models

**The one-liner you must memorize:**
```python
total = sum(p.numel() for p in model.parameters())
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

# Memory estimation:
# Training: params × 4 (weights) + params × 4 (gradients) + optimizer states
# Inference: params × 4 (weights only) or params × 2 (float16)
```

**Frozen vs trainable parameters:**
```python
# Freeze a layer (transfer learning, fine-tuning):
for param in model.backbone.parameters():
    param.requires_grad = False

# Now only model.head is trainable
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
# This counts only head parameters
```

**Useful for comparing architectures:**
```
Hidden [64, 64]:         4,869 params (0.02 MB)  — too small?
Hidden [256, 256, 256]: 133,893 params (0.51 MB)  — sweet spot
Hidden [512, 512, 512]: 533,509 params (2.04 MB)  — maybe overkill
```

In [None]:
@torch.no_grad()
def count_parameters(model: nn.Module) -> Dict[str, int]:
    """Count total and trainable parameters in a model."""
    total = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {
        "total": total,
        "trainable": trainable,
        "frozen": total - trainable,
        "total_MB": total * 4 / (1024 ** 2),  # Assuming float32
    }


# Test with different model configs
for hidden in [[64, 64], [256, 256, 256], [512, 512, 512, 512]]:
    m = SurrogateMLP(input_dim=10, output_dim=5, hidden_dims=hidden)
    p = count_parameters(m)
    print(f"Hidden {hidden}: {p['total']:>8,} params ({p['total_MB']:.2f} MB)")

print("[PASS] Exercise 14: count_parameters")

### Exercise 15: MC-Dropout for Prediction Uncertainty

**Prompt:** "Add prediction uncertainty using MC-Dropout. This is critical for simulation surrogates — we need to know WHEN the model is uncertain."

**Why uncertainty matters for Synopsys:**
- A surrogate model that says "I don't know" is safer than one that silently gives wrong answers
- High uncertainty → fall back to full HFSS simulation (expensive but accurate)
- This is the same idea as your JESTIE paper's design-space coverage metric

**Types of uncertainty:**

| Type | Source | Can reduce with more data? | Method |
|------|--------|---------------------------|--------|
| **Aleatoric** | Noisy data, inherent randomness | No | Predict mean + variance |
| **Epistemic** | Model doesn't know (insufficient data) | Yes | MC-Dropout, ensembles |

**MC-Dropout — the simplest uncertainty method:**
```python
# Normal inference: dropout OFF → deterministic output
model.eval()
pred = model(x)  # Same input → same output every time

# MC-Dropout: dropout ON at inference → stochastic output
model.train()  # Keep dropout active! (the key trick)
preds = []
for _ in range(30):
    with torch.no_grad():        # Still no gradients needed
        preds.append(model(x))   # Each call drops different neurons

preds = torch.stack(preds)       # (30, batch, output_dim)
mean = preds.mean(dim=0)         # Best prediction
std = preds.std(dim=0)           # Uncertainty estimate
```

**Interpreting the results:**
```python
# Low std  → model is confident → trust the prediction
# High std → model is uncertain → run full simulation

# Decision rule:
threshold = 0.05  # Application-dependent
uncertain_mask = std.mean(dim=-1) > threshold
n_uncertain = uncertain_mask.sum()
print(f"{n_uncertain} samples need full HFSS simulation")
```

**Alternatives to MC-Dropout:**
| Method | Pros | Cons |
|--------|------|------|
| **MC-Dropout** | Simple, no extra training | Requires dropout layers, approximate |
| **Deep Ensembles** | Best quality uncertainty | N× training cost, N× memory |
| **Heteroscedastic** | Predicts per-point aleatoric | Doesn't capture epistemic |
| **Bayesian NN** | Principled, full posterior | Slow, hard to scale |

**For the interview:** MC-Dropout is the go-to answer because it's simple, effective, and requires zero changes to training — just keep `model.train()` at inference.

In [None]:
def predict_with_uncertainty(
    model: nn.Module,
    x: torch.Tensor,
    n_samples: int = 30,
    device: str = "cpu",
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Monte Carlo Dropout for uncertainty estimation.

    Runs N forward passes with dropout ENABLED at inference time.
    The variance across predictions estimates epistemic uncertainty.

    Args:
        model: Model with Dropout layers (must have dropout > 0)
        x: Input tensor, shape (batch, ...)
        n_samples: Number of MC samples (30 is usually sufficient)

    Returns:
        mean: Mean prediction, shape same as model output
        std:  Standard deviation (uncertainty), same shape
    """
    model.train()  # Keep dropout ON (this is the key trick!)
    x = x.to(device)

    predictions = []
    for _ in range(n_samples):
        with torch.no_grad():  # Still no gradients needed
            pred = model(x)
        predictions.append(pred)

    predictions = torch.stack(predictions)  # (n_samples, batch, ...)
    mean = predictions.mean(dim=0)
    std = predictions.std(dim=0)

    model.eval()  # Restore to eval mode
    return mean, std

In [None]:
# Test Exercise 15
model_with_dropout = SurrogateMLP(
    input_dim=10, output_dim=5, hidden_dims=[64, 64], dropout=0.1
)
x = torch.randn(4, 10)

mean, std = predict_with_uncertainty(model_with_dropout, x, n_samples=50)

assert mean.shape == (4, 5)
assert std.shape == (4, 5)
assert (std >= 0).all()

print(f"Mean prediction shape: {mean.shape}")
print(f"Uncertainty (std) shape: {std.shape}")
print(f"Mean uncertainty per output: {std.mean(dim=0)}")
print(f"Max uncertainty: {std.max():.4f}")
print("[PASS] Exercise 15: MC-Dropout uncertainty")

---
## Part 4: PyTorch Gotchas — "Spot the Bug"

Quick-fire questions an interviewer might ask. For each one: what's wrong and why?

### Gotcha 1: `model.eval()` does NOT disable gradients

```python
# BUG:
model.eval()
output = model(x)  # Still tracking gradients! Wastes memory.

# FIX:
model.eval()
with torch.no_grad():
    output = model(x)
```

**WHY:** `model.eval()` only changes Dropout and BatchNorm behavior. `torch.no_grad()` disables autograd for memory/speed savings.

### Gotcha 2: In-place operations break autograd

```python
# BUG:
x = torch.randn(3, requires_grad=True)
x += 1            # In-place! Destroys grad graph.
loss = x.sum()
loss.backward()   # RuntimeError

# FIX:
y = x + 1         # New tensor — graph intact.
```

**RULE:** Any operation ending in `_` (`add_`, `mul_`, `zero_`) is in-place. Never use on tensors that need gradients.

### Gotcha 3: Device mismatch

```python
# BUG:
model = model.cuda()
x = torch.randn(4, 10)   # CPU!
output = model(x)         # RuntimeError

# FIX:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
x = x.to(device)
```

**TIP:** Inside a model, create new tensors on the same device as input: `mask = torch.ones(n, device=x.device)`

### Gotcha 4: Forgetting `model.train()` after validation

```python
# BUG:
model.eval()
val_loss = validate(model)
# ... continue training with Dropout disabled, BatchNorm frozen ...

# FIX:
model.eval()
with torch.no_grad():
    val_loss = validate(model)
model.train()  # ALWAYS switch back!
```

### Gotcha 5: Gradient accumulation (forgetting `zero_grad`)

```python
# BUG:
for batch in dataloader:
    loss = model(batch).sum()
    loss.backward()       # Gradients ACCUMULATE from prior steps!
    optimizer.step()

# FIX:
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch).sum()
    loss.backward()
    optimizer.step()
```

**NOTE:** Accumulation is actually *useful* for simulating larger batch sizes:
```python
for i, batch in enumerate(dataloader):
    loss = model(batch).sum() / accumulation_steps
    loss.backward()
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
```

### Gotcha 6: Silent broadcasting bugs

```python
# BUG:
pred = model(x)       # (batch, 10)
target = get_target() # (batch,)
loss = (pred - target) ** 2  # Broadcasts to (batch, 10) — WRONG!

# FIX:
target = target.unsqueeze(-1)  # (batch, 1) — explicit shape
loss = (pred - target) ** 2
```

**DEFENSE:** Always `assert pred.shape == target.shape`

### Gotcha 7: Memory leak from storing loss tensors

```python
# BUG:
losses = []
for batch in dataloader:
    loss = model(batch).sum()
    losses.append(loss)  # Entire computation graph stays in memory!

# FIX:
losses.append(loss.item())   # .item() -> Python float, graph freed
# or
losses.append(loss.detach()) # Tensor without grad history
```

### Gotcha 8: BatchNorm with `batch_size=1`

```python
# BUG:
model = nn.Sequential(nn.Linear(10, 10), nn.BatchNorm1d(10))
x = torch.randn(1, 10)  # Single sample!
model(x)                 # RuntimeError: variance is 0
```

**FIX:** Use `LayerNorm` (normalizes across features) or `InstanceNorm`. For simulation data with small batches, `LayerNorm` is usually better.

### Gotcha 9: DDP `state_dict` key mismatch

```python
# BUG: Saved with DDP → keys = "module.layer.weight"
#       Loaded without DDP → expects "layer.weight"

# FIX (save correctly):
torch.save(ddp_model.module.state_dict(), "model.pt")

# FIX (load flexibly):
state = torch.load("model.pt")
state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(state)
```

### Gotcha 10: `torch.Tensor` vs `torch.tensor`

```python
# BUG:
x = torch.Tensor(3)     # Allocates UNINITIALIZED tensor of size 3!
y = torch.Tensor([1,2]) # Always float32, ignores input dtype

# FIX:
x = torch.tensor(3)     # Scalar tensor with value 3
y = torch.tensor([1,2]) # Infers dtype (int64 here)
```

**RULE:** Always use lowercase `torch.tensor()` for creating from data.

---
## Bonus: Essential Patterns Quick Reference

In [None]:
# =====================================================
# Pattern 1: Proper training loop skeleton
# =====================================================
# model.train()
# for epoch in range(num_epochs):
#     for batch in dataloader:
#         optimizer.zero_grad()
#         loss = criterion(model(batch.to(device)), target.to(device))
#         loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
#         optimizer.step()

# =====================================================
# Pattern 2: Proper evaluation
# =====================================================
# model.eval()
# with torch.no_grad():
#     for batch in val_loader:
#         pred = model(batch.to(device))
# model.train()  # Switch back!

# =====================================================
# Pattern 3: Proper checkpointing
# =====================================================
# torch.save({
#     'epoch': epoch,
#     'model_state_dict': model.state_dict(),
#     'optimizer_state_dict': optimizer.state_dict(),
#     'loss': best_loss,
# }, 'checkpoint.pt')
#
# ckpt = torch.load('checkpoint.pt', weights_only=True)
# model.load_state_dict(ckpt['model_state_dict'])

# =====================================================
# Pattern 4: Mixed precision (AMP)
# =====================================================
# scaler = torch.amp.GradScaler("cuda")
# with torch.amp.autocast("cuda"):
#     loss = model(x)
# scaler.scale(loss).backward()
# scaler.step(optimizer)
# scaler.update()

# =====================================================
# Pattern 5: Gradient accumulation
# =====================================================
# for i, batch in enumerate(loader):
#     loss = model(batch) / accum_steps
#     loss.backward()
#     if (i + 1) % accum_steps == 0:
#         optimizer.step()
#         optimizer.zero_grad()

print("Review the 5 essential patterns above!")

---
## Summary: Run All Tests

In [None]:
print("=" * 70)
print("RUNNING ALL TESTS")
print("=" * 70)
print()

# Part 1: Python Fundamentals
print("--- Part 1: Python Fundamentals ---")

# Ex 1
results = [
    {"geometry": "dipole", "error": 0.05},
    {"geometry": "patch",  "error": 0.12},
    {"geometry": "dipole", "error": 0.03},
    {"geometry": "patch",  "error": 0.08},
]
avg = group_and_average(results)
assert abs(avg["dipole"] - 0.04) < 1e-10
print("  [PASS] Exercise 1: group_and_average")

# Ex 3
s_params = np.random.randn(10, 2, 2) + 1j * np.random.randn(10, 2, 2)
freqs = np.linspace(1.0, 10.0, 10)
result = SimulationResult("test_antenna", s_params, freqs)
assert result.n_ports == 2
print(f"  [PASS] Exercise 3: SimulationResult")

# Ex 5
config = {"model_type": "gnn", "learning_rate": 1e-3, "num_epochs": 50}
clean = validate_config(config)
assert clean["batch_size"] == 32
print("  [PASS] Exercise 5: validate_config")

print()

# Part 2: Tensor Operations
print("--- Part 2: Tensor Operations ---")

# Ex 6
x = torch.randn(5, 3); y = torch.randn(7, 3)
dist = pairwise_distances(x, y)
assert torch.allclose(dist, torch.cdist(x, y), atol=1e-5)
print("  [PASS] Exercise 6: pairwise_distances")

# Ex 7
agg = gather_scatter_demo(torch.randn(4, 8), torch.tensor([[0,1,2,3],[1,2,3,0]]))
assert agg.shape == (4, 8)
print("  [PASS] Exercise 7: gather_scatter")

# Ex 8
e, f, fd = filter_simulation_data(
    torch.tensor([0.05, 0.15, 0.03, 0.20, 0.08]),
    torch.tensor([2e9, 5e9, 3e9, 15e9, 8e9]),
    torch.randn(5, 4, 4)
)
assert len(e) == 3
print("  [PASS] Exercise 8: filter_simulation_data")

# Ex 9
rel = per_sample_relative_l2(torch.randn(4,16,16), torch.randn(4,16,16)+5)
assert rel.shape == (4,) and (rel >= 0).all()
print("  [PASS] Exercise 9: per_sample_relative_l2")

# Ex 10
normed = per_sample_normalize(torch.randn(2,3,8,8)*100+50)
assert torch.allclose(normed.mean(dim=(-2,-1)), torch.zeros(2,3), atol=1e-5)
print("  [PASS] Exercise 10: per_sample_normalize")

print()

# Part 3: PyTorch Patterns
print("--- Part 3: PyTorch Patterns ---")

# Ex 13
m = SurrogateMLP(10, 5, [64,64])
y = m(torch.randn(4,10)); y.sum().backward()
print(f"  [PASS] Exercise 13: SurrogateMLP ({sum(p.numel() for p in m.parameters()):,} params)")

# Ex 15
md = SurrogateMLP(10, 5, [64,64], dropout=0.1)
mean, std = predict_with_uncertainty(md, torch.randn(4,10), n_samples=10)
assert mean.shape == (4,5) and (std >= 0).all()
print("  [PASS] Exercise 15: MC-Dropout uncertainty")

print()
print("=" * 70)
print("ALL TESTS PASSED")
print("=" * 70)