# Adjoint Method

In this notebook, we derive, implement, and verify the **adjoint method**, a technique for efficiently computing derivatives for our simulation.

In [1]:
import helper
import numpy as np
import matplotlib.pyplot as plt 
import math


First, we generate data:

- The real data, using the real atmospheric forcing
- Fake data, using our noisy atmospheric forcing

In [2]:
nr,nc = 32,32
dt = 0.01
F = 0.1

#C_control is the covariance matrix of the full control vector f

C_control, x0, _, M = helper.generate_world(nr, nc, dt, F)

#C_known is the covariance matrix of the correct part of the control vector f
#C_error is the covariance matrix of the incorrect part of the control vector f

gamma = 2/3

C_known = C_control * gamma
C_error = C_control * (1-gamma)
C_ocean = C_control / 6

f_true, f_guess = helper.generate_true_and_first_guess_field(C_known, C_error, nr, nc)


In [3]:
# Atmosphere forcing coefficient
F = 0.1
# Standard deviation of the noise in the observations
sigma = 0.1 


# Number of timesteps to run the simulation for (repeated each iter)
num_timesteps = 10
# Number of iterations of gradient descent (repeated each run)
num_iters = 5
# Step size for gradient descent
step_size = 0.1

# Number of times to run the whole gradient descent optimization
num_runs = 1

# Run the simulation with the true and guessed control vector
saved_timesteps, real_state_over_time  = helper.compute_affine_time_evolution_simple(x0, M, F*f_true,  num_timesteps)
saved_timesteps, guess_state_over_time = helper.compute_affine_time_evolution_simple(x0, M, F*f_guess, num_timesteps)

Later, we'll need observations of the real ocean state:

In [4]:
num_obs_per_timestep = 50



observed_state_over_time_2d = helper.observe_over_time(real_state_over_time, sigma, 
                                                       num_obs_per_timestep, nr, nc)


observed_state_over_time =     [np.reshape(observed_state_2d, (nr*nc, 1)) 
                                for observed_state_2d in observed_state_over_time_2d]


### Outline

Our goal is to use our ocean simulation to improve our estimate of atmospheric conditions.

Here's the basic outline:

- We have a model of the ocean, that depends on atmospheric conditions.

- We have an initial estimate for our atmospheric state. 
  - We use it to simulate our ocean model.

- If our atmosphere estimate is inaccurate, it will likely cause some inaccuracies in our ocean simulation.
  - We can compute our inaccuracy by observing the real ocean state, and comparing it to the simulation.
  
- We can use the derivative chain rule to determine how to improve the ocean model, by adjusting our atmospheric estimate.
  - We use the adjoint method to create this chain rule more efficiently.

Hopefully, this adjusted atmospheric estimate is better than our old one!


### Progress So Far


We've already handled the first three parts:

- We have an **ocean model**.
- We've created a "true" atmosphere, and an **estimated** version of that atmosphere.
- We have modelled the process of gathering **observations**.

Now, we want to handle the last part: using those observations to improve our atmospheric estimate.

### Loss of our model

First, we'll use our observations to evaluate the quality of our ocean simulation. We'll need some variables to write this clearly.

First, important properties of our system:
- $x^*(t)$ is the **true ocean state** (the temperature of the atmosphere in each region).
- $f^*(t)$ is the **true atmospheric forcing** (temperature of the atmosphere) over time.
- $z(t)$ is the **observed ocean state**: for the sake of this problem, we'll treat it as if it were the true ocean state.
- $x(0)$ is the **initial ocean state**.
  - We'll assume that this is perfectly accurate: we'll use this to initialize our simulation.
  - We could, however, modify our problem to also improve this variable, as well.

Next, variables that describe our estimates:
- $x(t)$ is our estimate of the **ocean state**, based on our **simulation**.
  - We run our simulation for $\tau$ timesteps.
- $f(t)$ is our **first guess estimate** of the **atmospheric control**. 
  - This is used when we're simulating $x(t)$.

