In [1]:
import os
import numpy as np
import numpy.typing as npt
from pathlib import Path
import sklearn.model_selection
import tensorflow as tf
from neuralspot.tflite.metrics import MultiF1Score
from sleepkit.defines import SKTrainParams, SleepStage
from sleepkit.datasets import Hdf5Dataset
from sleepkit.utils import env_flag, set_random_seed, setup_logger
from sleepkit.metrics import compute_iou, confusion_matrix_plot, multi_f1
from sleepkit.datasets.utils import create_dataset_from_data
from sleepkit.defines import SKTrainParams
from neuralspot.tflite.metrics import get_flops
from neuralspot.tflite.model import get_strategy
from sleepkit.models.unet import UNet, UNetParams, UNetBlockParams


2023-09-21 23:26:55.202712: I tensorflow/core/util/port.cc:110] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2023-09-21 23:26:55.223630: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
logger = setup_logger(__name__)

In [3]:
params = SKTrainParams(
    job_dir=Path("..", "results", "mesa-fs002", "experiment-001"),
    ds_path=Path("..", "datasets", "processed", "mesa-fs002"),
    sampling_rate=64,
    frame_size=64,
    samples_per_subject=250,
    val_samples_per_subject=200,
    val_subjects=0.20,
    batch_size=256,
    buffer_size=100000,
    epochs=75,
    lr_rate=1e-3,
    lr_cycles=3,
    steps_per_epoch=200,
    val_metric="loss",
    val_size=60000
)

In [4]:
def get_target_classes(nstages: int):
    if nstages == 2:
        return [0, 1]
    if nstages == 3:
        return [0, 1, 2]
    if nstages == 4:
        return [0, 1, 2, 3]
    if nstages == 5:
        return [0, 1, 2, 3, 4]
    raise ValueError(f"Invalid number of stages: {nstages}")

def get_class_mapping(nstages: int):
    if nstages == 2:
        return {
            SleepStage.wake: 0,
            SleepStage.stage1: 1,
            SleepStage.stage2: 1,
            SleepStage.stage3: 1,
            SleepStage.stage4: 1,
            SleepStage.rem: 1,  
        }
    if nstages == 3:
        return {
            SleepStage.wake: 0,
            SleepStage.stage1: 1,
            SleepStage.stage2: 1,
            SleepStage.stage3: 1,
            SleepStage.stage4: 1,
            SleepStage.rem: 2,  
        }
    if nstages == 4:
        return {
            SleepStage.wake: 0,
            SleepStage.stage1: 1,
            SleepStage.stage2: 1,
            SleepStage.stage3: 2,
            SleepStage.stage4: 2,
            SleepStage.rem: 3,  
        }
    if nstages == 5:
        return {
            SleepStage.wake: 0,
            SleepStage.stage1: 1,
            SleepStage.stage2: 2,
            SleepStage.stage3: 3,
            SleepStage.stage4: 3,
            SleepStage.rem: 4,  
        }
    raise ValueError(f"Invalid number of stages: {nstages}")

def get_class_names(nstages: int):
    if nstages == 2:
        return ["WAKE", "SLEEP"]
    if nstages == 3:
        return ["WAKE", "NREM", "REM"]
    if nstages == 4:
        return ["WAKE", "CORE", "DEEP", "REM"]
    if nstages == 5:
        return ["WAKE", "N1", "N2", "N3", "REM"]
    raise ValueError(f"Invalid number of stages: {nstages}")


In [5]:
def load_model(inputs: tf.Tensor, num_classes: int = 2):
    blocks = [
        UNetBlockParams(filters=24, depth=1, kernel=(1, 5), strides=(1, 2), skip=True, seperable=True),
        UNetBlockParams(filters=48, depth=1, kernel=(1, 5), strides=(1, 2), skip=True, seperable=True),
        UNetBlockParams(filters=64, depth=1, kernel=(1, 5), strides=(1, 2), skip=True, seperable=False),
    ]
    return UNet(
        inputs,
        params=UNetParams(
            blocks=blocks,
            output_kernel_size=(1, 5),
            include_top=True,
            use_logits=False,
            include_rnn=False,
        ),
        num_classes=num_classes,
    )

# def load_model_v1(inputs: tf.Tensor, num_classes: int = 2):
#     y = inputs
#     y = tf.keras.layers.Conv1D(filters=32, kernel_size=5, strides=1, padding="same")(y)
#     y = tf.keras.layers.BatchNormalization()(y)
#     y = tf.keras.layers.Activation(tf.nn.relu6)(y)
#     y = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, padding="same")(y)
#     y = tf.keras.layers.BatchNormalization()(y)
#     y = tf.keras.layers.Activation(tf.nn.relu6)(y)    
#     # y = tf.keras.layers.Conv1D(filters=48, kernel_size=5, strides=1, padding="same")(y)
#     # y = tf.keras.layers.BatchNormalization()(y)
#     # y = tf.keras.layers.Activation(tf.nn.relu6)(y)
#     # y = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=64, return_sequences=True))(y)
#     y = tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(units=64, return_sequences=False))(y)
#     y = tf.keras.layers.Dense(units=96)(y)
#     y = tf.keras.layers.Dense(units=num_classes)(y)
#     y = tf.keras.layers.Softmax()(y)
#     model = tf.keras.Model(inputs, y)
#     return model


