In [None]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

# Calculate detailed metrics
binary_acc = accuracy_score(y_test_binary, binary_preds)
binary_prec, binary_rec, binary_f1, _ = precision_recall_fscore_support(
    y_test_binary, binary_preds, average='weighted'
)

multi_acc = accuracy_score(y_test_multiclass, multi_preds)
multi_prec, multi_rec, multi_f1, _ = precision_recall_fscore_support(
    y_test_multiclass, multi_preds, average='weighted'
)

# Per-class metrics for multiclass
per_class_prec, per_class_rec, per_class_f1, _ = precision_recall_fscore_support(
    y_test_multiclass, multi_preds, average=None
)

# Create metadata
model_metadata = {
    "input_dim": input_dim,
    "hidden_dims": [256, 128, 64],
    "dropout_rate": 0.3,
    "model_type": "deep_neural_network",
    "framework": "pytorch",
    "binary_params": count_parameters(binary_model),
    "multiclass_params": count_parameters(multiclass_model)
}

metrics = {
    "binary": {
        "accuracy": float(binary_acc),
        "precision": float(binary_prec),
        "recall": float(binary_rec),
        "f1": float(binary_f1),
        "confusion_matrix": confusion_matrix(y_test_binary, binary_preds).tolist()
    },
    "multiclass": {
        "accuracy": float(multi_acc),
        "precision": float(multi_prec),
        "recall": float(multi_rec),
        "f1": float(multi_f1),
        "per_class": {
            "normal": {"precision": float(per_class_prec[0]), "recall": float(per_class_rec[0]), "f1": float(per_class_f1[0])},
            "dos": {"precision": float(per_class_prec[1]), "recall": float(per_class_rec[1]), "f1": float(per_class_f1[1])},
            "probe": {"precision": float(per_class_prec[2]), "recall": float(per_class_rec[2]), "f1": float(per_class_f1[2])},
            "r2l": {"precision": float(per_class_prec[3]), "recall": float(per_class_rec[3]), "f1": float(per_class_f1[3])},
            "u2r": {"precision": float(per_class_prec[4]), "recall": float(per_class_rec[4]), "f1": float(per_class_f1[4])}
        },
        "confusion_matrix": confusion_matrix(y_test_multiclass, multi_preds).tolist()
    }
}

preprocessing_info = {
    "num_features": len(preprocessor.feature_names),
    "scaler": "StandardScaler",
    "categorical_encoding": "one-hot",
    "categorical_features": ["protocol_type", "service", "flag"],
    "preprocessor_path": "../data/processed/preprocessor.json"
}

save_model_metadata(
    "../models/model_metadata.json",
    model_metadata,
    metrics,
    preprocessing_info
)

print("\n" + "=" * 60)
print("FINAL TEST SET RESULTS")
print("=" * 60)
print(f"Binary Classification:")
print(f"  Accuracy:  {binary_acc:.4f}")
print(f"  Precision: {binary_prec:.4f}")
print(f"  Recall:    {binary_rec:.4f}")
print(f"  F1-Score:  {binary_f1:.4f}")
print(f"\nMulticlass Classification:")
print(f"  Accuracy:  {multi_acc:.4f}")
print(f"  Precision: {multi_prec:.4f}")
print(f"  Recall:    {multi_rec:.4f}")
print(f"  F1-Score:  {multi_f1:.4f}")
print("=" * 60)
