# Recurrent Closure

$$\begin{align}
c^{t+1} &= c^t + A(c^t) + D(c^t) + R(c^t) + F_1 (c^t, A(c^t), D(c^t), T, h^t)\\
h^{t+1} &= F_2 (c^{t+1}, h^t)
\end{align} $$

In [1]:
import numpy as np
import random
import matplotlib.pyplot as plt

import jax
import optax
import jaxopt
from flax import linen as nn
from jax import numpy as jnp
from flax.training import train_state  # Useful dataclass to keep train state
import flax

from functools import partial
import pickle

from numerical_methods import physics

In [2]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [3]:
dx, dy = 0.5, 0.5
ny, nx = 26, 49
dt=0.2

In [4]:
TIMES = jnp.load("dataset/times.npy")

TERRAIN = jnp.load("dataset/terrain.npy")

INFLOW_LOCS = jnp.load("dataset/inflow_locations.npy")

rel_loc = [(0, -1), (0, 0), (1, -1), (1, 0)]
INFLOW_COEFFS = jnp.load("dataset/inflow_poly_coeffs.npy")

TRAIN_SET = jnp.load("dataset/train_smoke_grid.npy")

VELOCITY = jnp.load("dataset/velocity_2d.npy")


In [5]:
with open('params_physics.pickle', 'rb') as handle:
    params_physics = pickle.load(handle)

## Neural Networks

