# Import libraries 

In [None]:
# Jax dedicated libraries
from flax import nnx
import optax
import jax
import jax.numpy as jnp # From this point on, there should not be numpy anymore but only jax.numpy
import jax.scipy as jsp
import orbax.checkpoint as ocp  # Checkpointing library

# Plotting libraries
import matplotlib.pyplot as plt

# Module functions
import ximinf.nn_train as nnt

# Import training data

In [None]:
# Import data from local path
data = jnp.load("../data/simulations.npy")

N = jnp.shape(data)[0]
M = (jnp.shape(data)[1]-3)/4



# Build a neural network

In [None]:
# Define the size of the different network layers
Nsize_p = 2*M
Nsize_r = 20*M
phi_batch = 1

model = nnt.DeepSetClassifier(0.05, Nsize_p, Nsize_r, phi_batch, rngs=nnx.Rngs(0))

# Visualize the model structure
nnx.display(model)

In [None]:
# Initialise metrics history
metrics_history = {'train_loss': [], 'train_accuracy': [], 'test_loss': [], 'test_accuracy': []}

In [None]:
# Define the learning rate schedule
learning_rate_schedule = optax.exponential_decay(
    init_value=3e-4,
    transition_steps=1000,  # Decay every 1000 forward passes
    decay_rate=0.9,
)

momentum = 0.9 # Necessary for the Adam optimiser

# Initialize optimiser with the adaptive learning rate
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate_schedule, momentum))

# Train NN

In [None]:
# Early stopping parameters
patience = 20 # Number of epochs to wait for improvement
epochs = 1000 # Maximum number of epochs

# batch_size = N//200 # Number of samples per batch
batch_size = 1000

# Initialise stopping criteria
best_train_loss = jnp.inf
best_test_loss = jnp.inf
strikes = 0

model.train()
for epoch in range(epochs):
    # Shuffle the training data using JAX.
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, len(train_data_gpu))
    train_data_gpu = train_data_gpu[perm]
    train_labels_gpu = train_labels_gpu[perm]
    del perm
    
    epoch_train_loss = 0
    epoch_train_accuracy = 0
    
    for i in range(0, len(train_data_gpu), batch_size):
        # Get the current batch of data and labels
        batch_data = train_data_gpu[i:i+batch_size]
        batch_labels = train_labels_gpu[i:i+batch_size]
        
        # Perform a training step
        loss, _ = nnt.loss_fn(model, (batch_data, batch_labels))
        accuracy = nnt.accuracy_fn(model, (batch_data, batch_labels))
        epoch_train_loss += loss
        epoch_train_accuracy += accuracy
        nnt.train_step(model, optimizer, (batch_data, batch_labels))
    
    # Log the training metrics.
    current_train_loss = epoch_train_loss / (len(train_data_gpu) / batch_size)
    metrics_history['train_loss'].append(current_train_loss)
    metrics_history['train_accuracy'].append(epoch_train_accuracy/(len(train_data_gpu) / batch_size))

    epoch_test_loss = 0
    epoch_test_accuracy = 0

    # Compute the metrics on the test set using the same batching as training
    for i in range(0, len(test_data_gpu), batch_size):
        batch_data = test_data_gpu[i:i+batch_size]
        batch_labels = test_labels_gpu[i:i+batch_size]

        loss, _ = nnt.loss_fn(model, (batch_data, batch_labels))
        accuracy = nnt.accuracy_fn(model, (batch_data, batch_labels))
        epoch_test_loss += loss
        epoch_test_accuracy += accuracy

    # Log the test metrics.
    current_test_loss = epoch_test_loss / (len(test_data_gpu) / batch_size)
    metrics_history['test_loss'].append(current_test_loss)
    metrics_history['test_accuracy'].append(epoch_test_accuracy/ (len(test_data_gpu) / batch_size))
    
    # Early Stopping Check
    if current_test_loss < best_test_loss:
        best_test_loss = current_test_loss  # Update best test loss
        strikes = 0
    elif current_train_loss >= best_train_loss:
        strikes = 0
    elif current_test_loss > best_test_loss and current_train_loss < best_train_loss:
        strikes +=1
    elif current_train_loss < best_train_loss:
        best_train_loss = current_train_loss # Update best train loss

    if strikes >= patience:
        print(f"\n Early stopping at epoch {epoch+1} due to {patience} consecutive increases in loss gap \n")
        break

    if epoch%5 == 0:
        #Plot loss and accuracy in subplots
        clear_output(wait=True) # Clear the output to avoid cluttering
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
        ax1.set_title(f'Loss for M:{M} and N:{N}')
        for dataset in ('train', 'test'):
            ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
            ax1.legend()
            ax1.set_yscale("log")
    
        ax2.set_title('Accuracy')
        for dataset in ('train', 'test'):
            ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
            ax2.legend()
        plt.show()

# Test NN

In [None]:
model.eval()  # disable dropout, etc.

batch_size = 128  # tune this to fit your RAM; lower â†’ safer

# Accumulators
all_logits = []
all_labels = []

# Loop over your test set in-place
num_samples = test_data_gpu.shape[0]
for i in range(0, num_samples, batch_size):
    xb = test_data_gpu[i : i + batch_size]
    yb = test_labels_gpu[i : i + batch_size]

    # Get logits for this mini-batch
    logits = nnt.pred_step(model, xb)
    all_logits.append(logits)

    # Store the corresponding true labels
    all_labels.append(yb > 0.5)

# Merge everything back together
all_logits = jnp.concatenate(all_logits, axis=0)
all_preds  = jsp.special.expit(all_logits) > 0.5
all_labels = jnp.concatenate(all_labels, axis=0)

# === same metrics computation as before ===
TP = jnp.sum((all_preds == 1) & (all_labels == 1))
TN = jnp.sum((all_preds == 0) & (all_labels == 0))
FP = jnp.sum((all_preds == 1) & (all_labels == 0))
FN = jnp.sum((all_preds == 0) & (all_labels == 1))

print(f"True positives : {TP}")
print(f"True negatives : {TN}")
print(f"False positives: {FP}")
print(f"False negatives: {FN}\n")

accuracy    = (TP + TN) / (TP + TN + FP + FN)
precision   = TP / (TP + FP)
sensitivity = TP / (TP + FN)
specificity = TN / (TN + FP)

print(f"Accuracy   : {accuracy:.3f}")
print(f"Precision  : {precision:.3f}")
print(f"Sensitivity: {sensitivity:.3f}")
print(f"Specificity: {specificity:.3f}")

# Save NN to disk

In [None]:
ckpt_dir = ocp.test_utils.erase_and_create_empty('/Users/atrigui/Documents/Stage/sbi_stage_trigui/saved_model_12_7_V7')

# Split the model into GraphDef (structure) and State (parameters + buffers)
_, rng_key, rng_count, state = nnx.split(model, nnx.RngKey, nnx.RngCount, ...)

# Display for debugging (optional)
nnx.display(state)

# Initialize the checkpointer
checkpointer = ocp.StandardCheckpointer()

# Save State (parameters & non-trainable variables)
checkpointer.save(ckpt_dir / 'state', state)