- $J$ is our **loss function**, representing how bad our estimate is (larger $J$, worse estimate).
  - This loss function will compare $x(t)$ to $z(t)$.

The simplest useful model for our loss is *squared difference*: the larger the squared difference between the **observed** and **simulated** ocean state, the less accurate our simulation is.

In 1D, if we normalize by the variance, we get:

$$J_{1D}(t) = (z(t) - x(t))^2 / \sigma^2$$

In the multivariable case (where $z(t)$ and $x(t)$ are vectors, and $W$ is the covariance of observation error), we get

$$J_{2D}(t) = \Big( z(t) - x(t) \Big)^\top W^{-1}
              \Big( z(t) - x(t) \Big)$$

Finally, we sum this up over all timesteps: we get our total squared difference.

$$J = \sum_{t=1}^{\tau}
      \Big( z(t) - x(t) \Big)^\top W^{-1}
      \Big( z(t) - x(t) \Big)$$

Our goal is to *minimize* this function.


# Deriving the Adjoint Method

### Our goal: computing a derivative

For convenience later, we choose an arbitrary timestep $q$. 

Our goal is to modify our atmospheric forcing $f(q)$ to improve our simulation (in other words, reducing $J$). This can be best represented by asking, "how does modifying $f(q)$ affect $J$?" This question is answered by the derivative,

$$\frac{ d J}{ d f(q)}$$

We can use this to directly compute an adjustment to $f(q)$, to improve our estimate. So, this derivative is our goal.



### Using our model

How does $f(q)$ affect $J$? It doesn't directly show up in the equation for $J$.

- Rather, it *indirectly* affects $J$, by modifying the (simulated) ocean state, $x(t)$.

This effect is represented by our equation for simulating forward in time:

$$x(t+1) = Mx(t) + Ff(t)$$

$f(q)$ influences the next state $x(q+1)$, which contributes to $J$. But, we're forgetting a second way that $f(q)$ can affect $J$: by affecting *future states*.

- While $f(q)$ only directly affects $x(q+1)$, we use $x(q+1)$ to compute $x(q+2)$. We can then use $x(q+2)$ to compute $x(q+3)$, and so on.
- So, $f(q)$ affects all of our future states! 
- By affecting each of these states, $f(q)$ can affect $J$ at $\tau - q  $ different states.

We can account for all of these terms using the multivariable chain rule:

$$\frac{ d J}{ d f(q)} \quad {\LARGE=}\quad  \sum_{t = q+1}^{\tau} \frac{dx(q+1)}{df(q)} \cdot \frac{dx(t)}{dx(q+1)}  \cdot \frac{\partial J}{ \partial x(t)} $$

We know how to compute each of these terms: the first and third terms are known matrix derivatives, so we'll put them off until later.

It's useful to think of this in a second way: above, we've listed every way that $x(q+1)$ can affect $J$. We have a *total derivative* of $J$ with respect to $x(q+1)$.

$$\frac{ d J}{ d f(q)} \quad {\LARGE=}\quad  \frac{dx(q+1)}{df(q)} 
\Bigg( \sum_{t = q+1}^{\tau} \frac{dx(t)}{dx(q+1)}  \cdot \frac{\partial J}{ \partial x(t)} \Bigg) 
\quad {\LARGE=}\quad 
\frac{dx(q+1)}{df(q)} \Bigg( \frac{dJ}{dx(q+1)} \Bigg)$$

### Redundant calculations

This technique gets the job done, but it can be inefficient to use for multiple timesteps: we have a lot of duplicate calculations. Consider an example:

- $f(1)$ and $f(2)$ both affect $x(3)$, which in turn affects $J$. Thus, both equations require $\frac{dJ}{dx(3)}$.

