<a href="https://colab.research.google.com/github/BenRodriguez1029/bayesflow-model-comparison/blob/main/simple_classification_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [156]:
import os

os.environ["KERAS_BACKEND"] = "jax"
import keras
from keras.utils import Sequence
from keras.layers import Layer, Dense, GlobalAveragePooling1D, Reshape
from keras.initializers import RandomNormal
from keras.models import Sequential
import keras.ops as ops

import numpy as np
import bayesflow as bf
import random

from losses import logistic_loss, exponential_loss, alpha_log_exponential_loss

print(keras.backend.backend())

jax


### Simulator and Data Preparation

In [176]:
# create model simulators

def prior_alternative():
    return np.random.normal(loc=0, scale=1)

def sample_model_0(sample_size, n=30):
    samples = np.random.normal(loc=0, scale=1, size=(sample_size, n))
    return samples

def sample_model_1(sample_size, n=30):
    mus = np.array([prior_alternative() for _ in range(sample_size)])

    samples = np.random.normal(loc=mus[:, None], scale=1, size=(sample_size, n))
    return samples

In [177]:
class MyDataset(Sequence):
    def __init__(self, X, y, batch_size=32, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.X = X
        self.y = y
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(X))
        self.on_epoch_end()

    def __len__(self):
        return (len(self.X) + self.batch_size - 1) // self.batch_size

    def __getitem__(self, index):
        batch_indices = self.indices[index*self.batch_size : (index+1)*self.batch_size]
        X_batch = self.X[batch_indices]
        y_batch = self.y[batch_indices]
        return X_batch, y_batch

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

In [178]:
# simulate data
sample_size = 2000

data_0 = sample_model_0(sample_size)
data_1 = sample_model_1(sample_size)

split = int(0.8 * data_0.shape[0])
data_0_train, data_0_val = data_0[:split], data_0[split:]
data_1_train, data_1_val = data_1[:split], data_1[split:]

data_train = np.concatenate([data_0_train, data_1_train], axis=0)
data_val = np.concatenate([data_0_val, data_1_val], axis=0)

label_train = np.concatenate([np.zeros(data_0_train.shape[0]), np.ones(data_1_train.shape[0])], axis=0)
label_val = np.concatenate([np.zeros(data_0_val.shape[0]), np.ones(data_1_val.shape[0])], axis=0)

perm = np.random.permutation(len(data_train))
data_train = data_train[perm]
label_train = label_train[perm]

# dataloader
batch_size = 32

train_loader = MyDataset(data_train[:, None, :], label_train[:, None], batch_size=32)
val_loader   = MyDataset(data_val[:, None, :], label_val[:, None], batch_size=32, shuffle=False)

### Network

In [179]:
class EvidenceNetwork(keras.Model):
    def __init__(self, output, **kwargs):
        super().__init__(**kwargs)

        # shared backbone network
        self.summary_network = bf.networks.DeepSet(summary_dim=8, dropout=None)
        self.classification_network = bf.networks.MLP(widths=[32] * 4, activation="silu", dropout=None)

        # output layer depends on output type
        if output == "p":
            self.output_layer = Dense(1, activation="sigmoid",
                                      kernel_initializer=RandomNormal(mean=0.0, stddev=0.01))  # probability 0-1
        elif output == "K":
            self.output_layer = Dense(1,activation="softplus",
                                      kernel_initializer=RandomNormal(mean=np.log(np.exp(1)-1), stddev=0.01))  # strictly positive
        elif output == "log(K)":
            self.output_layer = Dense(1, activation=None,
                                       kernel_initializer=RandomNormal(mean=0.0, stddev=0.01))  # unbounded
        else:
            raise ValueError("Invalid output type")


    def call(self, inputs, training=False):
        x = self.summary_network(inputs, training=training)
        x = self.classification_network(x, training=training)
        return self.output_layer(x)

In [180]:
def bayes_accuracy(y_true, f_x):
    p = ops.sigmoid(f_x)
    preds = ops.cast(p > 0.5, "float32")
    return ops.mean(ops.cast(ops.equal(preds, ops.cast(y_true, "float32")), "float32"))

def alpha_log_exponential_accuracy(y_true, y_pred, alpha=2.0):
    y_true = ops.cast(y_true, dtype='float32')
    pred_labels = ops.cast(ops.greater(y_pred, 0.5), dtype='float32')
    correct = ops.equal(pred_labels, y_true)
    return ops.mean(ops.cast(correct, dtype='float32'), axis=-1)

classifier_log = EvidenceNetwork("log(K)")
classifier_log.compile(optimizer="adam", loss=logistic_loss, metrics=[bayes_accuracy])

classifier_ce = EvidenceNetwork("p")
classifier_ce.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])

