In [1]:
import keras
from keras import layers
import numpy as np
import jax
import jax.numpy as jnp

# Define the model architecture (as done earlier)
n = 32  # base filter count
input_shape = (42, 97, 1)  # Assuming 1 channel (tas)

def build_network(input_shape):
    inputs = keras.Input(shape=input_shape)
    
    # Encoder with convolutional layers
    x = inputs
    skip_connections = []

    # 31 convolutional layers, with filter numbers n, 2n, 4n and batch normalization
    for i in range(31):
        filters = n if i < 10 else (2 * n if i < 20 else 4 * n)
        x = layers.Conv2D(filters, kernel_size=3, strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        if i % 2 == 1:  # Every 2 convolutional layers, save skip connection
            skip_connections.append(x)

    # Decoder with upsampling layers and skip connections
    for i in range(16):
        filters = 4 * n if i < 6 else (2 * n if i < 12 else n)
        x = layers.Conv2DTranspose(filters, kernel_size=3, strides=1, padding='same')(x)
        x = layers.BatchNormalization()(x)
        x = layers.ReLU()(x)
        
        if i < len(skip_connections):
            skip_connection = skip_connections[-(i + 1)]
            if skip_connection.shape[-1] != x.shape[-1]:  # Align the number of channels
                skip_connection = layers.Conv2D(filters, kernel_size=1, strides=1, padding='same')(skip_connection)
            x = layers.Add()([x, skip_connection])

    # Final convolutional layer without activation
    outputs = layers.Conv2D(1, kernel_size=1, strides=1, padding='same')(x)
    
    # Building the model
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

# Build and compile model
model = build_network(input_shape)


2024-10-24 12:15:02.132452: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-10-24 12:15:02.148996: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-10-24 12:15:02.167423: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-10-24 12:15:02.172875: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-10-24 12:15:02.187935: I tensorflow/core/platform/cpu_feature_guar