# ECG Foundation Model

__Date created:__ 2024/07/17 

__Last Modified:__ 2024/07/17 

__Description:__ Train, evaluate, and export 4-stage ECG arrhythmia classifier

## Overview 

This notebook demonstrates creating a foundational model for raw ECG signals. By creating a foundational model, we can create small, down-stream classification models.

In [3]:
import os
os.environ["KMP_AFFINITY"] = "noverbose"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # 3
os.environ['AUTOGRAPH_VERBOSITY'] = '2' # 5

import functools
import random
from typing import Generator
from pathlib import Path
import tempfile
import tensorflow as tf
from tqdm import tqdm
import sklearn.model_selection
import keras
import numpy as np
import numpy.typing as npt
import heartkit as hk
import physiokit as pk
import neuralspot_edge as nse
from neuralspot_edge.trainers.simclr import SimCLRTrainer
import matplotlib as mpl
import matplotlib.pyplot as plt
import plotly.io as pio

hk.silence_tensorflow()
logger = hk.setup_logger('heartkit', level=2)

## Constants

Here we provide the constants that we will use throughout the guide. For better performance, adjust parameters such as `BATCH_SIZE`, `EPOCHS`, and `LEARNING_RATE`.

In [4]:
# Seed for reproducibility
seed = 42

# File paths
datasets_dir = Path("../../datasets")
job_dir = Path(tempfile.gettempdir()) / "hk-foundation"
model_file = job_dir / "model.keras"
val_file = job_dir / "val.pkl"

os.makedirs(job_dir, exist_ok=True)

# Data settings
sampling_rate = 100 # 100 Hz
input_size = 1000 # 10 seconds
frame_size = 800 # 8 seconds

# Training settings
batch_size = 1024        # Batch size for training
buffer_size = 10000      # How many samples are shuffled each epoch
epochs = 100             # Increase this to 100+
steps_per_epoch = 25     # # Steps per epoch (must set since ds has unknown size)
samples_per_patient = 1  # Number of samples per patient
val_size = 1000         # Number of samples used for validation
test_size = 1000        # Number of samples used for validation
val_percentage = 0.2     # Percentage of samples used for validation
verbose = 1              # Verbosity level
learning_rate = 1e-3     # Learning rate for Adam optimizer

# Model settings
projection_width = 128
temperature = 0.1

# Plotting settings
bg_rgba_color = "rgba(38,42,50,1.0)"
bg_color = "#262a32"
primary_color = "#11acd5"
secondary_color = "#ce6cff"
tertiary_color = "#ea3424"
quaternary_color = "#5cc99a"
colors = [primary_color, secondary_color, tertiary_color, quaternary_color]
plotly_template = "plotly_dark"
pio.renderers.default = "notebook"
plt.style.use('dark_background')
mpl.rcParams['axes.facecolor'] = bg_color
mpl.rcParams['figure.facecolor'] = bg_color

## Configure datasets

We are going to train our model using two large datasets: the PTB-XL dataset and the large-scale arrhythmia dataset. 

In [5]:
datasets = [
    hk.DatasetParams(
        name="lsad",
        path=datasets_dir / "lsad",
        params={}
    ),
    hk.DatasetParams(
        name="ptbxl",
        path=datasets_dir / "ptbxl",
        params={}
    ),
]

### Download datasets

In [5]:
hk.datasets.download_datasets(hk.HKDownloadParams(
    datasets=datasets,
    force=False,
    progress=True
))

## Lets load all subjects data and split into train and test

In [6]:
dsets = [hk.DatasetFactory.get(dataset.name)(
    ds_path=dataset.path,
) for dataset in datasets]

num_pts = sum((len(ds.get_train_patient_ids()) for ds in dsets))

