In [None]:
import sys
sys.path.append('../')

import yaml
import pickle
from typing import Any, Union, List, Tuple, Sequence

import numpy as np
import jax.numpy as jnp
import jax
import optax
from functools import partial
from jax.tree_util import tree_map

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors

import mesher
import network
import utils
import plot_utils

Pytree = Any
_Ext = utils.Extent
plt.rcParams.update(plot_utils.high_res_plot_settings)

In [None]:
with open("./settings.yaml", "r") as file:
  config = yaml.safe_load(file)

stamp_mesh_cfg = config['STAMP_MESH']
stamp_bbox_cfg = config['STAMP_BBOX']

nn_cfg = config['IMPLICIT_NN']
train_cfg = config['VAE_TRAIN']

# Define stamp mesh

In [None]:
stamp_bbox = mesher.BoundingBox(x=_Ext(stamp_bbox_cfg['x_min'],
                                       stamp_bbox_cfg['x_max']),
                                y=_Ext(stamp_bbox_cfg['y_min'],
                                       stamp_bbox_cfg['y_max'])
                                )
stamp_mesh = mesher.Mesher(nelx=stamp_mesh_cfg['nelx'],
                           nely=stamp_mesh_cfg['nely'],
                           bounding_box=stamp_bbox
                           )

# Get train images

In [None]:
stamp_sdfs = np.load('../data/train_sdf_images.npy')

num_train_stamps = stamp_sdfs.shape[0]

fig, ax = plt.subplots(num_train_stamps, 2,  figsize=(6, 3*num_train_stamps))
for i in range(num_train_stamps):
  ax[i,0].imshow(stamp_sdfs[i,:,:,0].T, cmap='coolwarm', origin='lower')
  ax[i,1].imshow(stamp_sdfs[i,:,:,0].T < 0, cmap='coolwarm', origin='lower')

  ax[i,0].set_axis_off(); ax[i,1].set_axis_off()

fig.subplots_adjust(wspace=0.)
plt.show()

# Initialize the NN

In [None]:
latent_dim = nn_cfg['latent_dim']
implicit_hidden_dim = nn_cfg['hidden_dim']
implicit_num_layers = nn_cfg['num_layers']
implicit_siren_freq = nn_cfg['siren_freq']

np_rng = np.random.default_rng(0)
rand_key = jax.random.PRNGKey(0)

In [None]:
sdf_net = network.ConvoImplicitAutoEncoder(latent_dim=latent_dim,
                                           implicit_hidden_dim=implicit_hidden_dim,
                                           implicit_num_layers=implicit_num_layers,
                                           implicit_siren_freq=implicit_siren_freq)

In [None]:
num_epochs = train_cfg['num_epochs']
lr = train_cfg['lr']
kl_factor = train_cfg['kl_factor']


# Gradient Clipping

In [None]:
def clip_grads(grads: Pytree, max_norm = 0.01) -> Pytree:
  """
  Clips gradients to have a maximum norm of max_norm.
  
  Args:
    grads: A pytree of gradients to be clipped.
    max_norm: The maximum allowed norm for the gradients.
      
  Returns:
    A pytree of gradients with norms clipped to max_norm.
  """
  def clip(grad: jnp.ndarray) -> jnp.ndarray:
    """
    Clip a single gradient array to have a maximum norm of max_norm.
    
    Args:
      grad: A gradient array.
        
    Returns:
      A clipped gradient array.
    """
    norm = jnp.linalg.norm(grad)
    return jnp.where(norm > max_norm, grad * (max_norm / norm), grad)

  return tree_map(clip, grads)



# Train the NN