$$\frac{ d J}{ d f(1)} 
\quad{\LARGE=}\quad  
\frac{dx(2)}{df(1)} \Bigg(  \overbrace{\frac{dJ}{dx(2)}}^{\text{Total effect of $x(2)$}}\Bigg) 
\quad{\LARGE=}\quad 
\frac{dx(2)}{df(1)} \Bigg( 
    \overbrace{
        \frac{\partial J}{ \partial x(2)}
     }^{\text{ $x(2)$ effect by itself}}
+ 
\overbrace{
\frac{dx(3)}{dx(2)} \textcolor{red}{\frac{dJ}{dx(3)}}
}^{\text{$x(2)$ effect via future timesteps}} \Bigg) 
$$

$$\frac{ d J}{ d f(2)} 
\quad{\LARGE=}\quad  
\frac{dx(3)}{df(1)} \Bigg(  \red{\frac{dJ}{dx(3)}}\Bigg) 
$$




### The Adjoint Method: Base Case

It seems that, in the above case, it would make sense to compute $dJ/dx(3)$ first, so we can re-use it for computing $dJ/dx(2)$.

- But if we just showed that $dJ/dx(3)$ is used for twice, doesn't it make sense that the same is true for $dJ/dx(4)$?
  - If we use an identical argument to before, we could show that computing $dJ/dx(3)$ involves computing $dJ/dx(4)$.
  - So, we should handle $dJ/dx(4)$ first.

We can use the same logic over and over, going further forward in time: it seems we're reusing a lot of calculations! 

The natural conclusion is for us to start with the very last timestep, $dJ/dx(\tau)$.

- Because there are no future timesteps, $x(\tau)$ can only affect $J$ directly:

$$\frac{d J}{d x(\tau)} = \frac{\partial J}{ \partial x(\tau)}$$





### The Adjoint Method: Recursion

Now, we can move one step **backwards** in time, using the equation we wrote above:

$$ \textcolor{red}{\frac{d J}{d x(\tau-1)}} 
\quad{\LARGE=}\quad
\overbrace{\frac{\partial J}{ \partial x(\tau-1)}}^{\text{ $x(\tau-1)$ effect by itself}} + 
\overbrace{\frac{dx(\tau) }{dx(\tau-1)}\textcolor{red}{\frac{d J}{d x(\tau)}}}^{\text{$x(\tau-1)$ effect via $x(\tau)$}}$$

To make things clearer, we'll rename the variable we're recursively building up:

$$ \lambda_t = \frac{d J}{d x(t)} $$

Rewriting our equation:

$$ \textcolor{red}{\lambda_{\tau-1}} = \frac{\partial J}{ \partial x(\tau-1)} + \frac{dx(\tau) }{dx(\tau-1)} \textcolor{red}{\lambda_{\tau}}$$

We get something that looks like a **recursive** relation: $\lambda_{\tau-1}$ references the next element in the sequence, $\lambda_{\tau}$. As we move further back in time, we find the exact same equation, confirming our suspicions. If we write it in general, we get:

$$ \lambda_{t} = 
\begin{cases}
\frac{\partial J}{ \partial x(t)} + \frac{dx(t+1) }{dx(t)} \lambda_{t+1} & \text{ if } t < \tau \\\\
\frac{\partial J}{ \partial x(t)} & t = \tau
\end{cases}
$$

These are our **adjoint variables**.

### Using the adjoint

We can find our adjoint variables by moving backwards in time: we start by computing $\lambda_{\tau}$, and begin decrementing through $t = \tau-1, \tau-2,..., 2,1$.

Once we've finished, it's easy to compute our final derivatives:

$$
\frac{ d J}{ d f(q)} \quad {\LARGE=}\quad  \frac{dx(q+1)}{df(q)} \lambda_{q+1}  
$$

If we apply this to our model ($x(t+1) = Mx(t) + Ff(t)$), we find that $ \frac{dx(t+1) }{dx(t)} = M^\top $ 