train_data = np.zeros((
    num_pts,
    input_size,
    1
))
pt_idx = 0
for ds in dsets:
    train_pt_ids = ds.get_train_patient_ids()
    for pt_id in tqdm(train_pt_ids):
        with ds.patient_data(pt_id) as h5:
            data = h5["data"][0:1, :].T
        # END WITH
        data = pk.signal.resample_signal(data, sample_rate=ds.sampling_rate, target_rate=sampling_rate, axis=0)
        data = np.expand_dims(data, axis=0)
        train_data[pt_idx] = data
        pt_idx += 1
    # END FOR
# END FOR

100%|██████████| 36120/36120 [00:13<00:00, 2764.34it/s]
100%|██████████| 18500/18500 [00:07<00:00, 2611.86it/s]


In [7]:
train_data, val_data = sklearn.model_selection.train_test_split(
    train_data,
    test_size=val_percentage,
    random_state=seed
)

## Create TF train and validation datasets

In [8]:
train_ds = tf.data.Dataset.from_tensor_slices(train_data)
train_ds = train_ds.shuffle(
    buffer_size,
).batch(
    batch_size
)

val_ds = tf.data.Dataset.from_tensor_slices(val_data)
val_ds = val_ds.batch(
    batch_size
)

2024-07-25 21:33:53.890257: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected


In [12]:
x = next(iter(train_ds))
print(x.shape)

(1024, 1000, 1)


In [10]:
nstdb = hk.datasets.nstdb.NstdbNoise(target_rate=sample_rate)
noises = np.hstack((nstdb.get_noise(noise_type="bw"), nstdb.get_noise(noise_type="ma"), nstdb.get_noise(noise_type="em")))

augmentation_pipeline = nse.layers.preprocessing.ts.augmentation_pipeline.AugmentationPipeline(
    layers=[
        nse.layers.preprocessing.ts.random_crop.RandomCrop(
            duration=frame_size,
        ),
        nse.layers.preprocessing.ts.gaussian_noise.GaussianNoise(
            stddev=0.05
        ),
        nse.layers.preprocessing.ts.random_cutout.RandomCutout(
            factor=(0.05, 0.1),
            cutouts=(1, 3)
            fill_mode="constant",
            fill_value=0.0
        ),
        nse.layers.preprocessing.ts.random_background_noises.RandomBackgroundNoises(
            noises=noises
        )
    ]
)

In [11]:
test_ds = train_ds.map(augmentation_pipeline)

In [14]:
x = next(iter(test_ds))
print(x.shape)

(1024, 800, 1)


In [5]:
dsets = []
for dset in datasets:
    if hk.DatasetFactory.has(dset.name):
        dsets.append(hk.DatasetFactory.get(dset.name)(ds_path=dset.path, **dset.params))
    # END IF
# END FOR

## Preprocess pipeline

We will preprocess the ECG signals by applying the following steps:
* Apply Z-score normalization w/ epsilon to avoid division by zero

The task accepts a list of preprocessing functions that will be applied to the input data. 

__NOTE:__ We dont apply any filtering as the model is expected to learn the filtering mechanism.

In [6]:
preprocesses = [
    hk.PreprocessParams(name="znorm", params=dict(eps=0.01, axis=None))
]

## Augmentation pipeline

We will apply the following augmentations to the ECG signals:
* Baseline wander: Simulate baseline wander by adding a random frequency sinusoidal signal to the ECG signal
* Powerline noise: Simulate powerline noise by adding a 50 Hz sinusoidal signal to the ECG signal
* Burst noise: Simulate burst noise by randomly injecting burst of high frequency noise to the ECG signal
* Noise sources: Apply several noises at given frequencies to the ECG signal
* Lead noise: Simulate lead noise by adding a random frequency sinusoidal signal to the ECG signal
* NSTDB: Add real noise captured from NSTDB dataset to the ECG signal. 


