# Gradient Descent Variations

In this notebook, we explore different variants of gradient descent to optimize our atmosphere.

- These variants start simple, and become more complex, in order to more accurately represent the atmosphere.

In [None]:
import numpy as np
import matplotlib.pyplot as plt 
import math
import copy
from sklearn.decomposition import PCA
from matplotlib.colors import LinearSegmentedColormap

import helper

run_all_code = False # Runs code blocks below, one by one, with plots

First, we generate data:

- The real data, using the real atmospheric forcing
- Fake data, using our noisy atmospheric forcing

In [None]:
nr,nc = 32,32
dt = 0.01
F = 0.1

#C_known is the covariance matrix of the full control vector f

C_control, x0, _, M = helper.generate_world(nr, nc, dt, F)

#C_known is the covariance matrix of the correct part of the control vector f
#C_error is the covariance matrix of the incorrect part of the control vector f

gamma = 2/3
C_known = C_control * gamma
C_error = C_control * (1-gamma)
C_ocean = C_control / 6

f_true, f_guess = helper.generate_true_and_first_guess_field(C_known, C_error, nr, nc)

In [None]:
# Atmosphere forcing coefficient
F = 0.1
# Standard deviation of the noise in the observations
sigma = 0.1 


# Number of timesteps to run the simulation for (repeated each iter)
num_timesteps = 4
# Number of iterations of gradient descent (repeated each run)
num_iters = 5
# Step size for gradient descent
step_size = 0.1

# Number of times to run the whole gradient descent optimization
num_runs = 1

# Run the simulation with the true and guessed control vector
saved_timesteps, real_state_over_time  = helper.compute_affine_time_evolution_simple(x0, M, F*f_true,  num_timesteps)
saved_timesteps, guess_state_over_time = helper.compute_affine_time_evolution_simple(x0, M, F*f_guess, num_timesteps)

Later, we'll need observations of the real ocean state:

In [None]:
num_obs_per_timestep = 50



observed_state_over_time_2d = helper.observe_over_time(real_state_over_time, sigma, 
                                                       num_obs_per_timestep, nr, nc)


observed_state_over_time =     [np.reshape(observed_state_2d, (nr*nc, 1)) 
                                for observed_state_2d in observed_state_over_time_2d]


# Basic Concept for this notebook: Gradient Descent

Our goal is to improve the accuracy of our atmosphere $f$ by modifying it, so that it creates a more accurate ocean simulation over time, $x(t)$.

- We measure the accuracy of our ocean simulation using the above observations, and plugging them into a loss function, $J(x)$.

  - $J(x)$ tells us how *bad* the fit between our observations and our simulation are.

- So, we want to choose an updated $f$ that minimizes $J$.

To accomplish this, we'll use **gradient descent**: we take small steps, modifying $f$ using gradient $-dJ/df$ repeatedly.

# Conventions for discussing gradient descent


In the following sections, our work can get pretty messy. So, the ease the process, we introduce some symbolic conventions for discussing each component of our algorithm:

- $i$: The iteration of gradient descent we're referencing, starting from $i=0$.

- $f_i$: The total atmospheric control vector at the start of iteration $i$.

- $u_i$: The update we apply to $f_i$ during iteration $i$:

$$f_{i+1} = f_i + u_i$$

- $a_i$: The total atmospheric adjustment we've made to $f_0$ before iteration $i$.

$$f_i = f_0 + a_i$$

We'll generally refer to $f_{true}$ as the true atmospheric control vector.


# Simple Gradient Descent + Template

First, let's try implementing the most basic version of gradient descent:

- $\eta$ is our "step size".

$$ u_i = - \eta \frac{dJ}{df}$$

### Loss functions

While we're at it, we'll keep track of three types of "loss":

- "Ocean loss": this is the misfit between the ocean simulation $x(t)$, and our observed state $Hy(t)$. This is $J$.

  - We mask/ignore cells that aren't observed.

$$ J_{ocean} = J = \sum_t \Big(x - Hy(t) \Big)^\top \Big( x - Hy(t) \Big)$$

- "Atmosphere loss": this is the misfit between the current guess for the atmospheric control $f_i$, and what we know is the true atmospheric control $f_{true}$.

  - This measure is not used for optimization: rather, we use it to check how well our gradient descent is training on information it doesn't have ($f_{true}$). Thus, we can use this as a proxy for overfitting.

  $$J_{atm} = \Big(f_{true} - f_i \Big)^\top \Big( f_{true} - f_i \Big)$$

- "Mahalanobis loss": this is a measure of how far the covariance of the control adjustment $a_i$ is to a given covariance matrix $C$ (the "true" covariance). 

  - It becomes larger if covariance of $a_i$ moves further away from $C$, or the $|a_i|$ increases.

  - Later, we want to optimize $a_i$ to have covariance $C$, so this will be a useful metric.

  - It also gives an idea of how "statistically accurate" $a_i$ is.

$$J_{mhl} = a_i^\top C^{-1} a_i$$
  

### Simple Gradient Descent

We'll use a helper function to keep calculation of our losses separate.

In [None]:
losses_template = { #Losses at each iteration
        "ocean_misfit": [],
        "atmosphere_misfit": [],
        "mahalanobis(covariance similarity)": [],
    }