$$ \lambda_{t} = 
\begin{cases}
\frac{\partial J}{ \partial x(t)} + M^\top \lambda_{t+1} & \text{ if } t < \tau \\\\
\frac{\partial J}{ \partial x(t)} & t = \tau
\end{cases}
$$

If we work through the induction, we can simplify this to:

$$ \lambda_{k} = \sum_{i=k}^\tau \Bigg( (M^{i-k})^\top \frac{\partial J}{ \partial x(i)} \Bigg)
$$

And finally:

$$
\frac{ d J}{ d f(q)}
\quad {\LARGE=}\quad
F \lambda_{q+1} 
\quad {\LARGE=}\quad
F \sum_{i=q+1}^\tau \Bigg( (M^{i-q-1})^\top \frac{\partial J}{ \partial x(i)} \Bigg)
$$

Notice that the last forcing, $f(\tau)$, actually has no effect on our loss: it would be applied to a future state $x(\tau+1)$, that doesn't exist.

- In the above equation, this would refer to some non-existent $\lambda_{\tau+1}$.

### Why is the adjoint useful?

Something worth addressing:

**Q:** *Couldn't we have computed the answer in our original form, without invoking the adjoint? We could've just plugged values into the chain rule we started with.*

In this particular case, this is true. However, this is only simple, because our model takes on such a simple form, where we can multiply by $A^T$ repeatedly to get our answer.

In many situations, our model can be too complex to get an analytical derivative. So, instead, we might use a more demanding approach, like **finite difference approximation**:

- Modify one variable of $f(q)$ and simulate the whole model, seeing how the loss changes.

- We repeat this process for each variable in $f(q)$, to get the overall derivative.

- Then, we have to repeat *all* of that, for every timestep $q$.

Using the adjoint method, we can significantly cut down on the work we have to do:

- First, we compute the adjoint variables $\lambda_t$: this requires computing our derivatives $\partial J/\partial x(t)$ and $\partial x(t+1)/\partial x(t)$. 

  - $\partial J/\partial x(t)$ can be gotten directly from the loss function.
  - $\partial x(t+1)/x(t)$ only requires simulating one timestep forward, for each variable.

Since we have to simulate between each pair of timesteps $t$ and $t+1$, this is equivalent to running through the whole model once (per variable in $x$).

Once we've done that, we don't need to run the whole simulation for each $f(q)$: we only have to run one timestep, to see how it affects $x(q+1)$.

We can think of this as "pre-simulating" the effect that our states have on the loss, so that we only have to see how $f(q)$ affects the first in that chain of timesteps: $x(q+1)$.


# Implementing the Adjoint Method

Now that we understand the adjoint method (for our particular case), let's implement it.

A few more details:

- Our $f$ is time-invariant: we're using the same $f(q)$ for all timesteps.
- This isn't too much of a problem: we can just add up the derivative contributions over all timesteps.

$$
\frac{ d J}{ d f}
\quad {\LARGE=}\quad
\sum_{q=1}^{\tau-1}
F \lambda_{q+1} 
$$

And one last thing:

- We only observe some pixels: so, we simply ignore the remaining pixels. We leave them as NaNs, and are careful to exclude them from our calculations.

### Compute Loss

Let's start by defining the loss we're going to be computing:

In [5]:
def compute_Jt(xt_true, xt_guess): 
    """
    Computes squared loss between two vectors at time t.
    
    Args:
    xt_true (np.ndarray): True state vector at time t
    xt_guess (np.ndarray): Guessed state vector at time t
    
    Returns:
    float: Squared loss, or 0 if no valid terms
    """
    
    # Sum over all valid terms, using numpy to treat nans as zeros
    result = np.nansum((xt_true - xt_guess)**2)
    if np.isnan(result): 
        return 0
    else: 
        return result