In [7]:
augmentations = [
    hk.AugmentationParams(name="baseline_wander", params=dict(amplitude=[0.0, 0.5], frequency=[0.5, 1.5])),
    hk.AugmentationParams(name="powerline_noise", params=dict(amplitude=[0.05, 0.15], frequency=[45, 50])),
    hk.AugmentationParams(name="burst_noise", params=dict(burst_number=[0, 4], amplitude=[0.05, 0.1], frequency=[20, 49])),
    hk.AugmentationParams(name="noise_sources", params=dict(num_sources=[1, 2], amplitude=[0.05, 0.1], frequency=[10, 40])),
    hk.AugmentationParams(name="lead_noise", params=dict(scale=[0.05, 0.1])),
    hk.AugmentationParams(name="nstdb", params=dict(noise_level=[0.1, 0.3]))
]

In [8]:
def data_generator(
    patient_generator: hk.datasets.defines.PatientGenerator,
    ds: hk.datasets.HKDataset,
    frame_size: int,
    samples_per_patient: int | list[int] = 1,
    target_rate: int | None = None,
) -> Generator[tuple[npt.NDArray, npt.NDArray], None, None]:
    """Generate frames using patient generator.

    Args:
        patient_generator (PatientGenerator): Patient Generator
        ds: PtbxlDataset
        frame_size (int): Frame size
        samples_per_patient (int | list[int], optional): # samples per patient. Defaults to 1.
        target_rate (int|None, optional): Target rate. Defaults to None.

    Returns:
        Generator[tuple[npt.NDArray, npt.NDArray], None, None]: Sample generator

    """
    input_size = int(np.round((ds.sampling_rate / target_rate) * frame_size))
    data_cache = {}
    for pt in patient_generator:
        if pt not in data_cache:
            with ds.patient_data(pt) as h5:
                data_cache[pt] = h5["data"][:]
        data = data_cache[pt]

        for _ in range(samples_per_patient):
            leads = random.sample(ds.leads, k=2)
            lead_p1 = leads[0]
            lead_p2 = leads[1]
            start_p1 = np.random.randint(0, data.shape[1] - input_size)
            start_p2 = np.random.randint(0, data.shape[1] - input_size)
            # start_p2 = start_p1

            x1 = np.nan_to_num(data[lead_p1, start_p1 : start_p1 + input_size].squeeze()).astype(np.float32)
            x2 = np.nan_to_num(data[lead_p2, start_p2 : start_p2 + input_size].squeeze()).astype(np.float32)

            if ds.sampling_rate != target_rate:
                x1 = pk.signal.resample_signal(x1, ds.sampling_rate, target_rate, axis=0)
                x2 = pk.signal.resample_signal(x2, ds.sampling_rate, target_rate, axis=0)
            # END IF
            yield x1, x2
        # END FOR
    # END FOR

def preprocess(x: npt.NDArray, preprocesses: list[hk.PreprocessParams], sample_rate: float) -> npt.NDArray:
    """Preprocess data pipeline

    Args:
        x (npt.NDArray): Input data
        preprocesses (list[PreprocessParams]): Preprocess parameters
        sample_rate (float): Sample rate

    Returns:
        npt.NDArray: Preprocessed data
    """
    return hk.datasets.preprocess_pipeline(x, preprocesses=preprocesses, sample_rate=sample_rate)


def augment(x: npt.NDArray, augmentations: list[hk.AugmentationParams], sample_rate: float) -> npt.NDArray:
    """Augment data pipeline

    Args:
        x (npt.NDArray): Input data
        augmentations (list[AugmentationParams]): Augmentation parameters
        sample_rate (float): Sample rate

    Returns:
        npt.NDArray: Augmented data
    """

    return hk.datasets.augment_pipeline(x=x, augmentations=augmentations, sample_rate=sample_rate)

