In [16]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import load_model
from tensorflow.keras.datasets import cifar10, cifar100
from sklearn.metrics import roc_auc_score

In [17]:
model = load_model('cifar10_improved_model_50.h5')

# Loading CIFAR-10 dataset (in-distribution)
(x_train_cifar10, y_train_cifar10), (x_test_cifar10, y_test_cifar10) = cifar10.load_data()
y_test_cifar10 = tf.keras.utils.to_categorical(y_test_cifar10, 10)
x_test_cifar10 = x_test_cifar10.astype('float32') / 255.0

# Evaluate the model on CIFAR-10 to check its performance on in-distribution data
cifar10_score = model.evaluate(x_test_cifar10, y_test_cifar10, verbose=0)
print(f"CIFAR-10 Model Evaluation Score (Loss, Accuracy): {cifar10_score}")



CIFAR-10 Model Evaluation Score (Loss, Accuracy): [0.49271273612976074, 0.8348000049591064]


In [18]:
# Load CIFAR-100 dataset (out-of-distribution)
(x_train_cifar100, y_train_cifar100), (x_test_cifar100, y_test_cifar100) = cifar100.load_data()
x_test_cifar100 = x_test_cifar100.astype('float32') / 255.0
cifar10_preds = model.predict(x_test_cifar10)
cifar100_preds = model.predict(x_test_cifar100)

# Since the model was trained on CIFAR-10, the CIFAR-100 data is out-of-distribution.
# We can use the maximum probability of the predicted class as a measure of "confidence"
cifar10_confidence = np.max(cifar10_preds, axis=1)
cifar100_confidence = np.max(cifar100_preds, axis=1)

# Combine confidence scores for AUROC calculation (in-distribution and out-of-distribution)
confidence_scores = np.concatenate([cifar10_confidence, cifar100_confidence])
labels = np.concatenate([np.ones(len(cifar10_confidence)), np.zeros(len(cifar100_confidence))])

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 8ms/step
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 7ms/step


In [19]:
# Compute AUROC
auroc_score = roc_auc_score(labels, confidence_scores)
print(f"AUROC for Out-of-Distribution Detection: {auroc_score:.4f}")

AUROC for Out-of-Distribution Detection: 0.7801