def compute_J(x_true, x_guess): 
    """
    Computes total squared loss between two vectors across all timesteps.
    
    Args:
    x_true (list): List of true state vectors at each timestep
    x_guess (list): List of guessed state vectors at each timestep
    
    Returns:
    float: Total squared loss across all timesteps
    """
    return np.sum([
                compute_Jt(x_true[i], x_guess[i]) for i in range(len(x_true))]
                )





### Compute Derivatives

Now, we can begin computing the derivatives:

In [6]:
def compute_DJ_Dxt(xt_true, xt_guess):
    """
    Computes partial derivative of squared loss w.r.t. guessed state at time t.
    
    Args:
    xt_true (np.ndarray): True state vector at time t
    xt_guess (np.ndarray): Guessed state vector at time t
    
    Returns:
    np.ndarray: Partial derivative of loss, with NaNs treated as 0
    """
    return np.nan_to_num( 2*(xt_guess - xt_true), nan = 0 )

Based on the percent error, our analytical (using chain rule) and empirical (taking finite difference) calculations of $\frac{dJ}{dx(t)}$ values are very similar.

### Compute Adjoint

Using these derivatives, you can build up the adjoint, as described above:

In [7]:
def compute_adjoints(DJ_Dx, dxtp1_dxt):
    """
    Computes adjoint variables for optimization using backwards-time recursion.

    Args:
    DJ_Dx (list): List of partial derivatives of loss w.r.t. state at each timestep
    dxtp1_dxt (list): List of total derivatives of next state w.r.t. current state at each timestep

    Returns:
    list: Adjoint variables for each timestep, in forward time order
    """

    tau = len(DJ_Dx)
    adjoints = [0] * tau # Initialize list of adjoints

    adjoints[tau-1] = DJ_Dx[tau-1]
    
    for t in range(tau-2, -1, -1):  # Backwards in time
        adjoint = DJ_Dx[t] + dxtp1_dxt[t] @ adjoints[t+1]

        adjoints[t] = adjoint

    return adjoints


Our goal is to use the adjoint to compute $\frac{dJ}{df}$. So, let's try that now:

In [8]:
def compute_dJ_df(M, F, observed_state_over_time, simulated_state_over_time):
    """
    Computes the gradient of the loss with respect to the forcing field f for the linear model:
    x(t+1) = Mx(t) + Ff

    Args:
    M (np.ndarray): Model matrix
    F (float): Forcing coefficient
    observed_state_over_time (list): List of observed states at each timestep
    simulated_state_over_time (list): List of simulated states at each timestep

    Returns:
    np.ndarray: Gradient of the loss with respect to the forcing field f
    """
    num_timesteps = len(observed_state_over_time)
    vec_length = len(observed_state_over_time[0])

    #Compute adjoints
    DJ_Dx = [compute_DJ_Dxt(observed_state_over_time[i], simulated_state_over_time[i])
             for i in range(num_timesteps)] # partial J / partial x(t)
    
    dxtp1_dxt = [M.T for i in range(num_timesteps-1)] #dx(t+1)/dx(t)

    adjoints = compute_adjoints(DJ_Dx, dxtp1_dxt) # dJ/dx(t) = lambda(t)

    # Compute gradient for each timestep: how f being applied at time t affects J
    dJ_dft = [ F * adjoint for adjoint in adjoints[1:] ] # dJ/df(t) = dx(t+1)/df(t) dJ/dx(t+1)
    dJ_dft.append(np.zeros((vec_length,1))) # dJ/df(tau) = 0 

    #f is applied the same at all timesteps
    dJ_df = np.sum(dJ_dft, axis=0) # dJ/df = sum_t dJ/df(t)
    return dJ_df



Using the adjoint, we now have a complete pipeline for computing $dJ/df$. In the next notebook, we'll use this in our gradient descent implementations.

# Verifying our Adjoint Implementation

This code is difficult to visualize, to ensure that it's (approximately) correct.

- Our solution is to **numerically** compute the derivatives, and compare them to what we find above.

Some code to evaluate the functionality of ```compute_DJ_Dxt```:

