In [2]:
import keras
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tqdm.notebook as tqdm
from IPython.display import clear_output

2025-02-10 00:02:19.320278: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-10 00:02:19.372315: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
batch_size = 16
H = 500.0            
num_points = 256     
dz = H / (num_points - 1)
num_modes = 16         
f_value = 50        

In [4]:
def generate_dataset(batch_size, num_points=num_points, H=H):
    
    z = np.linspace(0, H, num_points)
    dataset = []
    z_borders = []
    for _ in range(batch_size):
        z_border = np.random.uniform(0, 300)
        z_borders.append(z_border)
        c = np.where(z < z_border, 1700.0, 1900.0)
        rho = np.where(z < z_border, 1.0, 2.0)
        data = np.stack([z, c, rho], axis=-1) 
        dataset.append(data)
        
    dataset = np.array(dataset, dtype=np.float32)
    
    return dataset, z_borders

In [5]:
def get_name(prefix: str | None = None, suffix: str | None = None, sep: str = "."):
    return prefix and suffix and prefix + sep + suffix

In [6]:
def Unet1D(
    features: int,
    n_levels: int,
    num_modes: int,
    num_points: int,
    name: str = "Unet1D"
):

    envir_input = keras.layers.Input(shape=(num_points, 3), name=get_name(name, "environment_input"))
    freq_input = keras.layers.Input(shape=(1,), name=get_name(name, "frequency_input"))
    
    x = envir_input
    level_outputs = []
    

    for i in range(n_levels):
        level_name = get_name(name, f"down_level{i}")
        x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(level_name, "conv1"))(x)
        x = keras.layers.BatchNormalization(name=get_name(level_name, "bn1"))(x)
        x = keras.layers.ReLU(name=get_name(level_name, "relu1"))(x)
        
        x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(level_name, "conv2"))(x)
        x = keras.layers.BatchNormalization(name=get_name(level_name, "bn2"))(x)
        x = keras.layers.ReLU(name=get_name(level_name, "relu2"))(x)
        
        level_outputs.append(x)
        
        x = keras.layers.MaxPooling1D(pool_size=2, name=get_name(level_name, "pool"))(x)
        features *= 2

    x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(name, "down_conv1"))(x)
    x = keras.layers.BatchNormalization(name=get_name(name, "down_bn1"))(x)
    x = keras.layers.ReLU(name=get_name(name, "relu_down1"))(x)
    
    x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(name, "down_conv2"))(x)
    x = keras.layers.BatchNormalization(name=get_name(name, "down_bn2"))(x)
    x = keras.layers.ReLU(name=get_name(name, "relu_down2"))(x)
    

    freq_add = keras.layers.Lambda(
        lambda inputs: tf.repeat(tf.expand_dims(inputs[0], axis=1), tf.shape(inputs[1])[1], axis=1),
        name=get_name(name, "freq_add")
    )([freq_input, x])
    
    x = keras.layers.Concatenate(axis=-1, name=get_name(name, "concat_freq"))([x, freq_add])
    
    for j, skip in enumerate(level_outputs[::-1]):
        level_name_up = get_name(name, f"up_level{j}")
        features //= 2
        x = keras.layers.UpSampling1D(size=2, name=get_name(level_name_up, "upsample"))(x)
        x = keras.layers.Concatenate(name=get_name(level_name_up, "concat"))([x, skip])
        
        x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(level_name_up, "conv1"))(x)
        x = keras.layers.BatchNormalization(name=get_name(level_name_up, "bn1"))(x)
        x = keras.layers.ReLU(name=get_name(level_name_up, "relu1"))(x)
        
        x = keras.layers.Conv1D(filters=features, kernel_size=3, padding='same', name=get_name(level_name_up, "conv2"))(x)
        x = keras.layers.BatchNormalization(name=get_name(level_name_up, "bn2"))(x)
        x = keras.layers.ReLU(name=get_name(level_name_up, "relu2"))(x)
        
    modes = keras.layers.Conv1D(filters=num_modes, kernel_size=1, padding='same', activation='linear', name=get_name(name, "modes_conv"))(x)

    wave_numbers = keras.layers.GlobalAveragePooling1D(name=get_name(name, "global_pool"))(modes)
    wave_numbers = keras.layers.Dense(num_modes, activation='linear', name=get_name(name, "wave_numbers_dense"))(wave_numbers)
    
    mod_features = keras.layers.Reshape((num_modes, 256, 1), name=get_name(name, "reshape"))(modes)
    
    model = keras.Model(
        inputs=[envir_input, freq_input],
        outputs=[mod_features, wave_numbers],
        name=name
    )
    
    return model

In [8]:
model = Unet1D(features=64, n_levels=4, num_modes=16, num_points=256)
model.summary()

In [None]:
keras.utils.plot_model(model)

In [9]:
dataset, z_borders = generate_dataset(batch_size)
freq = np.full((batch_size, 1), f_value, dtype=np.float32)

In [10]:
dataset.shape, freq.shape

((16, 256, 3), (16, 1))

In [11]:
dataset

array([[[0.0000000e+00, 1.7000000e+03, 1.0000000e+00],
        [1.9607843e+00, 1.7000000e+03, 1.0000000e+00],
        [3.9215686e+00, 1.7000000e+03, 1.0000000e+00],
        ...,
        [4.9607843e+02, 1.9000000e+03, 2.0000000e+00],
        [4.9803922e+02, 1.9000000e+03, 2.0000000e+00],
        [5.0000000e+02, 1.9000000e+03, 2.0000000e+00]],

       [[0.0000000e+00, 1.7000000e+03, 1.0000000e+00],
        [1.9607843e+00, 1.7000000e+03, 1.0000000e+00],
        [3.9215686e+00, 1.7000000e+03, 1.0000000e+00],
        ...,
        [4.9607843e+02, 1.9000000e+03, 2.0000000e+00],
        [4.9803922e+02, 1.9000000e+03, 2.0000000e+00],
        [5.0000000e+02, 1.9000000e+03, 2.0000000e+00]],

       [[0.0000000e+00, 1.7000000e+03, 1.0000000e+00],
        [1.9607843e+00, 1.7000000e+03, 1.0000000e+00],
        [3.9215686e+00, 1.7000000e+03, 1.0000000e+00],
        ...,
        [4.9607843e+02, 1.9000000e+03, 2.0000000e+00],
        [4.9803922e+02, 1.9000000e+03, 2.0000000e+00],
        [5.0000000e+02