# 📊 Evaluation & Visualization for BERT-GAN
This notebook contains post-training evaluation metrics and visualizations.

In [None]:
# ✅ Step 1: Import Libraries
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, confusion_matrix, roc_curve, precision_recall_curve
import numpy as np
import tensorflow as tf
from transformers import DistilBertTokenizer
import pandas as pd


In [None]:
# ✅ Step 2: Load model and test data
model = tf.keras.models.load_model('../bert_gan_full_model_FINAL.py', compile=False)
data = pd.read_csv('../data/data.csv')
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
inputs = tokenizer(data['url'].tolist(), max_length=128, truncation=True, padding='max_length', return_tensors='tf')
X = inputs['input_ids'].numpy()
y = data['label'].map({'bad': 0, 'good': 1}).values
X_test = tf.convert_to_tensor(X, dtype=tf.int32)
y_test = tf.convert_to_tensor(y, dtype=tf.float32)
z_noise = tf.random.normal((len(X_test), 100))
y_probs = model.predict([X_test, z_noise]).flatten()
y_preds = (y_probs > 0.5).astype(int)

In [None]:
# ✅ Step 3: Print Evaluation Metrics
print("Accuracy:", accuracy_score(y_test, y_preds))
print("F1 Score:", f1_score(y_test, y_preds))
print("AUC:", roc_auc_score(y_test, y_probs))

In [None]:
# ✅ Step 4: Confusion Matrix
cm = confusion_matrix(y_test, y_preds)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues')
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
plt.show()

In [None]:
# ✅ Step 5: ROC Curve
fpr, tpr, _ = roc_curve(y_test, y_probs)
plt.plot(fpr, tpr, label='ROC Curve')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
# ✅ Step 6: Precision-Recall Curve
precision, recall, _ = precision_recall_curve(y_test, y_probs)
plt.plot(recall, precision, label='Precision-Recall Curve')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend()
plt.grid(True)
plt.show()