In [None]:
from google.colab import drive
drive.mount('/content/drive')

ValueError: mount failed

In [None]:
import sys
sys.path.append('/content/drive/MyDrive/ARIA')

# Combined Training

In [None]:
import numpy as np
import random
import matplotlib.pyplot as plt

import jax
import optax
from flax import linen as nn
from jax import numpy as jnp
from flax.training import train_state  # Useful dataclass to keep train state
import flax

from functools import partial
import pickle

from numerical_methods import physics

In [None]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"

In [None]:
dx, dy = 0.5, 0.5
ny, nx = 26, 49
dt=0.1

In [None]:
TIMES = jnp.load("/content/drive/MyDrive/ARIA/dataset/times.npy")

TERRAIN = jnp.load("/content/drive/MyDrive/ARIA/dataset/terrain.npy")

INFLOW_LOCS = jnp.load("/content/drive/MyDrive/ARIA/dataset/inflow_locations.npy")

rel_loc = [(0, -1), (0, 0), (1, -1), (1, 0)]
INFLOW_COEFFS = jnp.load("/content/drive/MyDrive/ARIA/dataset/inflow_poly_coeffs.npy")

TRAIN_SET = jnp.load("/content/drive/MyDrive/ARIA/dataset/train_smoke_grid.npy")

VELOCITY = jnp.load("/content/drive/MyDrive/ARIA/dataset/velocity_2d.npy")


In [None]:
with open('/content/drive/MyDrive/ARIA/ckpt/phase_2/unet/params_physics.pickle', 'rb') as handle:
    params_physics = pickle.load(handle)
with open('/content/drive/MyDrive/ARIA/ckpt/phase_2/unet/params_rcnn.pickle', 'rb') as handle:
    params_rcnn = pickle.load(handle)

## Neural Networks