In [9]:
def test_DJ_Dxt(observed_state_over_time, simulated_state_over_time, num_timesteps, epsilon = 1e-6):
    """
    This function tests ( partial J /partial xt ) by comparing it to a finite-difference approximation.
    It checks the gradient of J with respect to x(t) for each timestep.
    """

    percent_errors = []
    vec_length = len(observed_state_over_time[0])

    ### Compute analytic gradient using expression for loss
    DJ_Dx = [compute_DJ_Dxt(observed_state_over_time[t], simulated_state_over_time[t])
                 for t in range(num_timesteps)]

    ### Compute numerical gradient using finite differences
    for t in range(num_timesteps):
        DJ_Dxt_approx = np.zeros((vec_length, 1))

        for i in range(vec_length):
            # Perturb x(t) in each dimension
            state_t_plus = simulated_state_over_time[t].copy()
            state_t_minus = simulated_state_over_time[t].copy()

            state_t_plus[i] += epsilon
            state_t_minus[i] -= epsilon

            # Compute how perturbation affects J
            J_plus  = compute_Jt(observed_state_over_time[t], state_t_plus)
            J_minus = compute_Jt(observed_state_over_time[t], state_t_minus)

            # Compute finite difference approximation
            DJ_Dxt_approx[i] = (J_plus - J_minus) / (2*epsilon)

        DJ_Dxt = DJ_Dx[t]

        # Compute percent error between analytic (loss fn) and numerical (finite-difference) gradients
        percent_error = 100*np.linalg.norm(DJ_Dxt - DJ_Dxt_approx) / np.linalg.norm(DJ_Dxt)
        percent_errors.append(percent_error)

    return percent_errors

num_timesteps = 4
# Call the function with all necessary arguments
percent_errors = test_DJ_Dxt(observed_state_over_time, guess_state_over_time, num_timesteps)

# Print results

for t, percent_error in enumerate(percent_errors):
    print(f"Percent error at timestep {t}: {percent_error:.8f}%")



Percent error at timestep 0: 0.00000001%
Percent error at timestep 1: 0.00000002%
Percent error at timestep 2: 0.00000002%
Percent error at timestep 3: 0.00000002%


Some code to evaluate the functionality of ```compute_adjoints```:

In [10]:
def test_adjoint(observed_state_over_time, x0, M, F, f_guess, num_timesteps, nr, nc, epsilon = 1e-6):
    """
    This function tests compute_adjoints by comparing it to a numerical approximation of the adjoints.
    Note that the adjoints are dJ/dx(t) = lambda(t).
    """
    percent_errors = []

    vec_length = len(observed_state_over_time[0])

    _, simulated_state_over_time = helper.compute_affine_time_evolution_simple(x0, M, F*f_guess, num_timesteps)

    ### Compute analytical gradient using adjoint method
    DJ_Dx = [compute_DJ_Dxt(observed_state_over_time[i], simulated_state_over_time[i])
             for i in range(num_timesteps)]
    
    dxtp1_dxt = [M.T for _ in range(num_timesteps-1)]

    adjoints = compute_adjoints(DJ_Dx, dxtp1_dxt)[:num_timesteps]

    ### Compute numerical gradient using finite differences
    for t in range(num_timesteps): # For each timestep
        adjoint_approx = np.zeros((vec_length, 1))

        for i in range(vec_length): # For each element in x(t)
            #Perturb state x(t) in each dimension
            state_t       = simulated_state_over_time[t]
            state_t_plus  = state_t.copy()
            state_t_minus = state_t.copy()

            state_t_plus[i] += epsilon 
            state_t_minus[i] -= epsilon


            #Simulate starting from time t through tau=num_timesteps
            _, plus_state_t_tau = helper.compute_affine_time_evolution_simple(state_t_plus,   M, F*f_guess, num_timesteps-t)
            _, minus_state_t_tau = helper.compute_affine_time_evolution_simple(state_t_minus, M, F*f_guess, num_timesteps-t)

            observed_state_t_tau = observed_state_over_time[t:num_timesteps]

            # Compute how perturbation affects J
            J_plus =  compute_J(observed_state_t_tau,  plus_state_t_tau)
            J_minus = compute_J(observed_state_t_tau, minus_state_t_tau)

            # Compute finite difference approximation
            adjoint_approx[i] = (J_plus - J_minus) / (2*epsilon)

        #Compute percent error between analytic (adjoint) and numerical (finite-difference)
        adjoint = adjoints[t]

        percent_error = 100 * np.linalg.norm(adjoint - adjoint_approx) / np.linalg.norm(adjoint)
        percent_errors.append(percent_error)

    return percent_errors