In [None]:
def train_autoencoder(train_imgs: jnp.ndarray,
                      mesh: mesher.Mesher,
                      sdf_net: network.ConvoImplicitAutoEncoder,
                      num_epochs: int,
                      kl_factor: float,
                      lr: float,
                      key: jax.Array,
                      load_save:str=None,
                      print_interval: int = 10):
  

  """
  Train a convolutional implicit autoencoder on a set of training images.

  Args:
    train_imgs: Array of training images.
    mesh: Mesher object containing mesh information.
    sdf_net: Convolutional implicit autoencoder network.
    num_epochs: Number of training epochs.
    kl_factor: Weight of the KL divergence term in the loss function.
    lr: Learning rate for the optimizer.
    key: Random key for initialization.
    load_save: Path to a file to load saved parameters from (optional).
    print_interval: Interval at which to print training progress.

  Returns:
    A tuple containing the trained network, the final parameters,
    the convergence history, the predicted images, and the final random key.
  """

  mesh_xy = mesh.elem_centers
  solver = optax.adam(lr)
  params = sdf_net.init(key, train_imgs,  mesh_xy, key)['params']

  if load_save is not None:
    with open(load_save, 'rb') as f:
      params = pickle.load(f)

  solver_state = solver.init(params)

  def predict(params, key, is_training: bool):
    return sdf_net.apply({'params': params}, train_imgs, mesh_xy, key, is_training)


  @jax.jit
  def loss_fn(params, key):

    pred_sdf, enc_mu, enc_sigma, _ = predict(params, key, is_training=True)
    pred_sdf = (pred_sdf.reshape(-1, mesh.nelx, mesh.nely, 1))

    recons_loss = jnp.mean(((pred_sdf - train_imgs)/stamp_bbox.diag_length)**2)
    kl_loss = (enc_sigma**2 + enc_mu**2 - jnp.log(enc_sigma) - 1./2.).sum()
    net_loss = recons_loss +  kl_factor*kl_loss
    
    return net_loss, (recons_loss, kl_loss, pred_sdf)

  @jax.jit
  def train_step(params, solver_state, key):
    subkey, key = jax.random.split(key)
    (loss, aux), grad = jax.value_and_grad(loss_fn, has_aux=True)(params, subkey)

    clipped_grads = clip_grads(grad)
    updates, solver_state = solver.update(clipped_grads, solver_state, params)
    params = optax.apply_updates(params, updates)
    return params, solver_state, loss, aux, key

  convg_history = {'recons_loss':[], 'kl_loss':[], 'net_loss':[]}

  for epoch in range(num_epochs):

    params, solver_state, train_loss, train_aux, key = train_step(params,
                                                                  solver_state,
                                                                  key)
    
    if epoch%print_interval == 0:
      print(f'epoch {epoch:d}, recons_loss {train_aux[0]:.2E}, kl_loss {train_aux[1]:.2E} '
            f' , net_loss {train_loss:.2E}')
      
      convg_history['recons_loss'].append(train_aux[0])
      convg_history['kl_loss'].append(train_aux[1])
      convg_history['net_loss'].append(train_loss)


  (loss, train_aux), grad = jax.value_and_grad(loss_fn, has_aux=True)(params, key)
  pred_imgs = train_aux[2]

  return sdf_net, params, convg_history, pred_imgs, key

In [None]:
sdf_net, params, convg_history, pred_imgs, key = train_autoencoder(
                                          train_imgs=stamp_sdfs,
                                          mesh=stamp_mesh,
                                          sdf_net=sdf_net,
                                          num_epochs=num_epochs,
                                          kl_factor=kl_factor,
                                          lr=lr,
                                          key=rand_key,
                                          print_interval=50)


# Save the weights

In [None]:
def save_weights(file_name):
  with open(file_name, 'wb') as f:
    pickle.dump(params, f)
save_weights('../data/sdf_vae_net_weights.pkl')

# Viz the training and latent space

In [None]:
with open('../data/sdf_vae_net_weights.pkl', 'rb') as f:
  sdf_net_params_loaded = pickle.load(f)
dec_stamp_sdf, _, _, _ = sdf_net.apply({'params': sdf_net_params_loaded},
                              stamp_sdfs,
                              stamp_mesh.elem_centers,
                              False)
dec_stamp_sdf = (dec_stamp_sdf.reshape(-1, stamp_mesh.nelx, stamp_mesh.nely, 1))
print(dec_stamp_sdf.shape)
for i in range(dec_stamp_sdf.shape[0]):

  fig, ax = plt.subplots(1, 4)
  img = ax[0].imshow(dec_stamp_sdf[i,:,:,0].T <0 , cmap='coolwarm',
                     origin='lower')
  
  img = ax[1].imshow(stamp_sdfs[i,:,:,0].T <0 , cmap='coolwarm',
                     origin='lower')
  img = ax[2].imshow(dec_stamp_sdf[i,:,:,0].T , cmap='coolwarm',
                     origin='lower')
  img = ax[3].imshow(stamp_sdfs[i,:,:,0].T , cmap='coolwarm',
                     origin='lower')
  for j in range(4):
    ax[j].set_axis_off()
  plt.show()