def update_losses(losses, ocean_states_observed, ocean_states_simulated, f_guess, f_adjust, f_true, C_error):
    """
    Updates the losses dictionary with new loss values for ocean, atmosphere, and control adjustment.

    Args:
    losses (dict): Dictionary containing lists of loss values
    ocean_states_observed (list): List of observed ocean states
    ocean_states_simulated (list): List of simulated ocean states
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    f_adjust (np.ndarray): Adjustment to the atmospheric forcing field
    f_true (np.ndarray): True atmospheric forcing field
    C_error (np.ndarray): Covariance matrix for the control error

    Returns:
    dict: Updated losses dictionary with new loss values appended
    """
    f_i = f_guess + f_adjust

    # Compute losses
    ocean_loss_i = helper.compute_J(ocean_states_observed, ocean_states_simulated) # J_{ocean} = J
    atmos_loss_i = helper.compute_Jt(f_true, f_i)                                  # J_{atm} = misfit of atm
    
    mahal = f_adjust.T @ np.linalg.inv(C_error) @ f_adjust # C_error should be the covariance of our adjustment
    mahal_loss_i = np.linalg.norm(mahal)

    #Store losses
    losses["ocean_misfit"].append(ocean_loss_i)
    losses["atmosphere_misfit"].append(atmos_loss_i)
    losses["mahalanobis(covariance similarity)"].append(mahal_loss_i)

    return losses

Now, we're ready to implement gradient descent.

In [None]:
def gradient_descent_simple(M, F, f_true, f_guess, C_known, C_error,         # World parameters
                              x0, num_timesteps,                                    # Simulation parameters
                              ocean_states_observed, num_iters, step_size,      # Optimization parameters
                              disp=False):
    """
    Perform gradient descent to optimize the atmospheric forcing field. 
    
    The step we take at each iteration is computed using -step_size*dJ/df.

    Args:
        M: Model matrix
        F: Scalar constant for forcing
        f_true: True atmospheric forcing field
        f_guess: Initial guess for the atmospheric forcing field
        C_known: Covariance matrix for the known portion of the control
        C_error: Covariance matrix for the control error
        x0: Initial ocean state
        num_timesteps: Number of timesteps
        ocean_states_observed: Observed ocean states
        num_iters: Number of iterations
        step_size: Step size for gradient descent
        step_compute_method: Function to compute the step to take at each iteration
        extra_params: Extra parameters to pass to the step
        disp: Flag to print information

    Returns:
        f: Optimized atmospheric forcing field
        losses: Dictionary of losses
            ocean_misfit: Ocean loss at each iteration
            atmosphere: Atmospheric loss at each iteration
            mahalanobis(covariance similarity): Mahalanobis distance for the control adjustment at each iteration
    """

    f_adjust = np.zeros((nr*nc,1))

    losses = losses_template.deepcopy()

    for i in range(num_iters):
        if i%10==0 and disp:
            print("Iteration", i)

        #Compute results of previous update rule

        f_i = f_guess + f_adjust #f_i = f_0 + a_i
        _, ocean_states_simulated = helper.compute_affine_time_evolution(x0, M, F*f_i, num_timesteps)

        # Compute and store losses
        losses = update_losses(losses, ocean_states_observed, ocean_states_simulated, f_guess, f_adjust, f_true, C_error)

        # Apply update rule (gradient descent step)
        s = helper.compute_dJ_df(M, F, ocean_states_observed, ocean_states_simulated)
        u_i = -step_size * s #Update rule

        f_adjust = f_adjust + u_i

    return f_adjust, losses

    

### Gradient Descent Template

This is the most basic form of our update rule for $u_i$. However, we want to explore other possible update rules. So, we'll create a boilerplate version of gradient descent, that allows us to choose our update rule, without re-writing the rest of our code:

- First, we'll need to take care of some book-keeping: debug variables that we may find useful.

In [None]:
possible_debug_vars = { #Debug variables to compute at each iteration
    "Norm of s_i": [],
    "Expected Delta J w simple gd": [],
    "Expected Delta J w update rule": [],
    "Norm of simple gd ui": [],
    "Norm of update rule ui": [],
    "Normalized dot product $a_i$ and $u_i$": [],
    "$s_i^T Cs_i$": [],
    "$a_i$": [],
    "Normalized $a_i^T s_i$": [],
    "Normalized $-Cs_i \cdot a_i$": [],
    "Norm of $a_i$": [],
    "Actual Delta J": []
}

def update_debug_vars(debug_vars, x0, M, F, f_guess, f_adjust, 
                      C_known, C_error, 
                      s, step_size, ui, 
                      ocean_states_simulated, ocean_states_observed):
    """
    Updates the debug variables dictionary with various metrics for gradient descent analysis.

    Args:
    debug_vars (dict): Dictionary containing lists of debug variable values
    x0 (np.ndarray): Initial ocean state
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    f_adjust (np.ndarray): Adjustment to the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    ui (np.ndarray): Update vector for the current iteration
    ocean_states_simulated (list): List of simulated ocean states
    ocean_states_observed (list): List of observed ocean states

    Returns:
    dict: Updated debug_vars dictionary with new values appended to each metric
    """
    #Initialize useful variables
    num_timesteps = len(ocean_states_observed)
    ui_simple_gd = -step_size * s
    delta = s.T @(ui_simple_gd)

    #Compute debug vars
    norm_s = np.linalg.norm(s)
    exp_delta_J_simple_gd = (s.T @ ui_simple_gd)[0,0]
    exp_delta_J_update_rule = (s.T @ ui)[0,0]
    norm_simple_ui = np.linalg.norm(ui_simple_gd)
    norm_ui = np.linalg.norm(ui)
    norm_dot_product = (f_adjust.T @ ui)[0,0] / (np.linalg.norm(f_adjust) * np.linalg.norm(ui))
    sTCs = (s.T @ C_error @ s)[0,0]
    ai = f_adjust
    normalized_aiTs = (f_adjust.T @ s)[0,0] / (np.linalg.norm(f_adjust) * np.linalg.norm(s))
    normalized_Csdotai = (f_adjust.T @ (C_error @ s))[0,0] / (np.linalg.norm(f_adjust) * np.linalg.norm(C_error @ s))
    norm_ai = np.linalg.norm(f_adjust)

    #Store debug vars
    debug_vars["Norm of s_i"].append(norm_s)
    debug_vars["Expected Delta J w simple gd"].append(exp_delta_J_simple_gd)
    debug_vars["Expected Delta J w update rule"].append(exp_delta_J_update_rule)
    debug_vars["Norm of simple gd ui"].append(norm_simple_ui)
    debug_vars["Norm of update rule ui"].append(norm_ui)
    debug_vars["Normalized dot product $a_i$ and $u_i$"].append(norm_dot_product)
    debug_vars["$s_i^T Cs_i$"].append(sTCs)
    debug_vars["$a_i$"].append(ai)
    debug_vars["Normalized $a_i^T s_i$"].append(normalized_aiTs)
    debug_vars["Normalized $-Cs_i \cdot a_i$"].append(-normalized_Csdotai)
    debug_vars["Norm of $a_i$"].append(norm_ai)

    #Handle last debug var: Actual Delta J
    f_new = f_guess + f_adjust + ui
    _, new_ocean_states_simulated = helper.compute_affine_time_evolution_simple(x0, M, F*f_new, num_timesteps)

    new_J = helper.compute_J(ocean_states_observed, new_ocean_states_simulated)
    old_J = helper.compute_J(ocean_states_observed, ocean_states_simulated)

    actual_delta_J = new_J - old_J
    
    debug_vars["Actual Delta J"].append(actual_delta_J)
    
    return debug_vars