num_timesteps = 4
# Call the function with all necessary arguments
percent_errors = test_adjoint(observed_state_over_time, x0, M, F, f_guess, num_timesteps, nr, nc)

# Print results
for t, percent_error in enumerate(percent_errors):
    print(f"Percent error at timestep {t}: {percent_error:.8f}%")




Percent error at timestep 0: 0.00000028%
Percent error at timestep 1: 0.00000020%
Percent error at timestep 2: 0.00000009%
Percent error at timestep 3: 0.00000002%


Some code to evaluate the functionality of ```compute_dJ_df```:

In [11]:
def dJ_df_check(observed_state_over_time, x0, M, F, f_guess, num_timesteps, 
               epsilon = 1e-6):
    """
    This function tests compute_dJ_df by comparing it to a numerical approximation of the gradient.
    """

    vec_length = len(observed_state_over_time[0])

    observed_state_over_time     = observed_state_over_time[:num_timesteps]
    _, simulated_state_over_time = helper.compute_affine_time_evolution_simple(x0, M, F*f_guess, num_timesteps)

    ### Compute analytical gradient using adjoint method + chain rule
    dJ_df = compute_dJ_df(M, F, observed_state_over_time, simulated_state_over_time)

    ### Compute numerical gradient using finite differences
    dJ_df_approx = np.zeros((vec_length,1))
    for i in range(vec_length):
        f_plus  = f_guess.copy()
        f_minus = f_guess.copy()

        f_plus[i]  += epsilon
        f_minus[i] -= epsilon

        # Simulate future states
        _, plus_state_over_time  = helper.compute_affine_time_evolution_simple(x0, M, F*f_plus, num_timesteps)
        _, minus_state_over_time = helper.compute_affine_time_evolution_simple(x0, M, F*f_minus, num_timesteps)

        

        # Compute the loss for the modified state
        J_plus  = compute_J(observed_state_over_time, plus_state_over_time)
        J_minus = compute_J(observed_state_over_time, minus_state_over_time)

        # Compute the partial derivative
        dJ_df_approx[i] = (J_plus - J_minus) / (2*epsilon)

    percent_error = 100 * np.linalg.norm(dJ_df - dJ_df_approx) / np.linalg.norm(dJ_df)

    return percent_error

num_timesteps = 4
# Call the function with all necessary arguments
percent_error = dJ_df_check(observed_state_over_time, x0, M, F, f_guess, num_timesteps)

print(f"Percent error: {percent_error:.8f}")


Percent error: 0.00000069


In [13]:
num_timesteps = 10  # Adjust this to match your actual number of timesteps
DJ_Dx = [compute_DJ_Dxt(observed_state_over_time[i], guess_state_over_time[i])
         for i in range(num_timesteps)]
dxtp1_dxt = [M.T for _ in range(num_timesteps-1)]
adjoints = compute_adjoints(DJ_Dx, dxtp1_dxt)

# Check the size of the adjoint
print(f"Number of adjoints: {len(adjoints)}")
print(f"Shape of each adjoint: {adjoints[0].shape}")

print(M.shape)

Number of adjoints: 10
Shape of each adjoint: (1024, 1)
(1024, 1024)