In [6]:
def prepare(x, y, num_classes, class_map: dict[int, int]):
    return (
        x,
        tf.one_hot(np.vectorize(class_map.get)(y), num_classes)
    )

def load_train_datasets(params: SKTrainParams, feat_shape, class_shape, class_map, feat_cols=None):
    def preprocess(x: npt.NDArray[np.float32]):
       return x + np.random.normal(0, 0.1, size=x.shape)

    output_signature = (
        tf.TensorSpec(shape=feat_shape, dtype=tf.float32),
        tf.TensorSpec(shape=class_shape, dtype=tf.int32),
    ) 
 
    ds = Hdf5Dataset(
        ds_path=params.ds_path,
        frame_size=params.frame_size,
        mask_key="mask",
        feat_cols=feat_cols,
    )
    train_subject_ids, val_subject_ids = sklearn.model_selection.train_test_split(
        ds.train_subject_ids, test_size=params.val_subjects
    )

    def train_generator(subject_ids):
        def ds_gen():
            train_subj_gen = ds.uniform_subject_generator(subject_ids)
            return map(
                lambda x_y: prepare(preprocess(x_y[0]), x_y[1], class_shape[-1], class_map),
                ds.signal_generator(train_subj_gen, samples_per_subject=params.samples_per_subject)
            )
        return tf.data.Dataset.from_generator(
            ds_gen,
            output_signature=output_signature,
        )

    split = len(train_subject_ids) // params.data_parallelism
    train_datasets = [train_generator(
        train_subject_ids[i * split : (i + 1) * split]
    ) for i in range(params.data_parallelism)]

    # Create TF datasets
    train_ds = tf.data.Dataset.from_tensor_slices(
        train_datasets
    ).interleave(
        lambda x: x,
        cycle_length=params.data_parallelism,
        deterministic=False,
        num_parallel_calls=tf.data.AUTOTUNE,
    ).shuffle(
        buffer_size=params.buffer_size,
        reshuffle_each_iteration=True,
    ).batch(
        batch_size=params.batch_size,
        drop_remainder=False,
    ).prefetch(
        buffer_size=tf.data.AUTOTUNE
    )

    def val_generator():
        val_subj_gen = ds.uniform_subject_generator(val_subject_ids)
        return map(
            lambda x_y: prepare(preprocess(x_y[0]), x_y[1], class_shape[-1], class_map),
            ds.signal_generator(val_subj_gen, samples_per_subject=params.samples_per_subject)
        )

    val_ds = tf.data.Dataset.from_generator(
        generator=val_generator,
        output_signature=output_signature
    )
    val_x, val_y = next(val_ds.batch(params.val_size).as_numpy_iterator())
    val_ds = create_dataset_from_data(
        val_x, val_y, output_signature=output_signature
    ).batch(
        batch_size=params.batch_size,
        drop_remainder=False,
    )

    return train_ds, val_ds


In [7]:
params.seed = set_random_seed(params.seed)
logger.info(f"Random seed {params.seed}")

In [24]:
num_sleep_stages = 3

feat_names = [
    "SPO2-mu", "SPO2-std", "SPO2-med", "SPO2-iqr",
    "MOV-mu", "MOV-std", "MOV-med", "MOV-iqr",
    "RRI-mu", "RRI-std", "RRI-med", "RRI-iqr", "RRI-sd-rms", "RRI-sd-std",
    "HR-bpm", "RSP-bpm", "HRV-lf", "HRV-hf", "HRV-lfhf"
]
feat_cols = [
    0, 1, 2, 3,
    4, 5, 6, 7,
    8, 9, 10, 11, 12, 13,
    14, 16, 17, 18
]

num_feats = len(feat_cols)
target_classes = get_target_classes(num_sleep_stages)
class_names = get_class_names(num_sleep_stages)
class_mapping = get_class_mapping(num_sleep_stages)
num_classes = len(target_classes)


In [11]:
os.makedirs(params.job_dir, exist_ok=True)

In [12]:
feat_shape = (params.frame_size, num_feats)
class_shape = (params.frame_size, num_classes)
inputs = tf.keras.Input(feat_shape, batch_size=None, dtype=tf.float32)