Now, we can actually implement our template.

In [None]:
def gradient_descent_template(M, F, f_true, f_guess, C_known, C_error,         # World parameters
                              x0, num_timesteps,                                    # Simulation parameters
                              ocean_states_observed, num_iters, step_size,      # Optimization parameters
                              update_rule, update_params, disp=False):
    """
    Perform gradient descent to optimize the atmospheric forcing field. 
    
    The step we take at each iteration is computed using a modified update rule, and extra parameters as necessary.

    Args:
        M: Model matrix
        F: Scalar constant for forcing
        f_true: True atmospheric forcing field
        f_guess: Initial guess for the atmospheric forcing field
        x0: Initial ocean state
        num_timesteps: Number of timesteps
        ocean_states_observed: Observed ocean states
        num_iters: Number of iterations
        step_size: Step size for gradient descent
        C_known: Covariance matrix for the known portion of the control
        C_error: Covariance matrix for the control error
        update_params: Function to compute the update rule to take at each iteration
        extra_params: Extra parameters to pass to the update rule
        disp: Flag to print information

    Returns:
        f: Optimized atmospheric forcing field
        losses: Dictionary of losses
            ocean_misfit: Ocean loss at each iteration
            atmosphere: Atmospheric loss at each iteration
            mahalanobis(covariance similarity): Mahalanobis distance for the control adjustment at each iteration
    """
    size = f_guess.shape[0]
    f_adjust = np.zeros((size,1))

    losses = copy.deepcopy(losses_template)
    debug_vars = copy.deepcopy(possible_debug_vars)

    for i in range(num_iters):
        if i%10==0 and disp:
            print("Iteration", i)
        
        #Compute results of previous update rule
        f_i = f_guess + f_adjust #f_i = f_0 + a_i
        _, ocean_states_simulated = helper.compute_affine_time_evolution_simple(x0, M, F*f_i, num_timesteps)
        

        # Compute and store losses
        losses = update_losses(losses, ocean_states_observed, ocean_states_simulated, f_guess, f_adjust, f_true, C_error)
        
        # Compute and store debug variables
        
        s = helper.compute_dJ_df(M, F, ocean_states_observed, ocean_states_simulated)
        ui = update_rule(i, s, step_size, f_adjust, *update_params) #Update rule

        debug_vars = update_debug_vars(debug_vars, x0, M, F, f_guess, f_adjust, 
                                       C_known, C_error, 
                                       s, step_size, ui, 
                                       ocean_states_simulated, ocean_states_observed)

        # Apply update rule to f_adjust

        f_adjust = f_adjust + ui

        

    return f_adjust, losses, debug_vars

If we want to reproduce simple gradient descent using this template:

In [None]:
def simple_gradient_update_rule(curr_iter, s, step_size, f_adjust):
    """
    Computes the update step for simple gradient descent.

    Args:
    curr_iter (int): Current iteration number (unused in this function)
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    f_adjust (np.ndarray): Current adjustment to the forcing field (unused in this function)

    Returns:
    np.ndarray: Update step for the forcing field adjustment
    """
    return -step_size * s  # Just use the gradient of the loss

def simple_gradient_descent(M, F, f_true, f_guess, C_known, C_error,       # World parameters
                            x0, timesteps,                                   # Simulation parameters
                            ocean_states_observed, num_iters, step_size,      # Optimization parameters
                            disp=False):                                      # Optimization method
    """
    Implements simple gradient descent for optimizing the atmospheric forcing field.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    tuple: (f_adjust, losses, debug_vars)
        f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
        losses (dict): Dictionary of loss values over iterations
        debug_vars (dict): Dictionary of debug variables over iterations
    """
    return gradient_descent_template(M, F, f_true, f_guess, C_known, C_error, 
                                     x0, timesteps, 
                                     ocean_states_observed, num_iters, step_size,
                                     simple_gradient_update_rule, [], disp)

# Improving our Adjustments Using Covariance Constraints

### Our Motivation

Our current approach works under the assumption that, if our atmosphere adjustment creates a more realistic ocean simulation, then it creates a more realistic atmosphere.

- But it's possible for the atmosphere to induce the expected ocean currents, while being unrealistic.
- In fact, there may be many ways to create an atmosphere that gives the desired effect on ocean state.

### Our Solution

