In [1]:
import os
import numpy as np

os.environ["KERAS_BACKEND"] = "jax"

In [2]:
# import tensorflow as tf
from keras.layers import (
    Input, Conv2D, MaxPooling2D, UpSampling2D, Concatenate, BatchNormalization, Activation,
    AveragePooling2D, Reshape, GlobalAveragePooling2D, Lambda
)
from keras.models import Model

from examples.NHRC.nhrc_utils.new_cnn import NEW_INPUT_SHAPE

def segmentation_model(input_shape=NEW_INPUT_SHAPE, num_classes=4):
    inputs = Input(shape=input_shape)
    
    # Encoder
    c1 = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    c1 = BatchNormalization()(c1)
    p1 = MaxPooling2D((2, 2))(c1)  # Downsampling
    
    c2 = Conv2D(128, (3, 3), activation='relu', padding='same')(p1)
    c2 = BatchNormalization()(c2)
    p2 = MaxPooling2D((2, 2))(c2)  # Further Downsampling
    
    # c3 = Conv2D(256, (3, 3), activation='relu', padding='same')(p2)
    # c3 = BatchNormalization()(c3)
    # p3 = MaxPooling2D((2, 2))(c3)  # Bottleneck
    
    # # Decoder
    # u1 = UpSampling2D((2, 2))(p3)
    # u1 = Conv2D(128, (3, 3), activation='relu', padding='same')(u1)
    # u1 = BatchNormalization()(u1)
    # u1 = Concatenate()([u1, c3])  # Skip connection
    
    u2 = UpSampling2D((2, 2))(p2)
    u2 = Conv2D(64, (3, 3), activation='relu', padding='same')(u2)
    u2 = BatchNormalization()(u2)
    u2 = Concatenate()([u2, c2])  # Skip connection
    u2 = UpSampling2D((2, 2))(u2)
    
    # Collapse frequency axis
    collapse = AveragePooling2D(pool_size=(1, u2.shape[2]))(u2)  # Collapse frequency (128 -> 1)

    # Downsample to match label shape
    final_downsampling = AveragePooling2D(pool_size=(15, 1))(collapse)  # Downsample time (* -> 1024)
    
    # Final dense layer for class probabilities
    outputs = Conv2D(num_classes, (1, 1), activation='softmax')(final_downsampling)  # (1024, 4, 1)
    # outputs = Lambda(lambda x: tf.squeeze(x, axis=2))(outputs)  # Remove leftover spatial dimensions (1024, 4)
    outputs = outputs[..., 0]
    
    model = Model(inputs, outputs)
    return model

In [3]:
seg = segmentation_model()

Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
I0000 00:00:1732225010.096127 4739230 service.cc:145] XLA service 0x3242f5c30 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1732225010.096136 4739230 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1732225010.097595 4739230 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1732225010.097609 4739230 mps_client.cc:384] XLA backend will use up to 51539132416 bytes on device 0 for SimpleAllocator.


Metal device set to: Apple M3 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB



In [4]:
seg.summary()

In [5]:
inputs = np.random.normal(size=(4, *NEW_INPUT_SHAPE))

In [6]:
inputs.shape

(4, 15360, 257, 3)

In [10]:
for _ in range(10):
    seg.predict(inputs).shape

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 95ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 45ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 31ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 34ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 46ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step
[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 32ms/step


In [8]:
import jax

In [9]:
jax.devices()

[METAL(id=0)]