In [None]:
from tensorflow import keras
import tensorflow as tf
from tornet.data.loader import get_dataloader
from tornet.metrics.keras import metrics as tfm
import os
import argparse
import logging
import matplotlib.pyplot as plt
from custom_func import FalseAlarmRate, ThreatScore
import tensorflow_datasets as tfds
from tornet.models.keras.layers import CoordConv2D
import tornet.data.tfds.tornet.tornet_dataset_builder  # registers 'tornet'
import tqdm

# Setup environment
DATA_ROOT = TFDS_DATA_DIR = TORNET_ROOT = "/home/ubuntu/tfds"
os.environ['TORNET_ROOT'] = DATA_ROOT
os.environ['TFDS_DATA_DIR'] = TFDS_DATA_DIR

# Register custom layers
@keras.utils.register_keras_serializable()
class SpatialAttention(keras.layers.Layer):
    def __init__(self, kernel_size=7, **kwargs):
        super().__init__(**kwargs)
        self.kernel_size = kernel_size
        self.conv = keras.layers.Conv2D(1, kernel_size=self.kernel_size, strides=1, padding='same', activation='sigmoid')

    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=-1, keepdims=True)
        max_pool = tf.reduce_max(x, axis=-1, keepdims=True)
        attention = self.conv(tf.concat([avg_pool, max_pool], axis=-1))
        return keras.layers.Multiply()([x, attention])

@keras.utils.register_keras_serializable()
class ChannelAttention(keras.layers.Layer):
    def __init__(self, ratio=16, **kwargs):
        super().__init__(**kwargs)
        self.dense1 = keras.layers.Dense(1, activation='sigmoid')

    def call(self, x):
        avg_pool = tf.reduce_mean(x, axis=[1, 2], keepdims=True)
        attention = self.dense1(avg_pool)
        return keras.layers.Multiply()([x, attention])

@keras.utils.register_keras_serializable()
class FillNaNs(keras.layers.Layer):
    def __init__(self, fill_val, **kwargs):
        super().__init__(**kwargs)
        self.fill_val = tf.convert_to_tensor(fill_val, dtype=tf.float32)

    @tf.function(jit_compile=True)
    def call(self, x):
        return tf.where(tf.math.is_nan(x), self.fill_val, x)

@keras.utils.register_keras_serializable()
class FastNormalize(keras.layers.Layer):
    def __init__(self, mean, std, **kwargs):
        super().__init__(**kwargs)
        self.mean = tf.convert_to_tensor(mean, dtype=tf.float32)
        self.std = tf.convert_to_tensor(std, dtype=tf.float32)
        self._mean_list = list(mean)
        self._std_list = list(std)

    def call(self, x):
        return (x - self.mean) / (self.std + 1e-6)

@keras.utils.register_keras_serializable()
class ExpandDimsTwice(keras.layers.Layer):
    def call(self, inputs):
        return tf.expand_dims(tf.expand_dims(inputs, axis=1), axis=1)

@keras.utils.register_keras_serializable()
class StackAvgMax(tf.keras.layers.Layer):
    def call(self, inputs):
        return tf.stack(inputs, axis=1)

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--models_dir", type=str, required=True, help="Directory containing .keras model files")
    parser.add_argument("--threshold", type=float, default=0.5, help="Binary threshold for predictions")
    parser.add_argument("--dataloader", type=str, default="tensorflow-tfds", choices=["keras", "tensorflow", "tensorflow-tfds", "torch", "torch-tfds"])
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    logging.info(f"Using backend: {keras.config.backend()}")
    logging.info(f"Using dataloader: {args.dataloader}")

    model_paths = [os.path.join(args.models_dir, f) for f in os.listdir(args.models_dir) if f.endswith('.keras')]
    if not model_paths:
        raise ValueError(f"No .keras models found in {args.models_dir}")

    # Load data once
    ds_test = get_dataloader(args.dataloader, DATA_ROOT, range(2020, 2023), "train", 128,
                             weights={'wN': 1.0, 'w0': 1.0, 'w1': 1.0, 'w2': 1.0, 'wW': 1.0})

    plt.figure(figsize=(10, 8))
    
    for path in model_paths:
        logging.info(f"Evaluating model: {path}")
        model = keras.models.load_model(path, safe_mode=False, compile=False)

        all_labels = []
        all_preds = []

        for batch in tqdm.tqdm(ds_test, desc=f"Evaluating {os.path.basename(path)}"):
            inputs, labels, _ = batch
            preds = model.predict_on_batch(inputs)
            all_labels.append(labels)
            all_preds.append(preds)

        all_labels = tf.concat(all_labels, axis=0)
        all_preds = tf.concat(all_preds, axis=0)

        # Compute PR curve
        precision, recall, _ = tf.metrics.PrecisionRecallAtRecall(thresholds=2000)(all_labels, all_preds)
        precision = precision.numpy()
        recall = recall.numpy()
        auc = tf.keras.metrics.AUC(curve='PR', num_thresholds=2000)
        auc.update_state(all_labels, all_preds)
        pr_auc = auc.result().numpy()

        plt.plot(recall, precision, label=f"{os.path.basename(path)} (PR AUC: {pr_auc:.4f})")

    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('PR AUC Curve Comparison')
    plt.legend(loc='lower left')
    plt.grid(True)
    plt.tight_layout()
    plt.show()

if __name__ == "__main__":
    main()