So, we want to enforce another layer of realism on our adjustment: in addition to improving our ocean simulation, we also want the adjustment to have similar *structure* to what we expect from the atmosphere.

- We'll do this by modifying our adjustment, so that the **covariance** is similar to what we expect.
- If our adjustment has a similar covariance to the overall atmosphere, we hope it'll be a more realistic adjustment. 

We want to experiment with four different ways we can implement this covariance enforcement:

## 1. Cholesky Approach

1. Try to force our update $u_i$ to have covariance $C$, using the cholesky decomposition.
      - This is the way we generated our random vector with covariance $C$, previously.

$$ LL^\top = C \qquad \qquad z \sim \mathcal{N}(0,1) \qquad \qquad Lz \sim \mathcal{N}(0,C)$$

So, our result would be the new update rule $u_i$.

$$ z = -\eta\frac{dJ}{df} \qquad \qquad \qquad u_i = Lz $$

The problem with this approach, is that the proof for Cholesky assumes that $z$ is a random vector: in our case, we want to modify $\frac{dJ}{df}$, which is not a random vector. Maybe it'll have the same smoothing effect, but it doesn't have the same theoretical grounding.



## 2. Include Covariance Constraint in $J$

2. Add the covariance constraint (penalize mahalanobis distance) to $a_i$ in our loss function, while generating $dJ/df$. 

- We'll include a weight $\alpha$ to determine the relative importance of the second term.

$$J'(x, f) = J(x) +
            \alpha \Big( a^\top C^{-1} a \Big) $$

This approach should, hypothetically, encourage our total control adjustment $a_i$ to have covariance $C$.

If we use this $J'$ and then use gradient descent, we get:

$$u_i = -\eta \cdot \Bigg( \frac{dJ}{df} + 2\alpha C^{-1} a  \Bigg)$$

## 3. Dan's Method

Compute $dJ/df$ first, *then* constrain your update rule $u_i$ to do what we want:

- $u_i$ should reduce $J$, in order to make a better fit to ocean date. We'll say that it needs to decrease by $\delta$. We can approximate this as:

$$ \delta \approx s^\top u_i$$

- We also want to keep the covariance of the update rule closer to $C$: we'll penalize the mahalanobis distance.

$$u_i^\top C^{-1} u_i $$

Together, this can be encoded by the lagrangian (converting $u)

$$ \mathcal{L} = \lambda (\delta - s^\top u_i) + u_i^\top C^{-1} u_i$$

If we solve for the minimizing conditions $d\mathcal{L}/du_i=d\mathcal{L}/d\lambda = 0$, we find a solution:

$$ u_i = \delta \Big( \frac{Cs}{s^\top C s} \Big)$$

This is the proposal by Dan Amrhein, in his informal paper, "Ocean state estimation with atmospheric adjustments constrained by prior covariance and observations".

## 4. Dan's Method Modified

One problem: originally, the goal was to pressure our total control adjustment $a_i$ to have covariance $C$. But Dan's method, when applied to gradient descent, instead pressures each update $u_i$ to have covariance $C$.

- If we keep adding updates of covariance $C$ to our control adjustment, the covariance will keep increasing.

So, we'll modify Dan's approach:

- Compute $dJ/df$ first, then constrain your **total control adjustment** $u_i + a_i$ to have covariance $C$, while making sure your update $u_i$ reduces $J$.

$$ \mathcal{L} = \lambda (\delta - s^\top u_i) + (a_i+u_i)^\top C^{-1} (a_i+u_i)$$

If we solve for $d\mathcal{L}/du_i=d\mathcal{L}/d\lambda = 0$, we find a solution:

$$u_i = -a_i + \Big(\delta + a_i^\top s\Big) \Big( \frac{ Cs }{s^\top C s} \Big)$$

# Implementing and testing covariance-enforcement methods

## 1. Cholesky Approach

First, we'll take the simplest approach: compute the derivative, and then apply the covariance adjustment $z \to Lz$, as if we had a random variable.

In [None]:
def cholesky_update_rule(curr_iter, s, step_size, f_adjust, C_error):
    """
    Computes the update step using Cholesky decomposition of the error covariance matrix.

    Args:
    curr_iter (int): Current iteration number (unused in this function)
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    f_adjust (np.ndarray): Current adjustment to the forcing field (unused in this function)
    C_error (np.ndarray): Covariance matrix for the control error

    Returns:
    np.ndarray: Update step for the forcing field adjustment

    Notes:
    Applies the Cholesky decomposition L of C_error to s, then rescales the result 
    to match the original gradient's magnitude.
    """
    # Apply cholesky decomposition L : C_error = L @ L.T
    L = np.linalg.cholesky(C_error)
    cholesky_s = L @ s

    # Rescale so magnitude is the same
    rescaled_cholesky_s = cholesky_s * (np.linalg.norm(s) / np.linalg.norm(cholesky_s))
    step = - step_size * rescaled_cholesky_s

    return step

def cholesky_gradient_descent(M, F, f_true, f_guess, C_known, C_error,       # World parameters
                              x0, timesteps,                                 # Simulation parameters
                              ocean_states_observed, num_iters, step_size,   # Optimization parameters
                              disp=False):                                   # Optimization method
    """
    Implements gradient descent using Cholesky decomposition for optimizing the atmospheric forcing field.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    tuple: (f_adjust, losses, debug_vars)
        f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
        losses (dict): Dictionary of loss values over iterations
        debug_vars (dict): Dictionary of debug variables over iterations
    """
    return gradient_descent_template(M, F, f_true, f_guess, C_known, C_error,
                                     x0, timesteps, 
                                     ocean_states_observed, num_iters, step_size,
                                     cholesky_update_rule, [C_error], disp)


In [None]:

