In [1]:
from google.colab import drive
drive.mount('/content/drive')

import sys
dir = '/content/drive/MyDrive/ARIA'
sys.path.append(dir)

import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

import time

Mounted at /content/drive


# Recurrent Closure

$$\begin{align}
c^{t+1} &= c^t + A(c^t) + D(c^t) + R(c^t) + F_1 (c^t, A(c^t), D(c^t), T, h^t)\\
h^{t+1} &= F_2 (c^{t+1}, h^t)
\end{align} $$

In [2]:
import numpy as np
import scipy
import random
import matplotlib.pyplot as plt

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from flax.training import train_state  # Useful dataclass to keep train state
import flax

from functools import partial
import pickle

from numerical_methods import physics, ssim
import dataset.read_data_to_grid as rdtg


os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [3]:
dx, dy = 0.5, 0.5
ny, nx = 26, 49
dt = 0.4
NSUBSTEPS = 4

with open(f'{dir}/dataset/dataset_v2.pickle', 'rb') as handle:
  dataset = pickle.load(handle)

# train_set = [6, 12,  0, 8, 21,  2,  7, 15, 17, 10, 19,  5, 23, 13,  4]
train_set = list(range(len(dataset['SMOKE_FIELD'])))

INFLOW_LOCS = dataset['INFLOW_LOCS'][train_set,]
FLOW_TIMES = [jnp.concatenate((jnp.zeros((1,)),jnp.array(dataset['FLOW_TIMES'][i]))) for i in train_set]
SMOKE_FIELD = [jnp.concatenate((jnp.zeros((1,ny,nx)), jnp.array(dataset['SMOKE_FIELD'][i])), axis=0) for i in train_set]
VELOCITY = dataset['VELOCITY']
REL_LOC = dataset['rel_loc']
TERRAIN = dataset['TERRAIN']
SPLINE_TCK = [dataset['SPLINE_TCK'][i] for i in train_set]

INFLOW_VALS = []
for set_idx in range(len(SMOKE_FIELD)):
  y,x = INFLOW_LOCS[set_idx]
  inflow_for_set = []
  for i in range(len(REL_LOC)):
    rel_y, rel_x = REL_LOC[i]
    inflow_for_set_for_loc = []
    for nt in range(1, 6001):
      val = scipy.interpolate.splev(nt*dt, SPLINE_TCK[set_idx][i])
      inflow_for_set_for_loc.append(val)
    inflow_for_set.append(inflow_for_set_for_loc)
  INFLOW_VALS.append(inflow_for_set)

INFLOW_VALS = jnp.array(INFLOW_VALS)

# FILENAME_PREFIX = f'{dir}/15042025_phase3_cnnfc'

diff_x_mask = jnp.logical_and(TERRAIN[:,1:],TERRAIN[:,:-1])
diff_y_mask = jnp.logical_and(TERRAIN[1:,:],TERRAIN[:-1,:])

diff_x_grad_x_mask = jnp.logical_and(diff_x_mask[:,1:],diff_x_mask[:,:-1])
diff_x_grad_y_mask = jnp.logical_and(diff_x_mask[1:,:],diff_x_mask[:-1,:])
diff_y_grad_x_mask = jnp.logical_and(diff_y_mask[:,1:],diff_y_mask[:,:-1])
diff_y_grad_y_mask = jnp.logical_and(diff_y_mask[1:,:],diff_y_mask[:-1,:])

diff_x_mask = diff_x_mask.astype(jnp.float32)
diff_y_mask = diff_y_mask.astype(jnp.float32)
diff_x_grad_x_mask = diff_x_grad_x_mask.astype(jnp.float32)
diff_x_grad_y_mask = diff_x_grad_y_mask.astype(jnp.float32)
diff_y_grad_x_mask = diff_y_grad_x_mask.astype(jnp.float32)
diff_y_grad_y_mask = diff_y_grad_y_mask.astype(jnp.float32)

In [4]:
grid = rdtg.read_data(f"{dir}/dataset/turb_vis.txt")
turb_vis = rdtg.extract_xyz_to_array(grid, x_range=(0, 48), y_range=(0, 25), z_range=(1,1), yxz=True)
turb_diff = jnp.array(turb_vis/0.7)