In [13]:
strategy = get_strategy()
with strategy.scope():
    train_ds, val_ds = load_train_datasets(params, feat_shape, class_shape, class_mapping, feat_cols=feat_cols)
    model = load_model(inputs, num_classes=len(target_classes))
    flops = get_flops(model, batch_size=1)

    # Grab optional LR parameters
    lr_rate: float = getattr(params, "lr_rate", 1e-3)
    lr_cycles: int = getattr(params, "lr_cycles", 1)
    steps_per_epoch = params.steps_per_epoch or 1000
    if lr_cycles == 1:
        scheduler = tf.keras.optimizers.schedules.CosineDecay(
            initial_learning_rate=lr_rate,
            decay_steps=int(steps_per_epoch * params.epochs),
        )
    else:
        scheduler = tf.keras.optimizers.schedules.CosineDecayRestarts(
            initial_learning_rate=lr_rate,
            first_decay_steps=int(0.1 * steps_per_epoch * params.epochs),
            t_mul=1.65 / (0.1 * lr_cycles * (lr_cycles - 1)),
            m_mul=0.4,
        )
    optimizer = tf.keras.optimizers.Adam(scheduler)
    loss = tf.keras.losses.CategoricalFocalCrossentropy(
        from_logits=False,
        label_smoothing=getattr(params, "label_smoothing", 0.1),
    )
    metrics = [
        tf.keras.metrics.CategoricalAccuracy(name="acc"),
        MultiF1Score(name="f1", dtype=tf.float32, average="macro"),
        tf.keras.metrics.OneHotIoU(
            num_classes=len(target_classes),
            target_class_ids=target_classes,
            name="iou",
        ),
    ]    
    model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
    model(inputs)

    model.summary(print_fn=logger.info)
    logger.info(f"Model requires {flops/1e6:0.2f} MFLOPS")

    params.weights_file = str(params.job_dir / "model.weights")

    model_callbacks = [
        tf.keras.callbacks.EarlyStopping(
            monitor=f"val_{params.val_metric}",
            patience=max(int(0.25 * params.epochs), 1),
            mode="max" if params.val_metric == "f1" else "auto",
            restore_best_weights=True,
        ),
        tf.keras.callbacks.ModelCheckpoint(
            filepath=params.weights_file,
            monitor=f"val_{params.val_metric}",
            save_best_only=True,
            save_weights_only=True,
            mode="max" if params.val_metric == "f1" else "auto",
            verbose=1,
        ),
        tf.keras.callbacks.CSVLogger(str(params.job_dir / "history.csv")),
        tf.keras.callbacks.TensorBoard(
            log_dir=str(params.job_dir / "logs"),
            write_steps_per_second=True
        ),
    ]


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


2023-09-21 23:23:25.887959: E tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:268] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-09-21 23:23:25.887998: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:168] retrieving CUDA diagnostic information for host: 70412016908e
2023-09-21 23:23:25.888001: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:175] hostname: 70412016908e
2023-09-21 23:23:25.888069: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:199] libcuda reported version is: 525.125.6
2023-09-21 23:23:25.888075: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:203] kernel reported version is: 525.125.6
2023-09-21 23:23:25.888077: I tensorflow/compiler/xla/stream_executor/cuda/cuda_diagnostics.cc:309] kernel version seems to match DSO: 525.125.6


In [None]:
with strategy.scope():
    try:
        model.fit(
            train_ds,
            steps_per_epoch=steps_per_epoch,
            verbose=2,
            epochs=params.epochs,
            validation_data=val_ds,
            callbacks=model_callbacks,
        )
    except KeyboardInterrupt:
        logger.warning("Stopping training due to keyboard interrupt")


In [None]:
model.load_weights(params.weights_file)

In [None]:
test_labels = [y.numpy() for _, y in val_ds]
y_true = np.argmax(np.concatenate(test_labels).squeeze(), axis=-1)
y_pred = np.argmax(model.predict(val_ds).squeeze(), axis=-1)

# Summarize results
test_acc = np.sum(y_pred == y_true) / y_true.size
test_iou = compute_iou(y_true, y_pred, average="weighted")
test_f1 = multi_f1(y_true, y_pred, average="macro")
test_f1 = 0
logger.info(f"[TEST SET] ACC={test_acc:.2%}, IoU={test_iou:.2%} F1={test_f1:.2%}")

cm_path = str(params.job_dir / f"confusion_matrix_test{num_sleep_stages}_unet.png")
confusion_matrix_plot(
    y_true.flatten(),
    y_pred.flatten(),
    labels=class_names,
    save_path=cm_path,
    normalize="true",
)

In [12]:
import h5py
import plotly.graph_objects as go

In [15]:
h5 = h5py.File(str(params.ds_path / "0002.h5"), "r")

In [16]:
h5.keys()

<KeysViewHDF5 ['features', 'labels', 'mask']>

In [21]:
features = h5["features"][:]
sleep_stages = h5["labels"][:]
mask = h5["mask"][:]


In [27]:
x = np.arange(mask.size)
fig = go.Figure()
for i in range(features.shape[1]):
    fig.add_trace(go.Scatter(x=x, y=np.where(mask, features[:, i], np.nan), mode="lines", name=feat_names[i]))
fig.add_trace(go.Scatter(x=x, y=mask, mode="lines", name="mask"))
sleep_boundaries = np.concatenate(([0], np.diff(sleep_stages).nonzero()[0]+1))
for i in range(1, len(sleep_boundaries)):
    start = sleep_boundaries[i-1]
    stop = sleep_boundaries[i]
    stage = sleep_stages[start]
    color_map = {0: "white", 1: "yellow", 2: "orange", 3: "red", 4: "purple", 5: "blue"}
    fig.add_vrect(x0=start, x1=stop, fillcolor=color_map.get(stage, None), opacity=0.4, layer="below", line_width=0)
fig.show()