def prepare(
    x_y: tuple[npt.NDArray, npt.NDArray],
    sample_rate: float,
    preprocesses: list[hk.PreprocessParams],
    augmentations: list[hk.AugmentationParams],
    spec: tuple[tf.TensorSpec, tf.TensorSpec],
) -> tuple[npt.NDArray, npt.NDArray]:
    """Prepare dataset

    Args:
        x_y (tuple[npt.NDArray, npt.NDArray]): Input data
        sample_rate (float): Sampling rate
        preprocesses (list[PreprocessParams]): Preprocessing pipeline
        augmentations (list[AugmentationParams]): Augmentation pipeline
        spec (tuple[tf.TensorSpec, tf.TensorSpec]): Spec
        num_classes (int): Number of classes

    Returns:
        tuple[npt.NDArray, npt.NDArray]: Prepared data
    """
    x, y = x_y[0].copy(), x_y[1].copy()

    if augmentations:
        x = augment(x, augmentations, sample_rate)
        y = augment(y, augmentations, sample_rate)
    # END IF

    if preprocesses:
        x = preprocess(x, preprocesses, sample_rate)
        y = preprocess(y, preprocesses, sample_rate)
    # END IF

    x = x.reshape(spec[0].shape)
    y = y.reshape(spec[0].shape)

    return x, y

In [9]:
id_generator = functools.partial(hk.datasets.utils.uniform_id_generator, repeat=True)

feat_shape = (frame_size, 1)

ds_spec = (
    tf.TensorSpec(shape=feat_shape, dtype="float32"),
    tf.TensorSpec(shape=feat_shape, dtype="float32"),
)

train_prepare = functools.partial(
    prepare,
    sample_rate=sampling_rate,
    preprocesses=preprocesses,
    augmentations=augmentations,
    spec=ds_spec
)

train_datasets =[]
val_datasets = []
for ds in dsets:
    ds_gen = functools.partial(
        data_generator,
        ds=ds,
        frame_size=frame_size,
        samples_per_patient=samples_per_patient,
        target_rate=sampling_rate,
    )

    train_ds, val_ds = hk.datasets.train_val_dataloader(
        ds=ds,
        spec=ds_spec,
        data_generator=ds_gen,
        id_generator=id_generator,
        val_patients=val_percentage,
        val_pt_samples=samples_per_patient,
        val_size=val_size,
        preprocess=train_prepare,
        num_workers=os.cpu_count(),
    )
    train_datasets.append(train_ds)
    val_datasets.append(val_ds)
# END FOR


In [10]:
ds_weights = np.array([d.weight for d in datasets])
ds_weights = ds_weights / ds_weights.sum()

train_ds = tf.data.Dataset.sample_from_datasets(train_datasets, weights=ds_weights)
val_ds = tf.data.Dataset.sample_from_datasets(val_datasets, weights=ds_weights)

# Shuffle and batch datasets for training
train_ds = (
    train_ds.shuffle(
        buffer_size=buffer_size,
        reshuffle_each_iteration=True,
    )
    .batch(
        batch_size=batch_size,
        drop_remainder=False,
        num_parallel_calls=tf.data.AUTOTUNE,
    )
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)
val_ds = val_ds.batch(
    batch_size=batch_size,
    drop_remainder=True,
    num_parallel_calls=tf.data.AUTOTUNE,
)

In [11]:
x, y = next(iter(val_ds))
print(x.shape, y.shape)

(1024, 800, 1) (1024, 800, 1)


In [12]:
inputs = keras.Input(shape=(frame_size, 1), name="input")

encoder_params=dict(
    input_filters=24,
    input_kernel_size=(1, 9),
    input_strides=(1, 2),
    blocks=[
        dict(filters=32, depth=2, kernel_size=(1, 9), strides=(1, 2), ex_ratio=1, se_ratio=4, norm="layer"),
        dict(filters=48, depth=2, kernel_size=(1, 9), strides=(1, 2), ex_ratio=1, se_ratio=4, norm="layer"),
        dict(filters=64, depth=2, kernel_size=(1, 9), strides=(1, 2), ex_ratio=1, se_ratio=4, norm="layer"),
        dict(filters=80, depth=1, kernel_size=(1, 9), strides=(1, 2), ex_ratio=1, se_ratio=4, norm="layer"),
        dict(filters=96, depth=1, kernel_size=(1, 9), strides=(1, 2), ex_ratio=1, se_ratio=4, norm="layer"),
    ],
    output_filters=projection_width,
    include_top=True,
)

encoder = nse.models.efficientnet.efficientnetv2_from_object(
    x=inputs,
    params=encoder_params,
    num_classes=None
)


