In [3]:
from __init__ import PRP; import sys
sys.path.append(PRP + 'veros/')

from datetime import datetime
from jax import config
config.update("jax_enable_x64", True)

import jax
sys.path.append(PRP)

from scripts.load_runtime import * #Setup parameters for veros 
from setups.acc.acc_learning import ACCSetup

from tqdm import tqdm

Differentiable Veros Experimental version
Importing core modules
 Using computational backend jax on cpu
  Kernels are compiled during first iteration, be patient
 Runtime settings are now locked



# Spin-Up

In [4]:
# Spin-up 
warmup_steps = 200
acc = ACCSetup()
acc.setup()


def ps(state) : 
    n_state = state.copy()
    acc.step(n_state)
    return n_state

# Spin-up 
acc = ACCSetup()
acc.setup()

step_jit = jax.jit(ps)

state = acc.state.copy()
for step in tqdm(range(warmup_steps)) :
    state = step_jit(state)

Running model setup
Diffusion grid factor delta_iso1 = 0.01942284820457075
Running model setup
Diffusion grid factor delta_iso1 = 0.01942284820457075


100%|███████████████████████████████████████████████████████████████████████████████| 200/200 [00:07<00:00, 28.46it/s]


# Compute derivative

In [None]:

class autodiff() :
    def __init__(self, step_function, agg_function,  var_name) :
        """
            Computes derivative dL/dvar with L in R and var in R
            step_function is the function done n iterations
            agg_function is computed at the end to go from R^space -> R
            var_name : name of the variable in the state to differentiatiat w.r.t
        """
        self.agg_function = agg_function
        self.step_function = partial(autodiff.pure, step=step_function)
        self.var_name = var_name

    @staticmethod
    def pure(state, step) :
        """
            Convert the state function into a "pure step" copying the input state
        """
        n_state = state.copy()
        step(n_state)  # This is a function that modifies state object inplace
        return n_state

    @staticmethod
    def set_var(var_name, state, var_value):
        n_state = state.copy()
        vs = n_state.variables
        with n_state.variables.unlock():
            setattr(vs, var_name, var_value)
        return n_state


    @staticmethod
    def wrapper(var_value, state, step_fun, var_name, agg_func, iter):
        n_state = autodiff.set_var(var_name, state, var_value)

        for i in range(iter) :
            n_state = step_fun(n_state)

        return agg_func(n_state)


    def g(self, state, var_value, iterations=1, **kwargs):
        def loss_fn(v):
            # rollout for `iterations` steps
            n_state = autodiff.set_var(self.var_name, state, v)
            for _ in range(iterations):
                n_state = self.step_function(n_state)
            return self.agg_function(n_state)
    
        loss, grad = jax.value_and_grad(loss_fn)(var_value)
        return loss, grad

In [None]:
vjpm_nr = vjp_grad_new(acc.step, agg_function, var_dev)

vjpm_nr.step_function = jax.jit(vjpm_nr.step_function)
vjpm_nr.agg_function = jax.jit(vjpm_nr.agg_function)

vjpm_nr.step_function = jax.checkpoint(vjpm_nr.step_function) # Remat to save memory


loss_and_grad_nr = lambda s, v, it: vjpm_nr.g(s, v, iterations=it)