In [None]:
# ids_validation.py
"""
IDS-LLM Validation Module for SecureMed-LLM

Performs validation of generated clinical reports using:
1. Rule-based checks for forbidden terms and contradictions
2. Clinical term verification
3. Anomaly detection using Isolation Forest
4. Overall IDS decision and evaluation

Author: Aya Lolo
"""

import pandas as pd
import numpy as np
from sklearn.ensemble import IsolationForest
from sentence_transformers import SentenceTransformer
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

# ========================
# Step 1: Load validation reports
# ========================
df = pd.read_csv("D:/masterdata/100_validation_reports.csv")
df = df[["generated", "label"]].dropna().rename(columns={"generated": "report"})
print(f"‚úÖ Loaded {len(df)} reports.")

# ========================
# Step 2: Rule-based validation
# ========================
forbidden_terms = ["magic", "alien", "miracle", "healed completely", "unreal", "all fine no matter what"]

def has_location(text):
    return any(t in text.lower() for t in ["right", "left", "bilateral", "lower", "upper", "middle", "lobe"])

def has_finding(text):
    return any(t in text.lower() for t in ["consolidation", "opacity", "effusion", "pneumonia", "cardiomegaly", "atelectasis"])

def has_contradiction(text):
    return "lungs are clear" in text.lower() and "opacity" in text.lower()

def rule_based_verification(text):
    if any(t in text.lower() for t in forbidden_terms):
        return False
    if has_contradiction(text):
        return False
    return has_location(text) and has_finding(text)

df["rule_pass"] = df["report"].apply(rule_based_verification)

# ========================
# Step 3: Clinical term check
# ========================
valid_terms = ["pneumonia", "cardiomegaly", "pleural effusion", "atelectasis", "opacity", "consolidation"]

def clinical_check(text):
    return any(term in text.lower() for term in valid_terms)

df["clinical_pass"] = df["report"].apply(clinical_check)

# ========================
# Step 4: Anomaly detection
# ========================
train_data = pd.read_csv("D:/masterdata/df_train.csv").dropna()
real_reports = train_data["text"].astype(str).tolist()

encoder = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
X_train = encoder.encode(real_reports, show_progress_bar=True)
X_test = encoder.encode(df["report"].tolist(), show_progress_bar=True)

iso = IsolationForest(contamination=0.1, random_state=42)
iso.fit(X_train)
df["anomaly_pass"] = (iso.predict(X_test) == 1)

# ========================
# Step 5: Final IDS decision
# ========================
df["IDS_decision"] = df[["rule_pass", "clinical_pass", "anomaly_pass"]].all(axis=1).astype(int)

# ========================
# Step 6: Confusion matrix and metrics
# ========================
y_true = df["label"].astype(int)
y_pred = df["IDS_decision"]

cm = confusion_matrix(y_true, y_pred, labels=[1, 0])
tp, fn = cm[0]
fp, tn = cm[1]

rule_pass_rate = 100 * df["rule_pass"].mean()
clinical_pass_rate = 100 * df["clinical_pass"].mean()
anomaly_pass_rate = 100 * df["anomaly_pass"].mean()
overall_pass_rate = 100 * ((tp + tn) / len(df))
fp_rate = 100 * (fp / (fp + tn))

print("\nüîé IDS Performance Summary:")
print(f"Rule-Based Pass Rate: {rule_pass_rate:.2f}%")
print(f"Clinical Check Pass Rate: {clinical_pass_rate:.2f}%")
print(f"Anomaly Detection Pass Rate: {anomaly_pass_rate:.2f}%")
print(f"‚úÖ Overall IDS Pass Rate: {overall_pass_rate:.2f}%")
print(f"üö´ False Positive Rate: {fp_rate:.2f}%")
print(f"üìä Confusion Matrix: TP={tp}, FN={fn}, FP={fp}, TN={tn}")

# ========================
# Step 7: Save results
# ========================
df.to_csv("D:/masterdata/validation_with_ids_results.csv", index=False)
print("\nüìÅ Results saved to: validation_with_ids_results.csv")

# ========================
# Step 8: Visualization
# ========================
modules = ["Rule-Based", "Clinical Params", "Anomaly Detection"]
pass_rates = [rule_pass_rate, clinical_pass_rate, anomaly_pass_rate]

x = np.arange(len(modules))
width = 0.35

plt.figure(figsize=(7, 4))
plt.bar(x, pass_rates, width, color="seagreen")
plt.xticks(x, modules)
plt.ylim(0, 100)
plt.ylabel("Pass %")
plt.title("IDS-LLM Module Performance (100 Reports)")
plt.grid(axis="y", linestyle=":")
plt.tight_layout()
plt.savefig("D:/masterdata/ids_validation_performance.png", dpi=300)
plt.show()