In [13]:
encoder.summary(print_fn=logger.info)
flops = nse.metrics.flops.get_flops(encoder, batch_size=1, fpath=os.devnull)
logger.info(f"Computation: {flops/1e6:0.2f} MFLOPs")
encoder_output = encoder(inputs)

In [14]:
projector_input = encoder_output
projector_output = keras.layers.Dense(projection_width, activation="relu6")(projector_input)
projector_output = keras.layers.Dense(projection_width)(projector_output)
projector = keras.Model(inputs=projector_input, outputs=projector_output, name="projector")
flops = nse.metrics.flops.get_flops(projector, batch_size=1, fpath=os.devnull)
projector.summary(print_fn=logger.info)
logger.debug(f"Projector requires {flops/1e6:0.2f} MFLOPS")

In [15]:
model = SimCLR(
    contrastive_augmenter=lambda x: x,
    encoder=encoder,
    projector=projector,
    # momentum_coeff=0.999,
    temperature=temperature,
    # queue_size=65536,
)

In [16]:
def get_scheduler():
    return keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=learning_rate,
        decay_steps=steps_per_epoch * epochs,
    )

model.compile(
    contrastive_optimizer=keras.optimizers.Adam(get_scheduler()),
    probe_optimizer=keras.optimizers.Adam(get_scheduler()),
)

In [17]:
val_metric = "loss"

model_callbacks = [
    keras.callbacks.EarlyStopping(
        monitor=f"val_{val_metric}",
        patience=max(int(0.25 * epochs), 1),
        mode="max" if val_metric == "f1" else "auto",
        restore_best_weights=True,
    ),
    keras.callbacks.ModelCheckpoint(
        filepath=str(model_file),
        monitor=f"val_{val_metric}",
        save_best_only=True,
        mode="max" if val_metric == "f1" else "auto",
        verbose=1,
    ),
    keras.callbacks.CSVLogger(job_dir / "history.csv"),
]
if hk.utils.env_flag("TENSORBOARD"):
    model_callbacks.append(
        keras.callbacks.TensorBoard(
            log_dir=job_dir,
            write_steps_per_second=True,
        )
    )


model.fit(
    train_ds,
    steps_per_epoch=steps_per_epoch,
    verbose=2,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=model_callbacks,
)

Epoch 1/100


I0000 00:00:1721671560.218993 1440263 service.cc:146] XLA service 0x74f6300176b0 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1721671560.219066 1440263 service.cc:154]   StreamExecutor device (0): NVIDIA GeForce RTX 4090, Compute Capability 8.9
I0000 00:00:1721671584.509272 1440263 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.



Epoch 1: val_loss improved from inf to 6.92869, saving model to /tmp/hk-foundation/model.keras
25/25 - 240s - 10s/step - c_acc: 0.0020 - loss: 6.9286 - r_acc: 0.0463 - val_c_acc: 9.7656e-04 - val_loss: 6.9287 - val_r_acc: 0.0352
Epoch 2/100

Epoch 2: val_loss did not improve from 6.92869
25/25 - 176s - 7s/step - c_acc: 9.7656e-04 - loss: 6.9301 - r_acc: 0.0358 - val_c_acc: 9.7656e-04 - val_loss: 6.9301 - val_r_acc: 0.0352
Epoch 3/100

Epoch 3: val_loss did not improve from 6.92869
25/25 - 181s - 7s/step - c_acc: 9.7656e-04 - loss: 6.9305 - r_acc: 0.0344 - val_c_acc: 9.7656e-04 - val_loss: 6.9305 - val_r_acc: 0.0312
Epoch 4/100

Epoch 4: val_loss did not improve from 6.92869
25/25 - 178s - 7s/step - c_acc: 9.7656e-04 - loss: 6.9308 - r_acc: 0.0352 - val_c_acc: 9.7656e-04 - val_loss: 6.9308 - val_r_acc: 0.0312
Epoch 5/100