In [None]:
class UNet(nn.Module):
    features: int = 16

    def setup(self):
        self.encoder1 = self._conv_block(self.features)
        self.encoder2 = self._conv_block(self.features * 2)
        self.encoder3 = self._conv_block(self.features * 4)
        self.encoder4 = self._conv_block(self.features * 8)

        self.bottleneck = self._conv_block(self.features * 16)

        self.upconv4 = self._upconv_block(self.features * 8)
        self.upconv3 = self._upconv_block(self.features * 4)
        self.upconv2 = self._upconv_block(self.features * 2)
        self.upconv1 = self._upconv_block(self.features)

        self.final_conv = nn.Conv(
            features=1, kernel_size=(1, 1), kernel_init=nn.initializers.xavier_uniform(), use_bias=True
        )

    def _conv_block(self, features):
        return nn.Sequential([
            nn.Conv(features, kernel_size=(3, 3), padding='SAME', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu,
            nn.Conv(features, kernel_size=(3, 3), padding='SAME', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu
        ])

    def _upconv_block(self, features):
        return nn.Sequential([
            nn.ConvTranspose(features, kernel_size=(2, 2), strides=(2, 2), padding='VALID', kernel_init=nn.initializers.xavier_uniform()),
            nn.relu
        ])

    def __call__(self, x):
        # Encoder
        enc1 = self.encoder1(x)
        enc2 = self.encoder2(nn.max_pool(enc1, window_shape=(2, 3), strides=(2, 3)))
        enc3 = self.encoder3(nn.max_pool(enc2, window_shape=(2, 3), strides=(2, 3)))
        enc4 = self.encoder4(nn.max_pool(enc3, window_shape=(2, 2), strides=(2, 2)))

        # Bottleneck
        bottleneck = self.bottleneck(nn.max_pool(enc4, window_shape=(2, 2), strides=(2, 2)))
        # Decoder
        dec4 = self.upconv4(bottleneck)
        dec4 = jax.image.resize(dec4, shape=(1, enc4.shape[1], enc4.shape[2], dec4.shape[-1]), method='bilinear')
        dec4 = jnp.concatenate([dec4, enc4], axis=-1)

        dec3 = self.upconv3(dec4)
        dec3 = jax.image.resize(dec3, shape=(1, enc3.shape[1], enc3.shape[2], dec3.shape[-1]), method='bilinear')
        dec3 = jnp.concatenate([dec3, enc3], axis=-1)

        dec2 = self.upconv2(dec3)
        dec2 = jax.image.resize(dec2, shape=(1, enc2.shape[1], enc2.shape[2], dec2.shape[-1]), method='bilinear')
        dec2 = jnp.concatenate([dec2, enc2], axis=-1)

        dec1 = self.upconv1(dec2)
        dec1 = jax.image.resize(dec1, shape=(1, enc1.shape[1], enc1.shape[2], dec1.shape[-1]), method='bilinear')
        dec1 = jnp.concatenate([dec1, enc1], axis=-1)

        output = self.final_conv(dec1)

        return output

In [None]:
hidden_net = UNet(features=16)
closure_net = UNet(features=8)

## Simulation Function

In [None]:
@jax.jit
def conv_diff_single_step(params,
                          smoke_initial: jnp.array,
                          hidden_state: jnp.array,
                          velocity: jnp.array,
                          time_curr: float,
                          inflow_loc: jnp.array,
                          inflow_poly_coeffs: jnp.array,
                          terrain: jnp.array,
                          dt: float):

    # set smoke at inflow locations
    y,x = inflow_loc
    for i in range(len(rel_loc)):
        rel_y, rel_x = rel_loc[i]
        c = jnp.polyval(inflow_poly_coeffs[i], time_curr)
        smoke_initial = smoke_initial.at[y+rel_y, x+rel_x].set(c)

    advection_term = physics.advect_corr_fvm(field=smoke_initial,
                                             velocity=velocity,
                                             u_corr=params['u_corr'],
                                             v_corr=params['v_corr'],
                                             dx=dx,
                                             dy=dy) * dt

    diffusion_term = physics.diffuse_2d_fvm(field=smoke_initial,
                                            diffusivity_x=params['diffusivity_x'],
                                            diffusivity_y=params['diffusivity_y'],
                                            dx=dx,
                                            dy=dy) * dt


    # compute closure term
    input = jnp.stack(((smoke_initial-jnp.mean(smoke_initial))/jnp.std(smoke_initial),
                        (advection_term-jnp.mean(advection_term))/jnp.std(advection_term),
                        (diffusion_term-jnp.mean(diffusion_term))/jnp.std(diffusion_term),
                        TERRAIN,
                        hidden_state), axis=-1)
    input = jnp.expand_dims(input, 0)

    output = closure_net.apply({'params': params['closure']}, input)
    closure_term = output * 1e-9        # a denormalization value, from previous tests

    smoke_pred = smoke_initial\
                    + advection_term\
                    + diffusion_term\
                    + closure_term.squeeze()

    time_next = time_curr + dt
    for i in range(len(rel_loc)):
        rel_y, rel_x = rel_loc[i]
        c = jnp.polyval(inflow_poly_coeffs[i], time_next)
        smoke_pred = smoke_pred.at[y+rel_y, x+rel_x].set(c)

    smoke_pred = smoke_pred * terrain
    smoke_pred = jnp.maximum(smoke_pred, 0.0)

    # update hidden state
    mean = jnp.mean(smoke_pred)
    std = jnp.std(smoke_pred)
    input = jnp.stack(((smoke_pred-mean)/std, hidden_state), axis=-1)
    input = jnp.expand_dims(input, 0)
    output = hidden_net.apply({'params': params['hidden']}, input)

    hidden_pred = output.squeeze()

    return (smoke_pred, hidden_pred, time_next, params, inflow_loc, inflow_poly_coeffs), None


step_for_loop = lambda carry, x: conv_diff_single_step(params=carry[3],
                                                        smoke_initial=carry[0],
                                                        hidden_state=carry[1],
                                                        velocity=VELOCITY,
                                                        time_curr=carry[2],
                                                        inflow_loc=carry[4],
                                                        inflow_poly_coeffs=carry[5],
                                                        terrain=TERRAIN,
                                                        dt=dt)

@partial(jax.jit, static_argnames=['nsteps'])
def conv_diff_nsteps(params,
                     smoke_initial: jnp.array,
                     hidden_initial: jnp.array,
                     time_initial,
                     inflow_loc,
                     inflow_poly_coeffs,
                     nsteps):
    (smoke_pred, hidden_pred, _, params, _, _), _ = jax.lax.scan(step_for_loop, (smoke_initial, hidden_initial, time_initial, params, inflow_loc, inflow_poly_coeffs), xs=None, length=nsteps)
    return smoke_pred, hidden_pred

## Training

In [None]:
def create_train_state(params_rcnn, params_physics, learning_rate):
    params = {}
    params['closure'] = params_rcnn['closure']
    params['hidden'] = params_rcnn['hidden']
    params['diffusivity_x'] = params_physics['diffusivity_x']
    params['diffusivity_y'] = params_physics['diffusivity_y']
    params['u_corr'] = params_physics['u_corr']
    params['v_corr'] = params_physics['v_corr']


    tx = optax.adam(learning_rate)
    return train_state.TrainState.create(apply_fn=conv_diff_nsteps,
                                         params=params,
                                         tx=tx)

# @partial(jax.jit, static_argnames=['nsteps'])
def train_step(state,
                smoke_initial,
                hidden_initial,
                time_initial,
                inflow_loc,
                inflow_poly_coeffs,
                nsteps,
                smoke_target):

    def loss_fn(params):
        smoke_pred, hidden_pred = state.apply_fn(params,
                                         smoke_initial=smoke_initial,
                                         hidden_initial=hidden_initial,
                                         time_initial=time_initial,
                                         inflow_loc=inflow_loc,
                                         inflow_poly_coeffs=inflow_poly_coeffs,
                                         nsteps=nsteps)

        # L2 loss
        # loss = jnp.sum(optax.l2_loss(smoke_predicted, smoke_target))
        # cosine loss
        # loss = 1 - jnp.sum(smoke_predicted * smoke_target)/jnp.sqrt(jnp.sum(smoke_predicted**2)*jnp.sum(smoke_target**2))
        # relative loss
        # denom = jnp.maximum(smoke_predicted, smoke_target)
        # denom = jnp.where(denom<1e-8, 1.0, denom)
        # loss = jnp.mean(optax.l2_loss(smoke_predicted, smoke_target)/denom**2)
        # relative loss 2
        denom = smoke_target
        denom = jnp.where(denom<1e-8, 1.0, denom)
        loss = jnp.mean(optax.l2_loss(smoke_pred, smoke_target)/denom**2)
        return loss, (smoke_pred, hidden_pred)

    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, (smoke_pred, hidden_pred)), grads = grad_fn(state.params)

    ## gradient clipping
    clipper = optax.clip_by_global_norm(0.5)
    clip_state = clipper.init(grads)
    new_grads, clip_state = clipper.update(grads, clip_state)

    state = state.apply_gradients(grads=new_grads)
    return state, loss, smoke_pred, hidden_pred

In [None]:
####################
#Using LR scheduler#
####################

max_epoch = 100
itr_per_epoch = len(TIMES)-1
max_iter = itr_per_epoch * max_epoch

# optimizer: ADAM learning rate scheduler
init_lr = 1e-4
peak_lr = 1e-3
lr_scheduler = optax.warmup_cosine_decay_schedule(init_value=init_lr,
                                                  peak_value=peak_lr,
                                                  warmup_steps=int(max_iter*.5),
                                                  decay_steps=max_iter,
                                                  end_value=1e-5)
####################
#Using LR scheduler#
####################

state = create_train_state(params_rcnn=params_rcnn,
                           params_physics=params_physics,
                           learning_rate=1e-4)

In [13]:
EPOCHS = 100

losses = []
min_loss = float('inf')

order = list(range(4))
np.random.seed(0)

for epoch in range(EPOCHS):
    random.shuffle(order)

    epoch_loss = 0.0
    for set_idx in order:
        hidden_curr = jnp.zeros((ny, nx))
        smoke_curr = TRAIN_SET[set_idx][0]
        for times_idx in range(len(TIMES)-1):
            nsteps = int((TIMES[times_idx+1]-TIMES[times_idx]+1e-3)/dt)
            state, loss, smoke_curr, hidden_curr = train_step(state,
                                                            smoke_initial=smoke_curr,
                                                            hidden_initial=hidden_curr,
                                                            time_initial=TIMES[times_idx],
                                                            inflow_loc=INFLOW_LOCS[set_idx],
                                                            inflow_poly_coeffs=INFLOW_COEFFS[set_idx],
                                                            nsteps=nsteps,
                                                            smoke_target=TRAIN_SET[set_idx][times_idx+1])

            print(f"EPOCH {epoch}, starting from {set_idx, times_idx} loss={loss}")
            epoch_loss += loss

    #v1_1: lr 1e-4 adam, gradient clipping 1., all 23 time steps
    if epoch_loss < min_loss:
        state_dict = flax.serialization.to_state_dict(state)
        with open('/content/drive/MyDrive/ARIA/v1_1_phase3_unet_min.pickle', 'wb') as handle:
            pickle.dump(state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        np.save("/content/drive/MyDrive/ARIA/v1_1_phase3_unet_min", losses)
        min_loss = epoch_loss
        print("New minimum achieved")

    losses.append(epoch_loss)
    print(f"EPOCH {epoch}: loss={epoch_loss}")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
EPOCH 46, starting from (2, 1) loss=0.0015443362285962286
EPOCH 46, starting from (2, 2) loss=0.0017194597539849246
EPOCH 46, starting from (2, 3) loss=0.007986229659822205
EPOCH 46, starting from (2, 4) loss=0.6651948807744463
EPOCH 46, starting from (2, 5) loss=0.10797150324254322
EPOCH 46, starting from (2, 6) loss=0.019372209041167692
EPOCH 46, starting from (2, 7) loss=0.014305823386165982
EPOCH 46, starting from (2, 8) loss=0.016258692847266998
EPOCH 46, starting from (2, 9) loss=0.016476425997996213
EPOCH 46, starting from (2, 10) loss=0.00756306781555201
EPOCH 46, starting from (2, 11) loss=0.006521091687348004
EPOCH 46, starting from (2, 12) loss=0.01043118009801319
EPOCH 46, starting from (2, 13) loss=0.014379570273085097
EPOCH 46, starting from (2, 14) loss=0.018320753260896012
EPOCH 46, starting from (2, 15) loss=0.029700134924685605
EPOCH 46, starting from (2, 16) loss=0.025939401431398456
EPOCH 46, starting 

In [14]:
state_dict = flax.serialization.to_state_dict(state)
with open('/content/drive/MyDrive/ARIA/v1_1_phase3_unet_final.pickle', 'wb') as handle:
    pickle.dump(state_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
np.save("/content/drive/MyDrive/ARIA/v1_1_phase3_unet_final", losses)