In [1]:
from __init__ import PRP; import sys
sys.path.append(PRP)

from scripts.grad_compare import *
from scripts.utils import *

from setups.acc.acc_learning import ACCSetup
from tqdm import tqdm
import gc 
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import seaborn as sn

from jax import lax

/home/emeunier/code/Veros-Autodiff/
Differentiable Veros Experimental version
Importing core modules
 Using computational backend jax on gpu
  Kernels are compiled during first iteration, be patient
 Runtime settings are now locked



2025-10-14 11:47:32.158196: W external/xla/xla/service/gpu/nvptx_compiler.cc:836] The NVIDIA driver's CUDA version is 12.2 which is older than the PTX compiler version (12.9.86). Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
import sys
print(sys.executable)
!which python

/home/emeunier/data/conda/envs/verosenv/bin/python
/var/lib/oar/.batch_job_bashrc: line 5: /home/emeunier/.bashrc: No such file or directory
/usr/bin/python


In [3]:
def set_var(var_name, state, var_value):
    n_state = state.copy()
    vs = n_state.variables
    with n_state.variables.unlock():
        setattr(vs, var_name, var_value)
    return n_state

def pure(state, step):
    """
    Convert the state function into a "pure step" copying the input state
    """
    n_state = state.copy()
    step(n_state)  # This is a function that modifies state object inplace
    return n_state

def agg_sum(state, key_sum='temp'):
    tau = getattr(state.variables, 'tau')
    return ((getattr(state.variables, key_sum)[..., tau])**2).mean()

# Spin-up

In [4]:
# Spin-up 
warmup_steps = 200
acc = ACCSetup()
acc.setup()


def ps(state) : 
    n_state = state.copy()
    acc.step(n_state)
    return n_state

# Spin-up 
acc = ACCSetup()
acc.setup()

step_jit = jax.jit(ps)

state = acc.state.copy()
for step in tqdm(range(warmup_steps)) :
    state = step_jit(state)

Running model setup
Diffusion grid factor delta_iso1 = 0.01942284820457075
Running model setup
Diffusion grid factor delta_iso1 = 0.01942284820457075


100%|██████████| 200/200 [00:36<00:00,  5.51it/s]


# Test grad implementations

In [16]:
state = acc.state.copy()
var_name = 'temp'
var_value = state.variables.temp
step_function = lambda state : pure(state, acc.step)
agg_function = agg_sum 

In [30]:
def lossgrad_backward(state, var_name, var_value, step_function, agg_function, iterations=1):
    def loss_fn(v):
        n_state = set_var(var_name, state, v)

        def body_fn(carry, _):
            carry = step_function(carry)
            return carry, None

        n_state, _ = lax.scan(body_fn, n_state, xs=None, length=iterations)
        return agg_function(n_state)

    loss, grad = value_and_grad(loss_fn)(var_value)
    return loss, grad

In [19]:
%%time
# No compilation
_, _ = lossgrad_backward(state, var_name, var_value, step_function, agg_function, iterations=1)

CPU times: user 1min 48s, sys: 4.1 s, total: 1min 53s
Wall time: 2min 1s


In [21]:
%%time
step_function = jax.jit(step_function)
agg_function = jax.jit(agg_function)

agg_function(step_function(state))

CPU times: user 14.7 s, sys: 302 ms, total: 15 s
Wall time: 17.5 s


Array(94.41118746, dtype=float64)

In [22]:
%%time
# After compilation of the step and agg
_, _ = lossgrad_backward(state, var_name, var_value, step_function, agg_function, iterations=1)

CPU times: user 1min 54s, sys: 4.83 s, total: 1min 59s
Wall time: 2min 8s


In [24]:
%%time
# Try compiling the full experimental backward

exp_back = lambda field, iterations : lossgrad_backward(state, var_name, field, step_function, agg_function, iterations=iterations)

exp_back_jit = jax.jit(exp_back, static_argnames=['iterations'])
_, _ = exp_back_jit(var_value, iterations=1)

CPU times: user 1min 43s, sys: 5.6 s, total: 1min 49s
Wall time: 1min 50s


In [26]:
%%time
_, _ = exp_back_jit(var_value, iterations=1)

CPU times: user 25.3 ms, sys: 35 μs, total: 25.3 ms
Wall time: 24.1 ms


In [27]:
%%time
_, _ = exp_back_jit(var_value, iterations=2)

CPU times: user 2min 7s, sys: 6.12 s, total: 2min 13s
Wall time: 2min 15s


In [4]:
n_iteration = 4
var_agg = 'temp'
step_function = acc.step

In [7]:
agg_function = lambda state : agg_sum(state, key_sum=var_agg)

vjpm = vjp_grad_new_2(step_function, agg_function, var_agg)


loss_and_grad = lambda s, v: vjpm.g(s, v, iterations=n_iteration)

In [6]:
%%time
state = state.copy()
field = state.variables.temp

output_forward, gradients = loss_and_grad(state, field)

CPU times: user 2min 9s, sys: 4.68 s, total: 2min 14s
Wall time: 2min 19s


In [9]:
%%time
state = state.copy()
field = state.variables.temp

output_forward, gradients = loss_and_grad(state, field)

CPU times: user 2min 24s, sys: 4.63 s, total: 2min 29s
Wall time: 2min 17s


Backward vjp - Directional Gradients

# Using backward unrolling

