In [None]:
import jax
# Use 8 CPU devices
# os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'
jax.config.update('jax_enable_x64', True)
from data_generator import *
from visualizer import *
from cnn_settings import *

In [None]:
train_ref_data, test_ref_data = read_train_test_dataset(10)
train_ref_data_std = jnp.std(train_ref_data['u'], axis=(0, 1, 2))

In [None]:
inspect_data(train_ref_data)

In [None]:
init_rng = jax.random.key(0)
learning_rate = 0.1
momentum = 0.1
from tqdm import tqdm
my_unet = UNet()
state = create_train_state(my_unet, init_rng, learning_rate, momentum)
low_res_lbm_solver = instantiate_simulator(5, True, transfer_output=False, quiet=True)
low_res_lbm_solver.set_state(state)

In [None]:
def loss_fn(params, batch_data_f, high_res_ref_data_u):
    batched_f = batch_data_f
    high_res_u = high_res_ref_data_u
    low_res_step_output = low_res_lbm_solver.vmapped_run_step(0, batched_f, params)
    loss = optax.l2_loss(low_res_step_output['u'][0], high_res_u).sum()
    return loss*100

In [None]:
def frame_to_img(frame):
    return np.concatenate([frame, np.zeros((frame.shape[0], frame.shape[1], 1))], axis=2)

plt.imshow(frame_to_img(state.apply_fn({'params': optimal_params}, high_res_u[0])))

In [None]:
epochs=20
min_loss = 100000
optimal_params = None
batch_size=64
test_batch_size = 16
for j in range(epochs):
    train_pbar = tqdm(range(1000))
    select_ts = np.random.choice(train_ref_data['timestep'].shape[0]-1, batch_size+test_batch_size, replace=False)
    batched_data_f = train_ref_data['f_poststreaming'][select_ts[:batch_size]]
    high_res_ref_data_u = train_ref_data['u'][select_ts[:batch_size]+1]
    test_batched_data = train_ref_data['f_poststreaming'][select_ts[batch_size:]]
    test_high_res_ref_data_u = train_ref_data['u'][select_ts[batch_size:]+1]
    for i in train_pbar:
      # Run optimization steps over training batches and compute batch metrics
      state, train_loss = train_step(state, batched_data_f, high_res_ref_data_u, low_res_lbm_solver) # get updated train state (which contains the updated parameters)
      cur_loss = loss_fn(state.params, test_batched_data, test_high_res_ref_data_u)
      if cur_loss < min_loss:
          min_loss = cur_loss
          optimal_params = state.params
      train_pbar.set_description("min loss: {:.5f}, train loss : {:.5f}, test loss : {:.5f}".format(min_loss, train_loss, cur_loss)) # aggregate batch metrics

In [None]:
init_frames = read_data_and_downsample(1, 8, 'init_frames')

In [None]:
low_res_ref_data = generate_sim_dataset(5, 0, 500, 0, 1, init_frames['f_poststreaming'][0])
low_res_lbm_solver.set_params(state.params)
cnn_corrected_res_data = generate_sim_dataset(5, 0, 500, 0, 1, init_frames['f_poststreaming'][0], solver=low_res_lbm_solver)

In [None]:
low_res_ref_data['u'][10].min(), low_res_ref_data['u'][10].max()

In [None]:
train_ref_data['u'][10].min(), train_ref_data['u'][10].max()

In [None]:
cnn_corrected_res_data['u'][10].min(), cnn_corrected_res_data['u'][10].max()

In [None]:
visualize_data(train_ref_data)
visualize_data(low_res_ref_data)
visualize_data(cnn_corrected_res_data)

In [None]:
import optax
import matplotlib.pyplot as plt

normalized_test_ref_u = vmapped_normalize_frame(train_ref_data['u'])
normalized_low_res_u = vmapped_normalize_frame(low_res_ref_data['u'])

y = optax.l2_loss(train_ref_data['u'], low_res_ref_data['u'][:100]).sum(axis=(1, 2, 3))
y_star = optax.l2_loss(normalized_low_res_u+state.apply_fn({'params': state.params}, normalized_low_res_u), normalized_test_ref_u).sum(axis=(1, 2, 3))
x = np.array(range(y.shape[0]))

plt.plot(x, y)
# plt.plot(x, y_star)

In [None]:
state.apply_fn({'params': state.params}, low_res_ref_data['u'])