if run_all_code:

    f_adjust_cholesky, losses_cholesky, debug_vars_cholesky = cholesky_gradient_descent(M, F, f_true, f_guess, C_known, C_error,
                                                                                        x0, num_timesteps, 
                                                                                        observed_state_over_time, num_iters, step_size, 
                                                                                        disp=True)


    # Plot the true field, the first guess, and the improved field
    f_optimized_cholesky = f_guess + f_adjust_cholesky

    many_states = [f_true, f_guess, f_optimized_cholesky]
    titles = ['True field', 'First-guess field', 'Cholesky field']
    big_title = 'Atmospheric Field: Gradient Descent with Cholesky Decomposition'

    helper.plot_multi_heatmap(many_states, nr, nc, titles, big_title, vmin=None, vmax=None)

Now, we want to be able to directly compare the two approaches.
- In order to make an apples-to-apples comparison, both gradient descent methods have to be using the same initialization, model, etc.

In [None]:
def compare_gd_methods_once(M, F, f_true, f_guess, C_known, C_error, 
                            x0, timesteps, 
                            ocean_states_observed, num_iters, step_size, 
                            methods, disp=False):
    """
    Runs multiple methods of gradient descent on the same dataset and compares their performance.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    methods (list): List of gradient descent methods to compare. Each method is a list of the form
                    ["Method Name", method_func, extra_params]
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    dict: A dictionary where keys are method names and values are tuples containing:
          - f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
          - losses (dict): Dictionary of loss values over iterations
          - debug_vars (dict): Dictionary of debug variables over iterations

    Notes:
    This function applies each specified gradient descent method to the same initial conditions
    and dataset, allowing for direct comparison of their performance.
    """
    results = {}

    for method_name, method_func, extra_params in methods:
        if disp:
            print(f"Running method {method_name}")

        f_adjust, losses, debug_vars = gradient_descent_template(M, F, f_true, f_guess, C_known, C_error, 
                                                     x0, timesteps, 
                                                     ocean_states_observed, num_iters, step_size,
                                                     method_func, extra_params, disp)
        results[method_name] = (f_adjust, losses, debug_vars)

    return results


    

In [None]:
if run_all_code:

    methods = [
        ["Simple Gradient Descent", simple_gradient_update_rule, []],
        ["Cholesky Gradient Descent", cholesky_update_rule, [C_error]],
    ]

    # Run multiple methods of gradient descent
    results = compare_gd_methods_once(M, F, f_true, f_guess, C_known, C_error, 
                                    x0, num_timesteps, 
                                    observed_state_over_time, num_iters, step_size, 
                                    methods, disp=True)

    # Plot the results 
    many_states = [f_true] + [f_guess + f_adjust for f_adjust, _, _ in results.values()]
    titles = ['True field'] + [f"{method[0]} field" for method in methods]
    big_title = 'Atmospheric Field: Gradient Descent with Different Methods'

    helper.plot_multi_heatmap(many_states, nr, nc, titles, big_title, vmin=None, vmax=None)


We might not get a representative sample from only running each once. So, we'll create a function for running it multiple times:

In [None]:
def compare_gd_methods_many_times(nr, nc, dt, F, gamma, sigma, num_obs_per_timestep, 
                                  num_timesteps, num_iters, step_size, 
                                  C_known, C_error, methods, num_runs, disp=False):
    """
    Create many different sets of data. 
    For each one, we will run each of our gradient descent methods.
    Once we finish, we average losses across all runs.

    Args:
    nr (int): Number of rows in the grid
    nc (int): Number of columns in the grid
    dt (float): Time step
    F (float): Forcing parameter
    gamma (float): Proportion of the control vector that is correct
    sigma (float): Standard deviation of observation noise
    num_obs_per_timestep (int): Number of observations per timestep
    num_timesteps (int): Number of timesteps
    num_iters (int): Number of iterations of gradient descent
    step_size (float): Step size for gradient descent
    methods (list): List of gradient descent methods to compare
    num_runs (int): Number of times to run the whole optimization process
    disp (bool): If True, print progress
    
    Returns:
    dict: Dictionary containing averaged losses and debug variables for each method
    """

    # Initialize results dictionary
    losses =     {method[0]: copy.deepcopy(losses_template) 
                for method in methods}
    debug_vars = {method[0]: copy.deepcopy(possible_debug_vars) 
                  for method in methods}

    for run in range(num_runs):
        if disp:
            print(f"Run {run + 1}/{num_runs}")

        # Generate world
        C_control, x0, _, M = helper.generate_world(nr, nc, dt, F)
        C_known, C_error = C_control * gamma, C_control * (1-gamma)
        
        f_true, f_guess = helper.generate_true_and_first_guess_field(C_known, C_error, nr, nc)

        # Run the simulation with the true and guessed control vector
        _, real_state_over_time  = helper.compute_affine_time_evolution_simple(x0, M, F*f_true,  num_timesteps)

        observed_state_over_time_2d = helper.observe_over_time(real_state_over_time, sigma, 
                                                               num_obs_per_timestep, nr, nc)

        observed_state_over_time = [np.reshape(observed_state_2d, (nr*nc, 1)) 
                                    for observed_state_2d in observed_state_over_time_2d]

        # Run each method
        for method_name, method_func, extra_params in methods:
            if disp:
                print(f"  Method: {method_name}")

            results = compare_gd_methods_once(M, F, f_true, f_guess, C_known, C_error,
                                                x0, num_timesteps, 
                                                observed_state_over_time, num_iters, step_size, 
                                                [[method_name, method_func, extra_params]], disp)
            
            # Include new losses
            for method_name, (_, method_losses, method_debug_vars) in results.items():
                for loss_name, loss_list in losses[method_name].items():
                    loss_list.append(method_losses[loss_name])

                for debug_name, debug_list in debug_vars[method_name].items():
                    debug_list.append(method_debug_vars[debug_name])

    # Average losses

    for method_name, method_losses in losses.items():
        for loss_name, loss_list in method_losses.items():
            losses[method_name][loss_name] = np.mean(loss_list, axis=0)

    for method_name, method_debug_vars in debug_vars.items():
        for debug_name, debug_list in method_debug_vars.items():
            debug_vars[method_name][debug_name] = np.mean(debug_list, axis=0)

    return losses, debug_vars

                