In [5]:
def iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=1):
   
    n_state = set_var(var_name, state, var_value)

    current_state = n_state
    funs = []
    for i in range(iterations):
        current_state, vjp_fun  = jax.vjp(step_function, current_state)
        funs.append(vjp_fun)

    # Compute final output
    l, vjp_agg = jax.vjp(agg_function, current_state)

    # Backward pass using VJP
    cotangent = jnp.ones_like(l)

    # Backpropagate through agg_function
    ds, = vjp_agg(cotangent)

    # Backpropagate through all steps
    for vjp_fun in reversed(funs):
        ds, = vjp_fun(ds)
    return l, attrgetter(f'variables.{var_name}')(ds)

In [6]:
state = acc.state.copy()
var_name = 'temp'
var_value = state.variables.temp
step_function = lambda state : pure(state, acc.step)
agg_function = agg_sum 

In [7]:
%%time
step_function = jax.jit(step_function)
agg_function = jax.jit(agg_function)

agg_function(step_function(state))

CPU times: user 21.3 s, sys: 331 ms, total: 21.7 s
Wall time: 18.8 s


Array(94.41118746, dtype=float64)

In [9]:
%%time 

# Precache computation of jvp
current_state, vjp_fun  = jax.vjp(step_function, state)
_ = vjp_fun(state)

CPU times: user 160 ms, sys: 12.4 ms, total: 172 ms
Wall time: 184 ms


E1014 12:12:39.891762   15017 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: RESOURCE_EXHAUSTED: CUDA driver ran out of memory trying to instantiate CUDA graph with 4304 nodes and 0 conditionals (total of 0 alive CUDA graphs in the process). You can try to (a) Give more memory to CUDA driver by reducing XLA_PYTHON_CLIENT_MEM_FRACTION (b) Disable CUDA graph with 'XLA_FLAGS=--xla_gpu_enable_command_buffer=' (empty set). Original error: Failed to instantiate CUDA graph: CUDA_ERROR_OUT_OF_MEMORY: out of memory


XlaRuntimeError: RESOURCE_EXHAUSTED: CUDA driver ran out of memory trying to instantiate CUDA graph with 4304 nodes and 0 conditionals (total of 0 alive CUDA graphs in the process). You can try to (a) Give more memory to CUDA driver by reducing XLA_PYTHON_CLIENT_MEM_FRACTION (b) Disable CUDA graph with 'XLA_FLAGS=--xla_gpu_enable_command_buffer=' (empty set). Original error: Failed to instantiate CUDA graph: CUDA_ERROR_OUT_OF_MEMORY: out of memory

In [None]:
%%time
state = state.copy()
field = state.variables.temp

output_forward, gradients = iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=1)

In [None]:
%%time
output_forward, gradients = iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=3)

In [11]:
%%time
output_forward, gradients = iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=20)

CPU times: user 12.2 s, sys: 135 ms, total: 12.3 s
Wall time: 12.2 s


In [12]:
%%time
output_forward, gradients = iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=30)

CPU times: user 18.8 s, sys: 160 ms, total: 18.9 s
Wall time: 18.7 s


# Time measurement 

In [13]:
import time, jax

In [14]:
def measure_backward(state, var_name, var_value, step_function, agg_function, iterations=3):
    """
    Measures execution time for iterative_backward.
    Returns (output_forward, gradients, exec_time_s, peak_memory_MiB)
    """
    # Start timing
    start = time.time()
    output_forward, gradients = iterative_backward(state, var_name, var_value, step_function, agg_function, iterations=iterations)
    
    # Force computation to finish
    output_forward.block_until_ready() # Wait for the computation to finish
    gradients.block_until_ready()
    end = time.time()
    exec_time = end - start
    
    return output_forward, gradients, exec_time

In [15]:
stats = []
for i in range(100) : 
    output_forward, gradients, exec_time = measure_backward(state, var_name, var_value, step_function, agg_function, iterations=i)
    s = {'iterations' : i, 'time' : exec_time}
    print(s)
    stats.append(s)

{'iterations': 0, 'time': 0.05199480056762695}
{'iterations': 1, 'time': 0.6601417064666748}
{'iterations': 2, 'time': 1.2049367427825928}
{'iterations': 3, 'time': 2.162071704864502}
{'iterations': 4, 'time': 2.3600287437438965}
{'iterations': 5, 'time': 3.2862493991851807}
{'iterations': 6, 'time': 3.919020175933838}
{'iterations': 7, 'time': 4.479687213897705}
{'iterations': 8, 'time': 5.001291990280151}
{'iterations': 9, 'time': 5.62661600112915}
{'iterations': 10, 'time': 6.616050481796265}
{'iterations': 11, 'time': 6.630909442901611}
{'iterations': 12, 'time': 7.8326334953308105}
{'iterations': 13, 'time': 7.742902517318726}
{'iterations': 14, 'time': 8.763912439346313}
{'iterations': 15, 'time': 9.374128341674805}
{'iterations': 16, 'time': 10.074889183044434}
{'iterations': 17, 'time': 10.453115224838257}
{'iterations': 18, 'time': 11.549152135848999}
{'iterations': 19, 'time': 11.565499305725098}
{'iterations': 20, 'time': 12.874145269393921}
{'iterations': 21, 'time': 12.790

KeyboardInterrupt: 

In [17]:
df = pd.DataFrame(stats)

In [19]:
df.to_csv('stats_grads_tensor_experimental_101025.csv')

In [20]:
df

Unnamed: 0,iterations,time
0,0,0.051995
1,1,0.660142
2,2,1.204937
3,3,2.162072
4,4,2.360029
...,...,...
64,64,39.499899
65,65,40.561125
66,66,40.761792
67,67,41.941572


In [23]:
jax.ge

AttributeError: module 'jax' has no attribute 'device'

In [25]:
!nvcc -v

/var/lib/oar/.batch_job_bashrc: line 5: /home/emeunier/.bashrc: No such file or directory
nvcc fatal   : No input files specified; use option --help for more information
