# Optimizing ML Training with Metagradient Descent

**Focus**: pass 1 and pass 2

**References**: 
- Optimizing ML Training with Metagradient Descent (https://arxiv.org/pdf/2503.13751)

**Purpose**: 
- Configuring the training process to maximize model performance is a major challenge in training large-scale ML models (like Neural Architecture Search but expanded?). Ex. finding the best training setup from a vast design space.

**Approach**: 
- Take a gradient-based approach to this problem.
- Introduce an algorithm for efficiently calculating metagradients - gradients through model training - at scale.
- "smooth mdoel training" framework that enables effective optimization using metagradients
- Why optimization? Gradients offer a more effective approach to maximizing high-dimensional functions than grid search.

**Result**: 
- with metagradient descent (MGD), they greatly improve existing dataset selection methods (?), outperform accuracy-degrading data poisoning attacks by an order of magnitude, and automatically find competitive learning rate schedules

**Definitions**: 
- metaparameters: training configuration / hyperparameters
- metagradient \delta_z\phi (A(z)) \in R^d: the gradient of the model output w.r.t. the metaparameter

- z \in R^n = vector of continuous metaparameters representing the aspects of the training setup we aim to optimize. Ex. if we want to adjust LR and weight decay of SGD, then n=2. Interesting... but their scales are so far off? Discrete metaparameters (ex. choice of training data) is handled by finding continuous relaxation (e.g. importance weights).

- A = algorithm mapping z to a trained machine learning model

- \phi = output function mapping \theta to a vector \phi(\theta) \in R. Ex. output function could be a validation loss of the model \theta. Requires \phi be differentiable w.r.t \theta (ex. loss function must be differentiable like MSE or NLL)

- training function f := \phi composed A mapping the training setup z directly to the output function \phi evaluated on the corresponding model (A(z)).

- metagradient = gradient of the training function f w.r.t. metaparameters z, so \delta_z f(z). Intuitively, the metagradient = direction of steepest ascent in metaparameter space. Minimize the loss. 

**Notes**:
- Deciding the data mix and architecture choice is challenging because there's a large design space. I wonder if they push some parameters to the extreme.
- Note: even hyperparameter tuning is sometimes difficult

Main approach is to take the optimization perspective on model training design. Kind of like NAS. 
- Deciding on a training configuration - a set of metaparameters - is a high-dimensional optimization problem (correct).

Input space: all possible metaparameter choices (Q: how do you even know what this is? Does the human have to provide the input space? Probably right? Can't a model try to search for this? Is that where the key is?)
- Ex: datapoints to train on, what model architecture to use, how to initalize model weights. - generally for these decisions, it looks like it's just one group that proves out that this is the configuration we should use basically. Then we just use that and mainly just tweak one thing - like only tweak the data. 

Objective function - takes in a set of parameters, trains a ML model according to those metaparameters, and then returns a target metric evaluated on that model (test accuracy). 

I think this is indeed a pretty good formulation, just need to check how to do this efficiently. Stopped at paragraph 2

- Why optimization? Gradients offer a more effective approach to maximizing high-dimensional functions than grid search.

How to do this:
1. make the objective differentiable w.r.t. the meta parameters
2. update via gradient steps

Metagradient
1. Embed a given aspect of the training setup (training dataset, optimizer hyperparameters) int oa continuous metaparameter vector z \im R^d. 
2. metaparameter z defines a model A(z) (a model with metaparameter z) by way of learning algorithm A

metagradient \delta_z\phi (A(z)) \in R^d is the gradient of the model output w.r.t. the metaparameter

What is a metagradient.

A = learning algorithm / SGD.

2 parts to ML training
1. Decide on a training setup
2. Apply the algorithm defined by training setup to train a model

Goal: optimize model behavior as a function of the training setup (metaparameters)


Main item: iterative algorithms to efficiently compute the metagradient, restrict focus to cases where algorithm A is iterative
- Q: isn't algorithm A always just mapping some metaparameters to a training config?

So A(z) := s_T where model state after T steps where s_{t+1} := h_t(s_t, z)

s_t is the optimizer state at step t
h_t = update mapping from state t to state t+1

optimizer state s_t is a superset that includes the model parameters.

Simplest case, plain gradient descent, s_t = vector of model parameters at step t. Otherwise it needs momentum and moment estimates. 

s_0 is the initial state (includes model parameters and optimizer states)

h_t = update mapping from state t to state t+1 (weight updates basically)

Example: z \in R^T is a per-step learning rate (which is pretty crazy because most learning rates are just driven by a weight decay equation) and algorithm A is full batch gradient descent, then each update h_t is 

h_t(s_t, z) := s_t - z_t\delta(loss(s_t)). Basically update the weights based on the loss. I'm kind of confused though... again, why are we multiplying by z_t (learing rate at step t). Like isn't that metaparameter of a learning rate kind of just embeded on the equation? 

Q: Technically, a metaparameter can also just be the model parameters too right?



**FAQs**:
0. Q: what does it even really mean, to optimize your learning rate using gradient descent? How do you get gradients for that? Is that even differentiable and what does that function even look like?
1. I wonder if they push some parameters to the extreme to better understand how this parameter affects training? Or it's more gradient based?
2. Can they re-create / re-discover the transformer arch
itecture? What's going on here?
3. It's a super big question - how can you possibly do better than a grid over hyperparameters? 
4. How is "optimal" training configuration defined? Probably for the absolute loss? Or rate of loss decrease?
5. How does this compare to AlphaEvolve approach? Right? Is it similar based on the validation loss and just trying different things? Is it faster or just as long because you have to compute?
6. The input space actually itself seems quite large. Shouldn't a model be trained to search and provide for that?
7. How expensive is this training process? Do the scaling laws hold here where going through these steps on a smaller model and demonstrating gains at a small scale can translate to gains at a large scale?
8. Interesting, so it's a one hot encoded continuous metaparameter vector? Do the intermediate dimensional values matter? Or it's a learned embedding I guess?
9. - z \in R^n = vector of continuous metaparameters representing the aspects of the training setup we aim to optimize. Ex. if we want to adjust LR and weight decay of SGD, then n=2. Interesting... but their scales are so far off? Q: How do you initialize these metaparameters?
10. Do you control epochs trained? Because validation loss obviously is affected by epochs trained as a hyperparameter.
11. Q: Isn't \phi computationally intensive? How many steps need to be run?
12. What exactly does the gradient mean? That from the validation loss, we can know if we need a larger or smaller loss function?
13. Why am I guessing what the LR should be if we can compute the derivative in the first place?
14. Q: 3-5x the cost of training. What was it before?
- Presentation will address later
- their optimization is 1.2x training cost. each step is 3-5x the cost
15. Q: How to initialize?
- weight all 1


Notes:
1. has done presentation before

**Action items**:
- Mainly just read the abstract so need to continue pass 1
- The additional readings are going to be good as well: 
    - DataRater: Meta-Learned Dataset Curation (https://arxiv.org/pdf/2505.17895)
    - MAGIC: Near-Optimal Data Attribution for Deep Learning (https://arxiv.org/pdf/2504.16430)