In [6]:
class UNet(nn.Module):
    features: int = 16

    def setup(self):
        self.encoder1 = self._conv_block(self.features)
        self.encoder2 = self._conv_block(self.features * 2)
        self.encoder3 = self._conv_block(self.features * 4)
        self.encoder4 = self._conv_block(self.features * 8)
        
        self.bottleneck = self._conv_block(self.features * 16)
        
        self.upconv4 = self._upconv_block(self.features * 8)
        self.upconv3 = self._upconv_block(self.features * 4)
        self.upconv2 = self._upconv_block(self.features * 2)
        self.upconv1 = self._upconv_block(self.features)

        self.final_conv = nn.Conv(
            features=1, kernel_size=(1, 1), kernel_init=nn.initializers.xavier_uniform(), use_bias=True
        )

    def _conv_block(self, features):
        return nn.Sequential([
            nn.Conv(features, kernel_size=(3, 3), padding='SAME', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu,
            nn.Conv(features, kernel_size=(3, 3), padding='SAME', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu
        ])
    
    def _upconv_block(self, features):
        return nn.Sequential([
            nn.ConvTranspose(features, kernel_size=(2, 2), strides=(2, 2), padding='VALID', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu
        ])

    def __call__(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.max_pool(enc1, window_shape=(2, 3), strides=(2, 3)))
        enc3 = self.encoder3(nn.max_pool(enc2, window_shape=(2, 3), strides=(2, 3)))
        enc4 = self.encoder4(nn.max_pool(enc3, window_shape=(2, 2), strides=(2, 2)))
        
        # Bottleneck
        bottleneck = self.bottleneck(nn.max_pool(enc4, window_shape=(2, 2), strides=(2, 2)))
        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = jax.image.resize(dec4, shape=(1, enc4.shape[1], enc4.shape[2], dec4.shape[-1]), method='bilinear')
        dec4 = jnp.concatenate([dec4, enc4], axis=-1)

        dec3 = self.upconv3(dec4)
        dec3 = jax.image.resize(dec3, shape=(1, enc3.shape[1], enc3.shape[2], dec3.shape[-1]), method='bilinear')
        dec3 = jnp.concatenate([dec3, enc3], axis=-1)

        dec2 = self.upconv2(dec3)
        dec2 = jax.image.resize(dec2, shape=(1, enc2.shape[1], enc2.shape[2], dec2.shape[-1]), method='bilinear')
        dec2 = jnp.concatenate([dec2, enc2], axis=-1)

        dec1 = self.upconv1(dec2)
        dec1 = jax.image.resize(dec1, shape=(1, enc1.shape[1], enc1.shape[2], dec1.shape[-1]), method='bilinear')
        dec1 = jnp.concatenate([dec1, enc1], axis=-1)

        output = self.final_conv(dec1)

        return output

In [7]:
hidden_net = UNet(features=8)
closure_net = UNet(features=8)

## Simulation Function

In [8]:
@jax.jit
def conv_diff_single_step(params_rcnn,
                          params_physics,
                          smoke_initial: jnp.array,
                          hidden_state: jnp.array,
                          velocity: jnp.array,
                          time_curr: float,
                          inflow_loc: jnp.array,
                          inflow_poly_coeffs: jnp.array,
                          terrain: jnp.array,
                          dt: float):
    
    # set smoke at inflow locations
    y,x = inflow_loc
    for i in range(len(rel_loc)):
        rel_y, rel_x = rel_loc[i]
        c = jnp.polyval(inflow_poly_coeffs[i], time_curr)
        smoke_initial = smoke_initial.at[y+rel_y, x+rel_x].set(c)
    
    advection_term = physics.advect_corr_fvm(field=smoke_initial, 
                                             velocity=velocity,
                                             u_corr=params_physics['u_corr'],
                                             v_corr=params_physics['v_corr'], 
                                             dx=dx, 
                                             dy=dy) * dt
    
    diffusion_term = physics.diffuse_2d_fvm(field=smoke_initial, 
                                            diffusivity_x=params_physics['diffusivity_x'],
                                            diffusivity_y=params_physics['diffusivity_y'],
                                            dx=dx, 
                                            dy=dy) * dt
    
    
    # compute closure term
    input = jnp.stack(((smoke_initial-jnp.mean(smoke_initial))/jnp.std(smoke_initial),
                        (advection_term-jnp.mean(advection_term))/jnp.std(advection_term),
                        (diffusion_term-jnp.mean(diffusion_term))/jnp.std(diffusion_term),
                        TERRAIN,
                        hidden_state), axis=-1)
    input = jnp.expand_dims(input, 0)
    
    output = closure_net.apply({'params': params_rcnn['closure']}, input)
    closure_term = output * 1e-9        # a denormalization value, from previous tests

    smoke_pred = smoke_initial\
                    + advection_term\
                    + diffusion_term\
                    + closure_term.squeeze()
    
    time_next = time_curr + dt
    for i in range(len(rel_loc)):
        rel_y, rel_x = rel_loc[i]
        c = jnp.polyval(inflow_poly_coeffs[i], time_next)
        smoke_pred = smoke_pred.at[y+rel_y, x+rel_x].set(c)
    
    smoke_pred = smoke_pred * terrain
    smoke_pred = jnp.maximum(smoke_pred, 0.0)

    # update hidden state
    mean = jnp.mean(smoke_pred)
    std = jnp.std(smoke_pred)
    input = jnp.stack(((smoke_pred-mean)/std, hidden_state), axis=-1)
    input = jnp.expand_dims(input, 0)
    output = hidden_net.apply({'params': params_rcnn['hidden']}, input)

    hidden_pred = output.squeeze()

    return (smoke_pred, hidden_pred, time_next, params_rcnn, inflow_loc, inflow_poly_coeffs), None


step_for_loop = lambda carry, x: conv_diff_single_step(params_rcnn=carry[3],
                                                       params_physics=params_physics,
                                                        smoke_initial=carry[0],
                                                        hidden_state=carry[1],
                                                        velocity=VELOCITY,
                                                        time_curr=carry[2],
                                                        inflow_loc=carry[4],
                                                        inflow_poly_coeffs=carry[5],
                                                        terrain=TERRAIN,
                                                        dt=dt)

@partial(jax.jit, static_argnames=['nsteps'])
def conv_diff_nsteps(params, 
                     smoke_initial: jnp.array, 
                     hidden_initial: jnp.array, 
                     time_initial, 
                     inflow_loc, 
                     inflow_poly_coeffs, 
                     nsteps):
    (smoke_pred, hidden_pred, _, params, _, _), _ = jax.lax.scan(step_for_loop, (smoke_initial, hidden_initial, time_initial, params, inflow_loc, inflow_poly_coeffs), xs=None, length=nsteps)
    return smoke_pred, hidden_pred

## Training

In [9]:
def create_train_state(learning_rate):
    params = {}
    params['closure'] = closure_net.init(jax.random.PRNGKey(0), jnp.zeros((1, ny, nx, 5)))['params']
    params['hidden'] = hidden_net.init(jax.random.PRNGKey(0), jnp.zeros((1, ny, nx, 2)))['params']
  
    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=conv_diff_nsteps,
                                         params=params,
                                         tx=tx)

# @partial(jax.jit, static_argnames=['nsteps'])
def train_step(state, 
                smoke_initial,
                hidden_initial,
                time_initial,
                inflow_loc,
                inflow_poly_coeffs,
                nsteps, 
                smoke_target):
  
    def loss_fn(params):
        smoke_pred, hidden_pred = state.apply_fn(params, 
                                         smoke_initial=smoke_initial, 
                                         time_initial=time_initial, 
                                         inflow_loc=inflow_loc, 
                                         inflow_poly_coeffs=inflow_poly_coeffs,
                                         nsteps=nsteps)
        
        # L2 loss
        # loss = jnp.sum(optax.l2_loss(smoke_predicted, smoke_target))
        # cosine loss
        # loss = 1 - jnp.sum(smoke_predicted * smoke_target)/jnp.sqrt(jnp.sum(smoke_predicted**2)*jnp.sum(smoke_target**2))
        # relative loss
        denom = smoke_target
        denom = jnp.where(denom<1e-8, 1.0, denom)
        loss = jnp.mean(optax.l2_loss(smoke_pred, smoke_target)/denom**2)
        return loss, (smoke_pred, hidden_pred)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (smoke_pred, hidden_pred)), grads = grad_fn(state.params)

    ## gradient clipping
    clipper = optax.clip_by_global_norm(1.)
    clip_state = clipper.init(grads)
    new_grads, clip_state = clipper.update(grads, clip_state)

    state = state.apply_gradients(grads=new_grads)
    return state, loss, smoke_pred, hidden_pred

In [10]:
####################
#Using LR scheduler#
####################

max_epoch = 1000
itr_per_epoch = len(TIMES)-1
max_iter = itr_per_epoch * max_epoch

# optimizer: ADAM learning rate scheduler
init_lr = 1e-4
peak_lr = 1e-3
lr_scheduler = optax.warmup_cosine_decay_schedule(init_value=init_lr, 
                                                  peak_value=peak_lr, 
                                                  warmup_steps=int(max_iter*.5),
                                                  decay_steps=max_iter, 
                                                  end_value=1e-5)
####################
#Using LR scheduler#
####################

state = create_train_state(learning_rate=1e-4)

In [11]:
EPOCHS = 1000

losses = []
min_loss = float('inf')

order = list(range(4))
np.random.seed(0)

for epoch in range(EPOCHS):
    random.shuffle(order)
    
    epoch_loss = 0.0
    for set_idx in order:
        hidden_curr = jnp.zeros((ny, nx))
        smoke_curr = TRAIN_SET[set_idx][0]
        for times_idx in range(len(TIMES)-1):
            nsteps = int((TIMES[times_idx+1]-TIMES[times_idx]+1e-3)/dt)
            state, loss, smoke_curr, hidden_curr = train_step(state, 
                                                smoke_initial=smoke_curr,
                                                hidden_initial=hidden_curr,
                                                time_initial=TIMES[times_idx],
                                                inflow_loc=INFLOW_LOCS[set_idx],
                                                inflow_poly_coeffs=INFLOW_COEFFS[set_idx],
                                                nsteps=nsteps, 
                                                smoke_target=TRAIN_SET[set_idx][times_idx+1])
        
            print(f"EPOCH {epoch}, starting from {set_idx, times_idx} loss={loss}")
            epoch_loss += loss
    
    if epoch_loss < min_loss:
        state_dict = flax.serialization.to_state_dict(state)
        with open('v1_1_phase2_rcnn_min.pickle', 'wb') as handle:
            pickle.dump(state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        np.save("v1_1_phase2_rcnn_min", losses)
        min_loss = epoch_loss
        print("New minimum achieved")
        
    losses.append(epoch_loss)
    print(f"EPOCH {epoch}: loss={epoch_loss}")

EPOCH 0, starting from (2, 0) loss=-11.32186878405473
EPOCH 0, starting from (2, 1) loss=-11.161923180128047
EPOCH 0, starting from (2, 2) loss=-11.516220463119117
EPOCH 0, starting from (2, 3) loss=-12.27718618835444
EPOCH 0, starting from (2, 4) loss=-10.53818932991176
EPOCH 0, starting from (2, 5) loss=-10.121893576600588
EPOCH 0, starting from (2, 6) loss=-9.893944107505618
EPOCH 0, starting from (2, 7) loss=-9.69031223047551
EPOCH 0, starting from (2, 8) loss=-9.528970373543684
EPOCH 0, starting from (2, 9) loss=-9.399112913583167


2024-07-28 17:33:06.887728: W external/xla/xla/service/hlo_rematerialization.cc:3005] Can't reduce memory use below 2.46GiB (2643239374 bytes) by rematerialization; only reduced to 3.79GiB (4067553838 bytes), down from 3.79GiB (4067967480 bytes) originally


EPOCH 0, starting from (2, 10) loss=-8.87038191721049
EPOCH 0, starting from (2, 11) loss=-8.550583215285243
EPOCH 0, starting from (2, 12) loss=-8.320788084347113
EPOCH 0, starting from (2, 13) loss=-8.144486094850928
EPOCH 0, starting from (2, 14) loss=-8.000772638151386
EPOCH 0, starting from (2, 15) loss=-7.881430339233438
EPOCH 0, starting from (2, 16) loss=-7.7787704982815775
EPOCH 0, starting from (2, 17) loss=-7.690126992593431
EPOCH 0, starting from (2, 18) loss=-7.611505573733784
EPOCH 0, starting from (2, 19) loss=-7.542030192845465
EPOCH 0, starting from (2, 20) loss=-7.479177842282074
EPOCH 0, starting from (2, 21) loss=-7.42277978083916
EPOCH 0, starting from (2, 22) loss=-7.371052314110673
EPOCH 0, starting from (0, 0) loss=-11.129229682012825
EPOCH 0, starting from (0, 1) loss=-11.015823407432858
EPOCH 0, starting from (0, 2) loss=-11.325766276869981
EPOCH 0, starting from (0, 3) loss=-11.484819086277207
EPOCH 0, starting from (0, 4) loss=-10.060646303256274
EPOCH 0, st

Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x7fcf04035dc0>>
Traceback (most recent call last):
  File "/home/tingkai/miniconda3/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt: 


EPOCH 8, starting from (3, 11) loss=-8.10820530004748
EPOCH 8, starting from (3, 12) loss=-7.985076678494951
EPOCH 8, starting from (3, 13) loss=-7.886866765629875
EPOCH 8, starting from (3, 14) loss=-7.808137128475026
EPOCH 8, starting from (3, 15) loss=-7.742165956425141
EPOCH 8, starting from (3, 16) loss=-7.684539619138645
EPOCH 8, starting from (3, 17) loss=-7.636270684163582
