In [None]:
import jax
# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
jax.config.update('jax_enable_x64', True)
from data_generator import *
from visualizer import *
from cnn_settings import *

In [None]:
train_ref_data, test_ref_data = read_train_test_dataset()

In [None]:
visualize_data(train_ref_data)
visualize_data(test_ref_data)

In [None]:
init_rng = jax.random.key(0)
learning_rate = 0.0001
momentum = 0.0009
from tqdm import tqdm
my_unet = SimpleNet()
state = create_train_state(my_unet, init_rng, learning_rate, momentum)
low_res_lbm_solver = instantiate_simulator(5, False, transfer_output=False, quiet=True)

In [None]:
state = train_step(state, train_ref_data, low_res_lbm_solver, 1)
def loss_fn(params, f):
    _, high_res_u = low_res_lbm_solver.update_macroscopic(f)
    high_res_u = normalize_frame(high_res_u)
    input_f = normalize_frame(f)
    low_res_step_output = low_res_lbm_solver.run_step(0, input_f)
    correction = state.apply_fn({'params': params}, low_res_lbm_solver.saved_data['u'][0])
    loss = optax.l2_loss(normalize_frame(low_res_step_output['u'][0]) + correction, high_res_u).sum()
    return loss

In [None]:
train_pbar = tqdm(range(4000))
min_loss = 100000
optimal_params = None
select_ts = int(np.random.uniform(low=0, high=500))
for i in train_pbar:
  # Run optimization steps over training batches and compute batch metrics
  state = train_step(state, train_ref_data, low_res_lbm_solver, select_ts) # get updated train state (which contains the updated parameters)
  cur_loss = loss_fn(state.params, test_ref_data['f_poststreaming'][int(np.random.uniform(low=0, high=500))])
  if cur_loss < min_loss:
      min_loss = cur_loss
      optimal_params = state.params
  train_pbar.set_description("min loss: {:.5f}, current loss : {:.5f}".format(min_loss, cur_loss)) # aggregate batch metrics