Epoch 5: val_loss did not improve from 6.92869
25/25 - 173s - 7s/step - c_acc: 9.9609e-04 - loss: 6.9309 - r_acc: 0.0369 - val_c_acc: 9.7656e-04 - val_loss: 6.9309 - 

<keras.src.callbacks.history.History at 0x74f7dbbb9c70>

In [23]:
metrics = [
    keras.metrics.MeanAbsoluteError(name="mae"),
    keras.metrics.MeanSquaredError(name="mse"),
    keras.metrics.CosineSimilarity(name="cosine"),
]

optimizer = keras.optimizers.Adam(get_scheduler())
loss = keras.losses.MeanSquaredError()
encoder.compile(optimizer=optimizer, loss=loss, metrics=metrics)

encoder.fit(
    train_ds,
    steps_per_epoch=steps_per_epoch,
    verbose=2,
    epochs=epochs,
    validation_data=val_ds,
    callbacks=model_callbacks,
)

Epoch 1/100


ValueError: Dimensions must be equal, but are 800 and 128 for '{{node compile_loss/mean_squared_error/sub}} = Sub[T=DT_FLOAT](compile_loss/mean_squared_error/Squeeze, EfficientNetV2_1/dropout_7_1/stateless_dropout/SelectV2)' with input shapes: [?,800], [?,128].

In [35]:
from abc import abstractmethod
from typing import Callable

import keras
import tensorflow as tf


