In [None]:
import jax
# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
jax.config.update('jax_enable_x64', True)
from data_generator import *
from visualizer import *
from cnn_settings import *

In [None]:
train_ref_data, test_ref_data = read_train_test_dataset(10)

In [None]:
inspect_data(train_ref_data)

In [None]:
init_rng = jax.random.key(0)
learning_rate = 0.00001
momentum = 1
from tqdm import tqdm
my_unet = UNet()
state = create_train_state(my_unet, init_rng, learning_rate, momentum)
low_res_lbm_solver = instantiate_simulator(5, False, transfer_output=False, quiet=True)

In [None]:
def loss_fn(params, batched_f):
    _, high_res_u = low_res_lbm_solver.vmapped_update_macroscopic(batched_f)
    high_res_u = vmapped_normalize_frame(high_res_u)
    low_res_step_output = low_res_lbm_solver.vmapped_run_step(0, batched_f)
    correction = state.apply_fn({'params': params}, low_res_step_output['u'][0])
    loss = optax.l2_loss(normalize_frame(low_res_step_output['u'][0]) + 0.01 * correction, high_res_u).sum()
    return loss

In [None]:
epochs=20
min_loss = 100000
optimal_params = None
batch_size=32
test_batch_size = 2
for j in range(epochs):
    train_pbar = tqdm(range(1000))
    select_ts = jnp.array(np.random.choice(train_ref_data['timestep'].shape[0], batch_size)) 
    batched_data = {
        'f_poststreaming': train_ref_data['f_poststreaming'][select_ts],
        'u': train_ref_data['u'][select_ts],
        'timestep': train_ref_data['timestep'][select_ts],
    }
    test_select_ts = jnp.array(np.random.choice(train_ref_data['timestep'].shape[0], test_batch_size)) 
    test_batched_data = {
        'f_poststreaming': test_ref_data['f_poststreaming'][test_select_ts],
        'u': test_ref_data['u'][test_select_ts],
        'timestep': test_ref_data['timestep'][test_select_ts],
    }
    for i in train_pbar:
      # Run optimization steps over training batches and compute batch metrics
      state, train_loss = train_step(state, batched_data, low_res_lbm_solver) # get updated train state (which contains the updated parameters)
      cur_loss = loss_fn(state.params, test_batched_data['f_poststreaming'])
      if cur_loss < min_loss:
          min_loss = cur_loss
          optimal_params = state.params
      train_pbar.set_description("min loss: {:.5f}, train loss : {:.5f}, test loss : {:.5f}".format(min_loss, train_loss, cur_loss)) # aggregate batch metrics

In [None]:
init_frames = read_data_and_downsample(1, 8, 'init_frames')

In [None]:
low_res_ref_data = generate_sim_dataset(5, 0, 500, 0, 1, init_frames['f_poststreaming'][1])

In [None]:
visualize_data(test_ref_data)
visualize_data(low_res_ref_data)

In [None]:
import optax
import matplotlib.pyplot as plt

normalized_test_ref_u = vmapped_normalize_frame(test_ref_data['u'])
normalized_low_res_u = vmapped_normalize_frame(low_res_ref_data['u'])

y = optax.l2_loss(normalized_test_ref_u, normalized_low_res_u).sum(axis=(1, 2, 3))
y_star = optax.l2_loss(normalized_low_res_u+state.apply_fn({'params': state.params}, normalized_low_res_u), normalized_test_ref_u).sum(axis=(1, 2, 3))
x = np.array(range(y.shape[0]))

plt.plot(x, y)
plt.plot(x, y_star)

In [None]:
state.apply_fn({'params': state.params}, low_res_ref_data['u'])