In [None]:
import os

import numpy as np

import jax
import jax.numpy as jnp
from jax import random, jit

from jax.experimental import mesh_utils, multihost_utils
from jax.sharding import Mesh, PartitionSpec as P

from function_diffusion.utils.data_utils import create_dataloader, BaseDataset
from function_diffusion.utils.model_utils import (
    create_optimizer,
    create_autoencoder_state,
    create_diffusion_state,
    compute_total_params,
)
from function_diffusion.utils.train_utils import  sample_ode
from function_diffusion.utils.checkpoint_utils import (
    create_checkpoint_manager,
    save_checkpoint,
    restore_checkpoint,
)

from model import Encoder, Decoder, DiT
from model_utils import create_encoder_step, create_decoder_step 
from data_utils import generate_dataset, BaseDataset, fit_damped_sine

from matplotlib import pyplot as plt

In [None]:
from configs import diffusion
config = diffusion.get_config('fae,dit')

In [None]:
def restore_fae_state(config, encoder, decoder):
    # Create learning rate schedule and optimizer
    lr, tx = create_optimizer(config)

    # Create train state
    state = create_autoencoder_state(config, encoder, decoder, tx)

    # Create checkpoint manager
    fae_job_name = f"{config.autoencoder.model_name}" + f"_{config.dataset.num_samples}_samples"

    ckpt_path = os.path.join(os.getcwd(), fae_job_name, "ckpt")
    ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

    # Restore the model from the checkpoint
    fae_state = restore_checkpoint(ckpt_mngr, state)
    print(f"Restored model {fae_job_name} from step", fae_state.step)

    return fae_state

In [None]:
# Initialize function autoencoder
encoder = Encoder(**config.autoencoder.encoder)
decoder = Decoder(**config.autoencoder.decoder)

fae_state = restore_fae_state(config, encoder, decoder)

# Initialize diffusion model
dit = DiT(**config.diffusion)
# Create learning rate schedule and optimizer
lr, tx = create_optimizer(config)

# Create diffusion train state
state = create_diffusion_state(config, dit, tx)
num_params = compute_total_params(state)
print(f"Model storage cost: {num_params * 4 / 1024 / 1024:.2f} MB of parameters")

# Device count
num_local_devices = jax.local_device_count()
num_devices = jax.device_count()
print(f"Number of devices: {num_devices}")
print(f"Number of local devices: {num_local_devices}")

job_name = f"{config.diffusion.model_name}"
job_name += f"_{config.dataset.num_samples}_samples"

ckpt_path = os.path.join(os.getcwd(), job_name, "ckpt")
# Create checkpoint manager
ckpt_mngr = create_checkpoint_manager(config.saving, ckpt_path)

# Restore the model from the checkpoint
state = restore_checkpoint(ckpt_mngr, state)
print(f"Restored model {job_name} from step", state.step)

In [ ]:
# Create sharding for data parallelism
mesh = Mesh(mesh_utils.create_device_mesh((jax.device_count(),)), "batch")
state = multihost_utils.host_local_array_to_global_array(state, mesh, P())
fae_state = multihost_utils.host_local_array_to_global_array(fae_state, mesh, P())

# Create encoder and decoder steps
encoder_step = create_encoder_step(encoder, mesh)
decoder_step = create_decoder_step(decoder, mesh)

In [ ]:
#  Generate samples
coords = jnp.linspace(0, 1, 128)
coords = multihost_utils.host_local_array_to_global_array(coords, mesh, P())

x_test, y_test = generate_dataset(num_samples=8192, num_sensors=128)

test_dataset = BaseDataset(x_test, y_test)
test_loader = create_dataloader(test_dataset,
                                batch_size=4096,
                                num_workers=config.dataset.num_workers)

In [ ]:
rng = jax.random.PRNGKey(12345)

u_pred_list = []

for batch in test_loader:
    rng, keys = random.split(rng, 2)
    batch = jax.tree.map(jnp.array, batch)
    u = batch
    u_batch = (jnp.ones_like(u), u, jnp.ones_like(u))

    u_batch = multihost_utils.host_local_array_to_global_array(
        u_batch, mesh, P("batch")
    )
    z_u = encoder_step(fae_state.params[0], u_batch)

    rng, subkey = random.split(keys)
    z0 = random.normal(subkey, shape=z_u.shape)

    z1_new, _ = sample_ode(state, z0=z0, num_steps=100)
    u_pred = decoder_step(fae_state.params[1], z1_new, coords)

    u_pred_list.append(u_pred)

samples = jnp.concatenate(u_pred_list, axis=0)

np.save(f"samples_dit_{config.dataset.num_samples}_samples.npy", samples)

In [None]:
from scipy.optimize import curve_fit
from scipy.signal import find_peaks


def damped_sine(t, A, gamma, omega, shift):
    return A * np.exp(-gamma * t) * np.sin(omega * t) + shift


def fit_damped_sine_simple(t, y):
    """
    Fit a damped sine using first two peaks and y[0] for shift.
    """
    # Get shift directly from y[0]
    shift = y[0]
    y_no_shift = y - shift
    
    # Find peaks
    peaks, _ = find_peaks(y_no_shift)
    if len(peaks) < 2:
        raise ValueError("Need at least two peaks to fit")
    
    # Get first two peaks
    t1, t2 = t[peaks[0]], t[peaks[1]]
    y1, y2 = y_no_shift[peaks[0]], y_no_shift[peaks[1]]
    
    # Calculate parameters
    omega = 2 * np.pi / (t2 - t1)
    gamma = -np.log(abs(y2/y1)) / (t2 - t1)
    A = y1 / np.exp(-gamma * t1)
    
    # Get uncertainties
    y_fit = damped_sine(x, A, gamma, omega, shift)
    mse = np.mean((y - y_fit)**2)

    return {
        "A": A,
        "gamma": gamma,
        "omega": omega,
        "shift": shift,
        "mse": mse
    }

In [None]:
 # Fit the data
coords = jnp.linspace(0, 1, 128)

num_samples = 1024
generated_samples = np.load(f"dit_samples_{num_samples}.npy")

# Visualize some samples
k = 1
x = coords
y = generated_samples[k]

# Get shift directly from y[0]
shift = y[0]
y_no_shift = y - shift

# Find peaks
peaks, _ = find_peaks(y_no_shift)
if len(peaks) < 2:
    raise ValueError("Need at least two peaks to fit")

# Get first two peaks
x1, x2 = x[peaks[0]], x[peaks[1]]
y1, y2 = y_no_shift[peaks[0]], y_no_shift[peaks[1]]

# Calculate parameters
omega = 2 * np.pi / (x2 - x1)
gamma = -np.log(abs(y2/y1)) / (x2 - x1)
A = y1 / np.exp(-gamma * x1)

print("A:", A, "gamma:", gamma, "omega:", omega, "shift:", shift)


# Fit the data
y_pred = damped_sine(x, A, gamma, omega, shift)
plt.plot(x1, y1 + shift, 'ro')
plt.plot(x2, y2 + shift, 'ro')
plt.plot(x, y_pred, label='Fitted curve', linestyle='--')
plt.plot(x, y, label='Generated data')
plt.legend()