class ContrastiveModel(keras.Model):
    """Base class for contrastive learning models"""

    def __init__(
        self,
        encoder: keras.Model,
        projector: keras.Model,
        contrastive_augmenter: Callable[[keras.KerasTensor], keras.KerasTensor] | None = None,
        classification_augmenter: Callable[[keras.KerasTensor], keras.KerasTensor] | None = None,
        linear_probe: keras.Model | None = None,
    ):
        super().__init__()

        self.encoder = encoder
        self.projector = projector
        self.contrastive_augmenter = contrastive_augmenter
        self.classification_augmenter = classification_augmenter
        self.linear_probe = linear_probe

        self.probe_loss = None
        self.probe_optimizer = None
        self.contrastive_loss_tracker = None
        self.contrastive_optimizer = None
        self.contrastive_accuracy = None
        self.correlation_accuracy = None
        self.probe_accuracy = None

    @property
    def metrics(self):
        """List of metrics to track during training and evaluation"""
        return [
            self.contrastive_loss_tracker,
            self.correlation_accuracy,
            self.contrastive_accuracy,
            # self.probe_loss_tracker,
            # self.probe_accuracy,
        ]

    @abstractmethod
    def contrastive_loss(self, projections_1, projections_2):
        """Contrastive loss function"""
        raise NotImplementedError()

    def call(self, inputs, training=None, mask=None):
        """Forward pass through the encoder model"""
        return self.encoder(inputs, training=training, mask=mask)

    # pylint: disable=unused-argument,arguments-differ
    def compile(
        self,
        contrastive_optimizer: keras.optimizers.Optimizer,
        probe_optimizer: keras.optimizers.Optimizer | None = None,
        **kwargs,
    ):
        """Compile the model with the specified optimizers"""
        super().compile(**kwargs)

        self.contrastive_optimizer = contrastive_optimizer
        self.probe_optimizer = probe_optimizer

        # self.contrastive_loss is a method that will be implemented by the subclasses
        self.probe_loss = keras.losses.SparseCategoricalCrossentropy(from_logits=True)

        self.contrastive_loss_tracker = keras.metrics.Mean(name="loss")
        self.contrastive_accuracy = keras.metrics.SparseCategoricalAccuracy(name="c_acc")
        self.correlation_accuracy = keras.metrics.SparseCategoricalAccuracy(name="r_acc")

        self.probe_accuracy = keras.metrics.SparseCategoricalAccuracy()

    def save(self, filepath, overwrite=True, save_format=None, **kwargs):
        """Save the encoder model to file

        Args:
            filepath (str): Filepath
            overwrite (bool, optional): Overwrite existing file. Defaults to True.
            save_format ([type], optional): Save format. Defaults to None.
        """
        self.encoder.save(filepath, overwrite, save_format, **kwargs)

    def reset_metrics(self):
        """Reset the metrics to their initial state"""
        self.contrastive_accuracy.reset_state()
        self.correlation_accuracy.reset_state()
        self.probe_accuracy.reset_state()

    def update_contrastive_accuracy(self, features_1, features_2):
        """Update the contrastive accuracy metric
        self-supervised metric inspired by the SimCLR loss
        """

        # cosine similarity: the dot product of the l2-normalized feature vectors
        features_1 = keras.ops.normalize(features_1, axis=1)
        features_2 = keras.ops.normalize(features_2, axis=1)
        similarities = keras.ops.matmul(features_1, keras.ops.transpose(features_2))

        # Push positive pairs to the diagonal
        batch_size = keras.ops.shape(features_1)[0]
        contrastive_labels = keras.ops.arange(batch_size)
        self.contrastive_accuracy.update_state(contrastive_labels, similarities)
        self.contrastive_accuracy.update_state(contrastive_labels, keras.ops.transpose(similarities))

    def update_correlation_accuracy(self, features_1, features_2):
        """Update the correlation accuracy metric
        self-supervised metric inspired by the BarlowTwins loss
        """

        # normalization so that cross-correlation will be between -1 and 1
        features_1 = (features_1 - keras.ops.mean(features_1, axis=0)) / keras.ops.std(features_1, axis=0)
        features_2 = (features_2 - keras.ops.mean(features_2, axis=0)) / keras.ops.std(features_2, axis=0)

        # the cross correlation of image representations should be the identity matrix
        batch_size = keras.ops.shape(features_1)[0]
        batch_size = keras.ops.cast(batch_size, dtype="float32")
        print(features_1.shape, features_2.shape, batch_size)
        print("DBG0", features_1.shape)
        cross_correlation = keras.ops.matmul(keras.ops.transpose(features_1), features_2) / batch_size
        print("DBG1", cross_correlation.shape)
        feature_dim = keras.ops.shape(features_1)[1]
        print("DBG2", feature_dim)
        correlation_labels = keras.ops.arange(feature_dim)
        print("DBG3", correlation_labels.shape)
        self.correlation_accuracy.update_state(correlation_labels, cross_correlation)
        print("DBG4", cross_correlation.shape)
        self.correlation_accuracy.update_state(correlation_labels, keras.ops.transpose(cross_correlation))

    def train_step(self, data):
        """Training step for the model"""
        pair1, pair2 = data

        # each input is augmented twice, differently
        augmented_inputs_1 = self.contrastive_augmenter(pair1)
        augmented_inputs_2 = self.contrastive_augmenter(pair2)
        with tf.GradientTape() as tape:
            # Encoder phase
            features_1 = self.encoder(augmented_inputs_1)
            features_2 = self.encoder(augmented_inputs_2)
            # Projection phase
            projections_1 = self.projector(features_1)
            projections_2 = self.projector(features_2)
            contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        # END WITH

        # backpropagation
        gradients = tape.gradient(
            contrastive_loss,
            self.encoder.trainable_weights + self.projector.trainable_weights,
        )
        self.contrastive_optimizer.apply_gradients(
            zip(
                gradients,
                self.encoder.trainable_weights + self.projector.trainable_weights,
            )
        )

        self.contrastive_loss_tracker.update_state(contrastive_loss)

        self.update_contrastive_accuracy(features_1, features_2)
        self.update_correlation_accuracy(features_1, features_2)

        # # labels are only used in evalutation for probing
        # augmented_inputs = self.classification_augmenter(labeled_pair)
        # with tf.GradientTape() as tape:
        #     features = self.encoder(augmented_inputs)
        #     class_logits = self.linear_probe(features)
        #     probe_loss = self.probe_loss(labels, class_logits)
        # gradients = tape.gradient(probe_loss, self.linear_probe.trainable_weights)
        # self.probe_optimizer.apply_gradients(
        #     zip(gradients, self.linear_probe.trainable_weights)
        # )
        # self.probe_accuracy.update_state(labels, class_logits)

        return {m.name: m.result() for m in self.metrics}

    def test_step(self, data):
        """Test step for the model"""
        pair1, pair2 = data
        augmented_inputs_1 = self.contrastive_augmenter(pair1)
        augmented_inputs_2 = self.contrastive_augmenter(pair2)
        features_1 = self.encoder(augmented_inputs_1, training=False)
        features_2 = self.encoder(augmented_inputs_2, training=False)
        projections_1 = self.projector(features_1, training=False)
        projections_2 = self.projector(features_2, training=False)

        contrastive_loss = self.contrastive_loss(projections_1, projections_2)
        self.contrastive_loss_tracker.update_state(contrastive_loss)
        self.update_contrastive_accuracy(features_1, features_2)
        self.update_correlation_accuracy(features_1, features_2)

        return {m.name: m.result() for m in self.metrics}