In [None]:
if run_all_code:

    methods = [
        ["Simple GD", simple_gradient_update_rule, []],
        ["Cholesky GD", cholesky_update_rule, [C_error]],
    ]

    losses, debug_vars = compare_gd_methods_many_times(nr, nc, dt, F, gamma, sigma, num_obs_per_timestep, 
                                    num_timesteps, num_iters, step_size, 
                                    C_known, C_error, methods, num_runs)


Some code to allow us to directly compare the losses for each method:

In [None]:
def plot_losses(losses_many, num_obs_per_timestep, step_size, num_timesteps, num_iters, min_iter=None, max_iter=None):
    """
    Plots the losses for multiple gradient descent methods over iterations.

    Args:
    losses_many (dict): Dictionary of losses for each method. 
                        Keys are method names, values are dictionaries containing losses.
    num_obs_per_timestep (int): Number of observations per timestep
    step_size (float): Step size used in gradient descent
    num_timesteps (int): Number of timesteps in the simulation
    num_iters (int): Number of iterations of gradient descent
    min_iter (int): Lowest plotted iter (default: None, plots from the beginning)
    max_iter (int): Highest plotted iter (default: None, plots until the end)

    Returns:
    None: This function displays the plot using matplotlib.pyplot.show()

    Notes:
    Creates a 2x2 grid of plots:
    1. Ocean misfit
    2. Atmosphere loss
    3. Control adjust Mahalanobis distance
    4. J' (combined loss for covariance constraint method, ocean loss for others)

    Each plot shows the evolution of the respective loss over iterations for all methods.
    """
    fig, axs = plt.subplots(2, 2, figsize=(10, 10))

    loss_funcs = ["$\sum_t (Ex(t)-y(t))^{T} (Ex(t)-y(t))$", 
                  "$\sum_t (f_i(t)-f_{true}(t) )^{T} ( f_i(t)-f_{true}(t) )$", 
                  "$a_i^T C^{-1} a_i$"]

    # Determine the range of iterations to plot
    min_iter = 0 if min_iter is None else max(0, min_iter)
    max_iter = num_iters if max_iter is None else min(num_iters, max_iter)
    plot_range = slice(min_iter, max_iter)

    for i, (loss_name, ax, func) in enumerate(zip(["ocean_misfit", "atmosphere_misfit", "mahalanobis(covariance similarity)"], axs.flatten(), loss_funcs)):
        for method_name, losses_dict in losses_many.items():
            ax.plot(range(min_iter, max_iter), losses_dict[loss_name][plot_range], label=method_name)
        ax.set_xlabel("Iteration $i$")
        ax.set_ylabel(loss_name+" loss:    "+func)
        ax.legend()
        ax.set_title(f"{loss_name}: "+func)
        
        # Set integer ticks on x-axis
        ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

    # Fourth plot: ocean_misfit + mahalanobis if using covariance control adjust, just ocean otherwise
    ax = axs[1, 1]
    for method_name, losses_dict in losses_many.items():
        if method_name == r"Covariance Constraint J Gradient Descent":
            combined_loss = [o + c for o, c in zip(losses_dict["ocean_misfit"], losses_dict["mahalanobis(covariance similarity)"])]
            ax.plot(range(min_iter, max_iter), combined_loss[plot_range], label=method_name)
        else:
            ax.plot(range(min_iter, max_iter), losses_dict["ocean_misfit"][plot_range], label=method_name)

    ax.set_xlabel("Iteration $i$")
    ax.set_ylabel("J'")
    ax.legend()
    ax.set_title("J'")
    
    # Set integer ticks on x-axis for the fourth plot
    ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))

    fig.suptitle(f"Gradient Descent Variants: num_obs={num_obs_per_timestep}, step_size={step_size}, num_timesteps={num_timesteps}, num_iters={num_iters}")

    plt.tight_layout()
    plt.show()

In [None]:
if run_all_code:
    plot_losses(losses, num_obs_per_timestep, step_size, num_timesteps, num_iters)

Using Cholesky seems to probably be a dead end: it doesn't increase covariance, but it performs worse by every other measure.

Now that we have the tools we need to directly compare our plots, we can proceed with our other methods.

## 2. Include Covariance Constraint in $J$

Here, we use our appended loss function.

$$u_i = -\eta \cdot \Bigg( \frac{dJ}{df} + \alpha 2C^{-1} a  \Bigg)$$

Our procedure is almost identical: we just add a term.

In [None]:
def cov_constraint_J_update_rule(curr_iter, s, step_size, f_adjust, C_error, weight_cov_term):
    """
    Computes the update step using a covariance constraint on the loss function.

    Args:
    curr_iter (int): Current iteration number (unused in this function)
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    f_adjust (np.ndarray): Current adjustment to the forcing field
    C_error (np.ndarray): Covariance matrix for the control error
    weight_cov_term (float): Weight for the covariance constraint term

    Returns:
    np.ndarray: Update step for the forcing field adjustment

    Notes:
    Adds a weighted covariance constraint term to the original gradient,
    then rescales the result to match the original gradient's magnitude.
    """
    cov_term_grad = 2 * np.linalg.inv(C_error) @ f_adjust  # Covariance term

    s_prime = s + weight_cov_term * cov_term_grad  # Gradient of J' with respect to the forcing field

    #print(np.linalg.norm(s), np.linalg.norm(s_prime))

    norm_s_prime = s_prime * (np.linalg.norm(s) / np.linalg.norm(s_prime))  # Rescale magnitude to match original
    
    return -step_size * norm_s_prime