turb_diff_x = ((turb_diff[:,1:] + turb_diff[:,:-1])/2).squeeze()
turb_diff_y = ((turb_diff[1:,:] + turb_diff[:-1,:])/2).squeeze()

turb_diff_x = jnp.where(diff_x_mask, turb_diff_x, 0)
turb_diff_y = jnp.where(diff_y_mask, turb_diff_y, 0)

792 rows of data read
791 data points used for grid
Checksum: 7.059849280200001


----------

## Neural Networks

In [5]:
class ClosureNet(nn.Module):
    def setup(self):
        # for decoding the hidden state
        self.convTrans1 = nn.ConvTranspose(features=128, kernel_size=(3, 2), strides=(3, 2), padding='SAME')
        self.conv11 = nn.Conv(features=128, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.convTrans2 = nn.ConvTranspose(features=64, kernel_size=(2, 2), strides=(2, 2), padding='SAME')
        self.conv21 = nn.Conv(features=64, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.convTrans3 = nn.ConvTranspose(features=32, kernel_size=(2, 3), strides=(2, 3), padding='SAME')
        self.conv31 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.convTrans4 = nn.ConvTranspose(features=16, kernel_size=(2, 3), strides=(2, 3), padding='SAME')
        self.conv41 = nn.Conv(features=16, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())

        # for transforming the output
        self.conv1 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv2 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv3 = nn.Conv(features=1, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())


    def __call__(self, x, h):
        # h is a list of hidden states
        dec = jnp.concatenate((h[0],h[1],h[2],h[3]))
        dec = jnp.expand_dims(dec, axis=(0,1))

        dec = self.convTrans1(dec)                          # (3,2)
        dec = nn.leaky_relu(dec, negative_slope=0.01)       # (3,2)
        dec = nn.leaky_relu(self.conv11(dec), negative_slope=0.01)+dec       # (3,2)

        dec = self.convTrans2(dec)                          # (6,4)
        dec = nn.leaky_relu(dec, negative_slope=0.01)       # (6,4)
        dec = jnp.pad(dec, ((0, 0), (0, 1), (0, 0)), mode='edge')        # (6,5)
        dec = nn.leaky_relu(self.conv21(dec), negative_slope=0.01)+dec       # (6,5)

        dec = self.convTrans3(dec)                                       # (12,15)
        dec = nn.leaky_relu(dec, negative_slope=0.01)                    # (12,15)
        dec = jnp.pad(dec, ((0, 1), (1, 0), (0, 0)), mode='edge')        # (13,16)
        dec = nn.leaky_relu(self.conv31(dec), negative_slope=0.01)+dec   # (13,16)

        dec = self.convTrans4(dec)                                       # (26,48)
        dec = jnp.pad(dec, ((0, 0), (0, 1), (0, 0)), mode='edge')        # (26,49)
        dec = nn.leaky_relu(self.conv41(dec), negative_slope=0.01)+dec   # (26,49)

        output = jnp.concatenate((dec, x), axis=-1)
        output = nn.leaky_relu(self.conv1(output), negative_slope=0.01) + jnp.pad(output, ((0,0),(0,0),(6,0)))
        output = nn.leaky_relu(self.conv2(output), negative_slope=0.01) + output
        output = self.conv3(output)
        return output

In [6]:
class HiddenNet(nn.Module):
    def setup(self):
        # to encode the scalar field
        self.conv1 = nn.Conv(features=16, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv11 = nn.Conv(features=16, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv2 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv21 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv3 = nn.Conv(features=64, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv31 = nn.Conv(features=64, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv4 = nn.Conv(features=128, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv41 = nn.Conv(features=128, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())

        self.lstm1 = nn.OptimizedLSTMCell(features=128)
        self.lstm2 = nn.OptimizedLSTMCell(features=128)
        self.lstm3 = nn.OptimizedLSTMCell(features=128)
        self.lstm4 = nn.OptimizedLSTMCell(features=128)


    def __call__(self, x, h, c):
        x = self.conv1(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = self.conv11(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = nn.max_pool(x, window_shape=(2, 3), strides=(2, 3))

        x = self.conv2(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = self.conv21(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = nn.max_pool(x, window_shape=(2, 3), strides=(2, 3))

        x = self.conv3(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = self.conv31(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))

        x = self.conv4(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = self.conv41(x)
        x = nn.leaky_relu(x, negative_slope=0.01)
        x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2))
        # x now contains the encoded current smoke state (input)

        # correction: lstm takes (c,h)
        (c1, h1), x1 = self.lstm1((c[0], h[0]), x)
        (c2, h2), x2 = self.lstm2((c[1], h[1]), x1)
        (c3, h3), x3 = self.lstm3((c[2], h[2]), x2)
        (c4, h4), x4 = self.lstm4((c[3], h[3]), x3)

        return jnp.array([h1.squeeze(), h2.squeeze(), h3.squeeze(), h4.squeeze()]), \
               jnp.array([c1.squeeze(), c2.squeeze(), c3.squeeze(), c4.squeeze()])

In [7]:
class InflowNet(nn.Module):
    def setup(self):
        self.conv1 = nn.Conv(features=8, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv2 = nn.Conv(features=16, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv3 = nn.Conv(features=32, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())
        self.conv4 = nn.Conv(features=64, kernel_size=(3,3), strides=1, padding='SAME', kernel_init=nn.initializers.xavier_uniform())

        self.fc1 = nn.Dense(features=64, kernel_init=nn.initializers.xavier_uniform())
        self.fc2 = nn.Dense(features=64, kernel_init=nn.initializers.xavier_uniform())
        self.fc3 = nn.Dense(features=4, kernel_init=nn.initializers.xavier_uniform())

    def __call__(self, x):
        # Encoder
        enc1 = self.conv1(x)                                  # (26,49)
        enc1 = nn.leaky_relu(enc1, negative_slope=0.01)       # (26,49)

        enc2 = nn.max_pool(enc1, (2, 3), strides=(2, 3))      # (13,16)
        enc2 = self.conv2(enc2)                               # (13,16)
        enc2 = nn.leaky_relu(enc2, negative_slope=0.01)       # (13,16)

        enc3 = nn.max_pool(enc2, (2, 3), strides=(2, 3))      # (6,5)
        enc3 = self.conv3(enc3)                               # (6,5)
        enc3 = nn.leaky_relu(enc3, negative_slope=0.01)       # (6,5)

        enc4 = nn.max_pool(enc3, (2, 2), strides=(2, 2))      # (3,2)
        enc4 = self.conv4(enc4)                               # (3,2)
        enc4 = nn.leaky_relu(enc4, negative_slope=0.01)       # (3,2)

        bottleneck = nn.max_pool(enc4, (3, 2), strides=(3, 2))            # (1,1)

        output = nn.leaky_relu(self.fc1(bottleneck.squeeze()))
        output = nn.leaky_relu(self.fc2(output))
        output = self.fc3(output)
        return 10**(-5+jnp.tanh(output))

## Phase 3

## Simulation Function

In [8]:
hidden_net = HiddenNet()
closure_net = ClosureNet()
inflow_net = InflowNet()

In [9]:
@jax.jit
def conv_diff_single_step(params,
                          smoke_initial: jnp.array,
                          hidden_state: jnp.array,
                          cell_state: jnp.array,
                          velocity: jnp.array,
                          time_curr: float,
                          inflow_loc: jnp.array,
                          inflow_vals: jnp.array,
                          terrain: jnp.array,
                          dt: float):
    y,x = inflow_loc
    inflow_marker = jnp.zeros((26,49))
    for i in range(len(REL_LOC)):
        rel_y, rel_x = REL_LOC[i]

        # set smoke at inflow locations
        inflow_marker = inflow_marker.at[y+rel_y, x+rel_x].set(1.0)

    # compute inflow term
    input = jnp.stack((terrain,
                       velocity[0,:,:],
                       velocity[1,:,:],
                       velocity[2,:,:],
                       jnp.pad(params['diffusivity_x'], ((0, 0), (1, 0))),
                       jnp.pad(params['diffusivity_y'], ((1, 0), (0, 0))),
                       inflow_marker), axis=-1) #(4,)
    inflow = inflow_net.apply({'params': params['inflow']}, input)
    inflow_term = jnp.zeros((26,49))
    for i in range(len(REL_LOC)):
        rel_y, rel_x = REL_LOC[i]
        inflow_term = inflow_term.at[y+rel_y, x+rel_x].set(inflow[i])

    smoke_pred = smoke_initial
    for _ in range(NSUBSTEPS):
      advection_step = physics.advect_fvm(field=smoke_pred,
                                              velocity=velocity,
                                              dx=dx,
                                              dy=dy) * dt/NSUBSTEPS

      diffusion_step = physics.diffuse_2d_fvm(field=smoke_pred,
                                              diffusivity_x=params['diffusivity_x'],
                                              diffusivity_y=params['diffusivity_y'],
                                              dx=dx,
                                              dy=dy) * dt/NSUBSTEPS

      inflow_step = inflow_term * dt/NSUBSTEPS
      smoke_pred = smoke_pred + advection_step + diffusion_step + inflow_step

    # compute closure term
    smoke_mean = jnp.mean(smoke_initial)
    eps = 1e-8
    smoke_std = jnp.std(smoke_initial)
    smoke_std_safe = jnp.where(smoke_std > eps, smoke_std, eps)
    input = jnp.stack(((smoke_initial-smoke_mean)/smoke_std_safe,
                        terrain * smoke_mean,
                        terrain * smoke_std,
                        velocity[0,:,:],
                        velocity[1,:,:],
                        velocity[2,:,:],
                        jnp.pad(params['diffusivity_x'], ((0, 0), (1, 0))),
                        jnp.pad(params['diffusivity_y'], ((1, 0), (0, 0))),
                        inflow_term,
                        terrain), axis=-1)  # (26,49,8)

    # (26,49,1)
    output = closure_net.apply({'params': params['closure']}, input, hidden_state)
    # print(jnp.mean(output), jnp.std(output))
    closure_term = output * 1e-10        # a denormalization value, from previous tests
    smoke_pred = (smoke_pred\
                    + closure_term.squeeze()) * terrain

    time_next = time_curr + dt

    # smoke_pred = jnp.maximum(smoke_pred, 0.0)

    # update hidden states
    mean = jnp.mean(smoke_pred)
    std = jnp.std(smoke_pred)
    input = jnp.stack(((smoke_pred-mean)/std,
                        TERRAIN * mean,
                        TERRAIN * std), axis=-1)  # (26,49,4)
    hidden_pred, cell_pred = hidden_net.apply({'params': params['hidden']}, input, hidden_state, cell_state)   # (128,)

    return (smoke_pred, hidden_pred, cell_pred, time_next, params, inflow_loc, inflow_vals), jnp.sum(closure_term**2)

step_for_loop = lambda carry, x: conv_diff_single_step(params=carry[4],
                                                        smoke_initial=carry[0],
                                                        hidden_state=carry[1],
                                                        cell_state=carry[2],
                                                        velocity=VELOCITY,
                                                        time_curr=carry[3],
                                                        inflow_loc=carry[5],
                                                        inflow_vals=carry[6],
                                                        terrain=TERRAIN,
                                                        dt=dt)

@partial(jax.jit, static_argnames=['nsteps'])
def conv_diff_nsteps(params,
                     smoke_initial: jnp.array,
                     hidden_initial: jnp.array,
                     cell_initial: jnp.array,
                     time_initial,
                     inflow_loc,
                     inflow_vals,
                     nsteps):
    (smoke_pred, hidden_pred, cell_pred, _, params, _, _), _ = jax.lax.scan(step_for_loop, (smoke_initial, hidden_initial, cell_initial, time_initial, params, inflow_loc, inflow_vals), xs=None, length=nsteps)
    return smoke_pred, hidden_pred, cell_pred

In [10]:
def create_train_state(params_nn, params_physics, learning_rate):
    params = {}
    params['closure'] = params_nn['closure']
    params['hidden'] = params_nn['hidden']
    params['inflow'] = params_nn['inflow']
    params['diffusivity_x'] = params_physics['diffusivity_x']
    params['diffusivity_y'] = params_physics['diffusivity_y']

    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=None,
                                         params=params,
                                         tx=tx)

In [11]:
with open(f'{dir}/dataset/diffusivity.pickle', 'rb') as handle:
    params_physics = pickle.load(handle)

params_nn = {}
params_nn['closure'] = closure_net.init(jax.random.PRNGKey(0), jnp.zeros((ny, nx, 10)), [jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,))])['params']
params_nn['hidden'] = hidden_net.init(jax.random.PRNGKey(0), jnp.zeros((ny, nx, 3)), [jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,))], [jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,)), jnp.zeros((128,))])['params']
params_nn['inflow'] = inflow_net.init(jax.random.PRNGKey(0), jnp.zeros((1, ny, nx, 7)))['params']

state = create_train_state(params_nn=params_nn, params_physics=params_physics,learning_rate=1e-4)

  params_physics = pickle.load(handle)


________


In [None]:
@jax.jit
def conv_diff_single_step(params,
                          smoke_initial: jnp.array,
                          hidden_state: jnp.array,
                          cell_state: jnp.array,
                          velocity: jnp.array,
                          time_curr: float,
                          inflow_loc: jnp.array,
                          inflow_vals: jnp.array,
                          terrain: jnp.array,
                          dt: float):
    y,x = inflow_loc
    inflow_marker = jnp.zeros((26,49))
    for i in range(len(REL_LOC)):
        rel_y, rel_x = REL_LOC[i]

        # set smoke at inflow locations
        inflow_marker = inflow_marker.at[y+rel_y, x+rel_x].set(1.0)

    # compute inflow term
    input = jnp.stack((terrain,
                       velocity[0,:,:],
                       velocity[1,:,:],
                       velocity[2,:,:],
                       jnp.pad(params['diffusivity_x'], ((0, 0), (1, 0))),
                       jnp.pad(params['diffusivity_y'], ((1, 0), (0, 0))),
                       inflow_marker), axis=-1) #(4,)
    inflow = inflow_net.apply({'params': params['inflow']}, input)
    inflow_term = jnp.zeros((26,49))
    for i in range(len(REL_LOC)):
        rel_y, rel_x = REL_LOC[i]
        inflow_term = inflow_term.at[y+rel_y, x+rel_x].set(inflow[i])

    smoke_pred = smoke_initial
    advection_term = jnp.zeros((ny, nx))
    diffusion_term = jnp.zeros((ny, nx))

    for _ in range(NSUBSTEPS):
      advection_step = physics.advect_fvm(field=smoke_pred,
                                              velocity=velocity,
                                              dx=dx,
                                              dy=dy) * dt/NSUBSTEPS

      diffusion_step = physics.diffuse_2d_fvm(field=smoke_pred,
                                              diffusivity_x=params['diffusivity_x'],
                                              diffusivity_y=params['diffusivity_y'],
                                              dx=dx,
                                              dy=dy) * dt/NSUBSTEPS

      inflow_step = inflow_term * dt/NSUBSTEPS
      smoke_pred = smoke_pred + advection_step + diffusion_step + inflow_step
      advection_term = advection_term + advection_step
      diffusion_term = diffusion_term + diffusion_step

    # compute closure term
    smoke_mean = jnp.mean(smoke_initial)
    eps = 1e-8
    smoke_std = jnp.std(smoke_initial)
    smoke_std_safe = jnp.where(smoke_std > eps, smoke_std, eps)
    input = jnp.stack(((smoke_initial-smoke_mean)/smoke_std_safe,
                        terrain * smoke_mean,
                        terrain * smoke_std,
                        velocity[0,:,:],
                        velocity[1,:,:],
                        velocity[2,:,:],
                        jnp.pad(params['diffusivity_x'], ((0, 0), (1, 0))),
                        jnp.pad(params['diffusivity_y'], ((1, 0), (0, 0))),
                        inflow_term,
                        terrain), axis=-1)  # (26,49,8)

    # (26,49,1)
    output = closure_net.apply({'params': params['closure']}, input, hidden_state)
    # print(jnp.mean(output), jnp.std(output))
    closure_term = output * 1e-10        # a denormalization value, from previous tests
    smoke_pred = (smoke_pred\
                    + closure_term.squeeze()) * terrain

    time_next = time_curr + dt

    # smoke_pred = jnp.maximum(smoke_pred, 0.0)

    # update hidden states
    mean = jnp.mean(smoke_pred)
    std = jnp.std(smoke_pred)
    input = jnp.stack(((smoke_pred-mean)/std,
                        TERRAIN * mean,
                        TERRAIN * std), axis=-1)  # (26,49,4)
    hidden_pred, cell_pred = hidden_net.apply({'params': params['hidden']}, input, hidden_state, cell_state)   # (128,)

    return (smoke_pred, hidden_pred, cell_pred, time_next, params, inflow_loc, inflow_vals, advection_term, diffusion_term, closure_term, inflow_term), None

In [None]:
with open(f'{dir}/trained_models/expt1_baseline_rep1.pickle', 'rb') as bunch:
  state_dict = pickle.load(bunch)
state = flax.serialization.from_state_dict(state, state_dict)

In [None]:
# with jax.disable_jit():
for set_idx in range(len(SMOKE_FIELD)):

  smoke_terms = []
  advection_terms = []
  diffusion_terms = []
  closure_terms = []
  error_terms = []
  rel_error_terms = []
  inflow_terms = []
  hidden_terms = []

  smoke_curr = SMOKE_FIELD[set_idx][0]
  time_curr = FLOW_TIMES[set_idx][0]
  hidden_curr, cell_curr = jnp.zeros((4,128)), jnp.zeros((4,128))

  n = len(FLOW_TIMES[set_idx])-1
  for times_idx in range(len(FLOW_TIMES[set_idx])-1):
    nsteps = int((FLOW_TIMES[set_idx][times_idx+1]-FLOW_TIMES[set_idx][times_idx]+1e-3)/dt)

    for i in range(nsteps):
      (smoke_curr, hidden_curr, cell_curr, time_curr, _, _, _, advection_term, diffusion_term, closure_term, inflow_term), _ = conv_diff_single_step(params=state.params,
                                                                                                                                smoke_initial=smoke_curr,
                                                                                                                                hidden_state=hidden_curr,
                                                                                                                                cell_state=cell_curr,
                                                                                                                                velocity=VELOCITY,
                                                                                                                                time_curr=time_curr,
                                                                                                                                inflow_loc=INFLOW_LOCS[set_idx],
                                                                                                                                inflow_vals=INFLOW_VALS[set_idx],
                                                                                                                                terrain=TERRAIN,
                                                                                                                                dt=dt)
      smoke_terms.append(smoke_curr)
      advection_terms.append(advection_term)
      diffusion_terms.append(diffusion_term)
      closure_terms.append(closure_term)
      inflow_terms.append(inflow_term)
      error_terms.append(optax.l2_loss(smoke_curr, SMOKE_FIELD[set_idx][times_idx+1]))
      hidden_terms.append(hidden_curr)

      denom = SMOKE_FIELD[set_idx][times_idx+1]
      loss1 = optax.l2_loss(smoke_curr, SMOKE_FIELD[set_idx][times_idx+1])/(SMOKE_FIELD[set_idx][times_idx+1]+1e-9)**2
      rel_error_terms.append(loss1)


  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_smoke.pickle', 'wb') as f:
    pickle.dump(smoke_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_advection.pickle', 'wb') as f:
    pickle.dump(advection_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_diffusion.pickle', 'wb') as f:
    pickle.dump(diffusion_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_closure.pickle', 'wb') as f:
    pickle.dump(closure_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_inflow.pickle', 'wb') as f:
    pickle.dump(inflow_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_relerror.pickle', 'wb') as f:
    pickle.dump(rel_error_terms, f)
  with open(f'{dir}/videos/22072025_{INFLOW_LOCS[set_idx]}_hidden.pickle', 'wb') as f:
    pickle.dump(hidden_terms, f)

---