class SimCLR(ContrastiveModel):
    """SimCLR model for self-supervised learning"""

    def __init__(
        self,
        encoder: keras.Model,
        projector: keras.Model,
        contrastive_augmenter: Callable[[keras.KerasTensor], keras.KerasTensor] | None = None,
        classification_augmenter: Callable[[keras.KerasTensor], keras.KerasTensor] | None = None,
        linear_probe: keras.Model | None = None,
        temperature: float = 0.1,
    ):
        super().__init__(
            encoder=encoder,
            projector=projector,
            contrastive_augmenter=contrastive_augmenter,
            classification_augmenter=classification_augmenter,
            linear_probe=linear_probe,
        )
        self.temperature = temperature

    def contrastive_loss(self, projections_1, projections_2):
        """Contrastive loss function for SimCLR"""
        # InfoNCE loss (information noise-contrastive estimation)
        # NT-Xent loss (normalized temperature-scaled cross entropy)

        # cosine similarity: the dot product of the l2-normalized feature vectors
        projections_1 = keras.ops.normalize(projections_1, axis=1)
        projections_2 = keras.ops.normalize(projections_2, axis=1)
        similarities = keras.ops.matmul(projections_1, keras.ops.transpose(projections_2)) / self.temperature

        # the temperature-scaled similarities are used as logits for cross-entropy
        batch_size = keras.ops.shape(projections_1)[0]
        contrastive_labels = keras.ops.arange(batch_size)
        loss1 = keras.losses.sparse_categorical_crossentropy(contrastive_labels, similarities, from_logits=True)
        loss2 = keras.losses.sparse_categorical_crossentropy(
            contrastive_labels, keras.ops.transpose(similarities), from_logits=True
        )
        return (loss1 + loss2) / 2


In [36]:
def get_scheduler():
    return keras.optimizers.schedules.CosineDecay(
        initial_learning_rate=learning_rate,
        decay_steps=steps_per_epoch * epochs,
    )


model = SimCLR(
    contrastive_augmenter=lambda x: x,
    encoder=encoder,
    projector=projector,
    # momentum_coeff=0.999,
    temperature=temperature,
    # queue_size=65536,
)

model.compile(
    contrastive_optimizer=keras.optimizers.Adam(get_scheduler()),
    probe_optimizer=keras.optimizers.Adam(get_scheduler()),
)

In [37]:
model.train_step((x, y))

(1024, 128) (1024, 128) tf.Tensor(1024.0, shape=(), dtype=float32)
DBG0 (1024, 128)
DBG1 (128, 128)
DBG2 128
DBG3 (128,)
DBG4 (128, 128)


{'loss': <tf.Tensor: shape=(), dtype=float32, numpy=6.931387>,
 'r_acc': <tf.Tensor: shape=(), dtype=float32, numpy=0.1015625>,
 'c_acc': <tf.Tensor: shape=(), dtype=float32, numpy=0.0014648438>}

In [None]:
keras.preprocessing