def cov_constraint_J_gradient_descent(M, F, f_true, f_guess, C_known, C_error,       # World parameters
                                      x0, timesteps,                                 # Simulation parameters
                                      ocean_states_observed, num_iters, step_size,   # Optimization parameters
                                      weight_cov_term, disp=False):                  # Optimization method
    """
    Implements gradient descent with a covariance constraint for optimizing the atmospheric forcing field.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    weight_cov_term (float): Weight for the covariance constraint term
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    tuple: (f_adjust, losses, debug_vars)
        f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
        losses (dict): Dictionary of loss values over iterations
        debug_vars (dict): Dictionary of debug variables over iterations
    """
    return gradient_descent_template(M, F, f_true, f_guess, C_known, C_error,
                                     x0, timesteps, 
                                     ocean_states_observed, num_iters, step_size,
                                     cov_constraint_J_update_rule, [C_error, weight_cov_term], disp)



In [None]:
if run_all_code:
    weight_cov_term = 0.0001
    f_adjust_cov_constraint_J, losses_cov_constraint_J, debug_vars_cov_constraint_J = cov_constraint_J_gradient_descent(M, F, f_true, f_guess, C_known, C_error,
                                                                                        x0, num_timesteps, 
                                                                                        observed_state_over_time, num_iters, step_size, 
                                                                                        weight_cov_term, disp=True)

In [None]:
if run_all_code:
    f_optimized_cov_constraint_J = f_guess + f_adjust_cov_constraint_J

    many_states = [f_true, f_guess, f_optimized_cov_constraint_J]
    titles = ['True field', 'First-guess field', 'Covariance-constrained field']
    big_title = 'Atmospheric Field: Gradient Descent with Covariance Constraint'

    helper.plot_multi_heatmap(many_states, nr, nc, titles, big_title, vmin=None, vmax=None)

In [None]:
if run_all_code:
    methods = [
        ["Simple Gradient Descent", simple_gradient_update_rule, []],
        ["Cholesky Gradient Descent", cholesky_update_rule, [C_error]],
        ["Covariance Constraint J Gradient Descent", cov_constraint_J_update_rule, [C_error, weight_cov_term]],
    ]

    num_iters_local = 10
    step_size_local = step_size / 100
    losses, debug_vars = compare_gd_methods_many_times(nr, nc, dt, F, gamma, sigma, num_obs_per_timestep, 
                                    num_timesteps, num_iters_local, step_size_local, 
                                    C_known, C_error, methods, num_runs)

    plot_losses(losses, num_obs_per_timestep, step_size_local, num_timesteps, num_iters_local)

## 3. Dan's Method

We've already derived this method above:

$$ u_i = \delta \Big( \frac{Cs}{s^\top C s} \Big)$$

In [None]:
def dan_update_rule(curr_iter, s, step_size, f_adjust, C_error):
    """
    Computes the update step using Dan's method for improving the Mahalanobis distance.

    Args:
    curr_iter (int): Current iteration number (unused in this function)
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    f_adjust (np.ndarray): Current adjustment to the forcing field (unused in this function)
    C_error (np.ndarray): Covariance matrix for the control error

    Returns:
    np.ndarray: Update step for the forcing field adjustment

    Notes:
    Modifies the gradient direction to improve the Mahalanobis distance while
    maintaining the desired improvement in the loss function J.
    """
    ui_simple_gd = -step_size * s  # Pre-dan step
    delta = s.T @ (ui_simple_gd)   # Compute desired improvement of J
    
    new_vec = (C_error @ s) / (s.T @ C_error @ s)  # Direction modified to improve Mahalanobis distance

    return delta * new_vec

def dan_gradient_descent(M, F, f_true, f_guess, C_known, C_error,       # World parameters
                         x0, timesteps,                                 # Simulation parameters
                         ocean_states_observed, num_iters, step_size,   # Optimization parameters
                         disp=False):                                   # Optimization method
    """
    Implements gradient descent using Dan's method for optimizing the atmospheric forcing field.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    tuple: (f_adjust, losses, debug_vars)
        f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
        losses (dict): Dictionary of loss values over iterations
        debug_vars (dict): Dictionary of debug variables over iterations
    """
    return gradient_descent_template(M, F, f_true, f_guess, C_known, C_error,
                                     x0, timesteps, 
                                     ocean_states_observed, num_iters, step_size,
                                     dan_update_rule, [C_error], disp)


In [None]:
dan_field_view = False

if run_all_code or dan_field_view:

    f_adjust_dan, losses_dan, debug_vars_dan = dan_gradient_descent(M, F, f_true, f_guess, C_known, C_error,
                                                                    x0, num_timesteps, 
                                                                    observed_state_over_time, num_iters, step_size, 
                                                                    disp=True)

    # Plot the true field, the first guess, and the improved field

    f_optimized_dan = f_guess + f_adjust_dan

    many_states = [f_true, f_guess, f_optimized_dan]
    titles = ['True field', 'First-guess field', 'Dan field']
    big_title = 'Atmospheric Field: Gradient Descent with Dan'

    helper.plot_multi_heatmap(many_states, nr, nc, titles, big_title, vmin=None, vmax=None)


In [None]:
# Compare all methods
dan_methods_view = False