classifier_exp = EvidenceNetwork("log(K)")
classifier_exp.compile(optimizer="adam", loss=exponential_loss, metrics=[bayes_accuracy])

classifier_alpha_exp_log = EvidenceNetwork("K")
classifier_alpha_exp_log.compile(optimizer="adam", loss=alpha_log_exponential_loss, metrics=[alpha_log_exponential_accuracy])

### Training

In [181]:

epochs = 20

history_log = classifier_log.fit(
    train_loader,
    validation_data=val_loader,
    epochs=epochs,
)

history_ce = classifier_ce.fit(
    train_loader,
    validation_data=val_loader,
    epochs=epochs,
)

history_exp = classifier_exp.fit(
    train_loader,
    validation_data=val_loader,
    epochs=epochs,
)

history_alpha_exp_log = classifier_alpha_exp_log.fit(
    train_loader,
    validation_data=val_loader,
    epochs=epochs,
)

Epoch 1/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 89ms/step - bayes_accuracy: 0.6733 - loss: 0.6017 - val_bayes_accuracy: 0.7950 - val_loss: 0.4960
Epoch 2/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 5ms/step - bayes_accuracy: 0.8005 - loss: 0.4433 - val_bayes_accuracy: 0.7862 - val_loss: 0.4843
Epoch 3/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - bayes_accuracy: 0.8288 - loss: 0.3962 - val_bayes_accuracy: 0.7912 - val_loss: 0.4406
Epoch 4/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - bayes_accuracy: 0.8366 - loss: 0.3834 - val_bayes_accuracy: 0.8037 - val_loss: 0.4386
Epoch 5/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 5ms/step - bayes_accuracy: 0.8521 - loss: 0.3576 - val_bayes_accuracy: 0.7950 - val_loss: 0.4515
Epoch 6/20
[1m100/100[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - bayes_accuracy: 0.8562 - loss: 0.3480 - 

### Testing

In [182]:
def test_model_per_class_by_output(model, output_type="p", sample_size=100, n=30):
    if output_type == "p":
        def get_label(y_pred):
            if isinstance(y_pred, tuple):
                y_pred = y_pred[0]
            return int(tf.cast(y_pred > 0.5, tf.int32).numpy()[0,0])
    elif output_type == "K":
        def get_label(y_pred):
            if isinstance(y_pred, tuple):
                y_pred = y_pred[0]
            return int(tf.cast(y_pred > 1.0, tf.int32).numpy()[0,0])
    elif output_type == "log(K)":
        def get_label(y_pred):
            if isinstance(y_pred, tuple):
                y_pred = y_pred[0]
            return int(tf.cast(y_pred > 0.0, tf.int32).numpy()[0,0])
    else:
        raise ValueError("Invalid output_type. Must be 'p', 'K', or 'log(K)'.")

    correct_0 = 0
    total_0 = 0
    correct_1 = 0
    total_1 = 0

    for i in range(sample_size):
        if i < sample_size // 2:
            x = sample_model_0(1)
            y_true = 0
        else:
            x = sample_model_1(1)
            y_true = 1

        x_input = x[None, :, :]
        x_input_tf = tf.convert_to_tensor(x_input, dtype=tf.float32)

        y_pred = model(x_input_tf)
        pred_label = get_label(y_pred)

        if y_true == 0:
            total_0 += 1
            if pred_label == 0:
                correct_0 += 1
        else:
            total_1 += 1
            if pred_label == 1:
                correct_1 += 1

    acc_0 = correct_0 / total_0
    acc_1 = correct_1 / total_1

    print(f"Model 0 accuracy: {acc_0*100:.2f}% ({correct_0}/{total_0})")
    print(f"Model 1 accuracy: {acc_1*100:.2f}% ({correct_1}/{total_1})")

    return acc_0, acc_1

In [183]:
print("Logistic")
test_model_per_class_by_output(classifier_log, output_type="log(K)", sample_size=1000, n=30)
print()

print("Cross Entropy")
test_model_per_class_by_output(classifier_ce, output_type="p", sample_size=1000, n=30)
print()

print("Exponential")
test_model_per_class_by_output(classifier_exp, output_type="log(K)", sample_size=1000, n=30)
print()

print("Alpha Exponential Log")
test_model_per_class_by_output(classifier_alpha_exp_log, output_type="K", sample_size=1000, n=30)
print()

Logistic
Model 0 accuracy: 86.40% (432/500)
Model 1 accuracy: 69.40% (347/500)

Cross Entropy
Model 0 accuracy: 89.60% (448/500)
Model 1 accuracy: 71.80% (359/500)

Exponential
Model 0 accuracy: 95.20% (476/500)
Model 1 accuracy: 49.80% (249/500)

Alpha Exponential Log
Model 0 accuracy: 100.00% (500/500)
Model 1 accuracy: 0.00% (0/500)