if run_all_code or dan_methods_view:
        
        weight_cov_term = 0.0001
    
        methods = [
            #["Simple Gradient Descent", simple_gradient_update_rule, []],
            #["Cholesky Gradient Descent", cholesky_update_rule, [C_error]],
            #["Covariance Constraint J Gradient Descent", cov_constraint_J_update_rule, [C_error, weight_cov_term]],
            ["Dan Gradient Descent", dan_update_rule, [C_error]],
        ]
    
        num_iters_local = 250
        step_size_local = step_size
        losses, debug_vars = compare_gd_methods_many_times(nr, nc, dt, F, gamma, sigma, num_obs_per_timestep, 
                                        num_timesteps, num_iters_local, step_size_local, 
                                        C_known, C_error, methods, num_runs)
    
        plot_losses(losses, num_obs_per_timestep, step_size_local, num_timesteps, num_iters_local)

So far, Dan's Method appears to be, by far, the most successful.

## 4. Dan's Method Modified

Our modified update rule is given by:

$$u_i = -a_i + \Big(\delta + a_i^\top s\Big) \Big( \frac{ Cs }{s^\top C s} \Big)$$

In [None]:
def dan_modified_update_rule(curr_iter, s, step_size, f_adjust, C_error):
    """
    Computes the update step using a modified version of Dan's method for improving the Mahalanobis distance.

    Args:
    curr_iter (int): Current iteration number (unused in this function)
    s (np.ndarray): Gradient of the loss with respect to the forcing field
    step_size (float): Step size for gradient descent
    f_adjust (np.ndarray): Current adjustment to the forcing field
    C_error (np.ndarray): Covariance matrix for the control error

    Returns:
    np.ndarray: Update step for the forcing field adjustment

    Notes:
    Modifies the gradient direction to improve the Mahalanobis distance while
    maintaining the desired improvement in the loss function J. This version
    intends to improve Mahalanobis distance of a_i+u_i (f_adjust+gradient update), instead of just u_i.
    """
    ui_simple_gd = -step_size * s  # Pre-dan step
    delta = s.T @ (ui_simple_gd)   # Compute desired improvement of J
    
    new_vec = (C_error @ s) / (s.T @ C_error @ s)  # Direction modified to improve Mahalanobis distance
    vec_scale = delta + s.T @ f_adjust

    return -f_adjust + vec_scale * new_vec

def dan_modified_gradient_descent(M, F, f_true, f_guess, C_known, C_error,       # World parameters
                                  x0, timesteps,                                 # Simulation parameters
                                  ocean_states_observed, num_iters, step_size,   # Optimization parameters
                                  disp=False):                                   # Optimization method
    """
    Implements gradient descent using a modified version of Dan's method for optimizing the atmospheric forcing field.

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    f_true (np.ndarray): True atmospheric forcing field
    f_guess (np.ndarray): Initial guess for the atmospheric forcing field
    C_known (np.ndarray): Covariance matrix for the known portion of the control
    C_error (np.ndarray): Covariance matrix for the control error
    x0 (np.ndarray): Initial ocean state
    timesteps (int): Number of timesteps for simulation
    ocean_states_observed (list): List of observed ocean states
    num_iters (int): Number of iterations for gradient descent
    step_size (float): Step size for gradient descent
    disp (bool, optional): If True, display progress. Defaults to False.

    Returns:
    tuple: (f_adjust, losses, debug_vars)
        f_adjust (np.ndarray): Final adjustment to the atmospheric forcing field
        losses (dict): Dictionary of loss values over iterations
        debug_vars (dict): Dictionary of debug variables over iterations
    """
    return gradient_descent_template(M, F, f_true, f_guess, C_known, C_error,
                                     x0, timesteps, 
                                     ocean_states_observed, num_iters, step_size,
                                     dan_modified_update_rule, [C_error], disp)

In [None]:
dan_modified_field_view = False

if run_all_code or dan_modified_field_view:
    
        f_adjust_dan_modified, losses_dan_modified, debug_vars_dan_modified = dan_modified_gradient_descent(M, F, f_true, f_guess, C_known, C_error,
                                                                        x0, num_timesteps, 
                                                                        observed_state_over_time, num_iters, step_size, 
                                                                        disp=True)
    
        # Plot the true field, the first guess, and the improved field
    
        f_optimized_dan_modified = f_guess + f_adjust_dan_modified
    
        many_states = [f_true, f_guess, f_optimized_dan_modified]
        titles = ['True field', 'First-guess field', 'Dan modified field']
        big_title = 'Atmospheric Field: Gradient Descent with Dan Modified'
    
        helper.plot_multi_heatmap(many_states, nr, nc, titles, big_title, vmin=None, vmax=None)

In [None]:
dan_modified_methods_view = True

if run_all_code or dan_modified_methods_view:

    weight_cov_term = 1e-6

    methods = [
        #["Simple Gradient Descent", simple_gradient_update_rule, []],
        #["Cholesky Gradient Descent", cholesky_update_rule, [C_error]],
        #["Modified J Gradient Descent $\\alpha = 1e-6$", cov_constraint_J_update_rule, [C_error, 1e-6]],
        #["Modified J Gradient Descent $\alpha = 5e-6$", cov_constraint_J_update_rule, [C_error, 5e-6]],
        #["Modified J Gradient Descent $\alpha = 5e-7$", cov_constraint_J_update_rule, [C_error, 5e-7]],
        
        #["Dan Gradient Descent", dan_update_rule, [C_error]],
        ["Dan Modified Gradient Descent", dan_modified_update_rule, [C_error]],
    ]

    num_iters_local = 250
    step_size_local = step_size*10
    losses, debug_vars = compare_gd_methods_many_times(nr, nc, dt, F, gamma, sigma, num_obs_per_timestep, 
                                    num_timesteps, num_iters_local, step_size_local, 
                                    C_known, C_error, methods, num_runs, disp=True)

    plot_losses(losses, num_obs_per_timestep, step_size_local, num_timesteps, num_iters_local)

This method is... really unstable, for some reason. We see it aggressively oscillating. But why? We'll explore that in the next notebook.