In [None]:
import os
import torch
import shap
import numpy as np
import nltk
from nltk import sent_tokenize
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification, Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
from sklearn.metrics import classification_report, f1_score
from shap.maskers import Text

# === Setup NLTK ===
nltk_data_path = os.path.join(os.getcwd(), 'nltk_data')
os.makedirs(nltk_data_path, exist_ok=True)
try:
    nltk.download('punkt_tab', download_dir=nltk_data_path)
except:
    nltk.download('punkt', download_dir=nltk_data_path)
nltk.data.path.append(nltk_data_path)

# === Device setup ===
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🚀 Using device: {device}")

# === Load and clean dataset ===
def get_cleaned_dataset(include_instruction=False):
    prefix = "Instruction: Classify the following news article as real or fake.\n\nInput: "
    suffix = "\n\nOutput: fake"
    l_pre, l_suf = len(prefix), len(suffix)
    dataset = load_dataset("Hasib18/fake-news-dataset")
    for split in dataset:
        if not include_instruction:
            dataset[split] = dataset[split].map(lambda x: {"text": x["text"][l_pre:-l_suf]})
    return dataset

dataset = get_cleaned_dataset()
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=512)

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
tokenized_dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

# === LoRA-wrapped DistilBERT ===
base_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=["q_lin", "v_lin"],
    lora_dropout=0.1,
    bias="none",
    task_type=TaskType.SEQ_CLS
)
model = get_peft_model(base_model, lora_config)
model.to(device)
model.eval()
model.print_trainable_parameters()

# === Training config ===
training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
    num_train_epochs=2,
    fp16=torch.cuda.is_available(),
    save_strategy="no",
    logging_steps=1000,
    report_to="none",
    weight_decay=0.01,
    logging_dir="./logs"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
)

# === Train and save ===
print("\n🚀 Training LoRA-augmented DistilBERT...")
trainer.train()
trainer.save_model("./trained_distilbert_lora")

# === Evaluation ===
preds_output = trainer.predict(tokenized_dataset["test"])
preds = np.argmax(preds_output.predictions, axis=1)
labels = preds_output.label_ids
print(f"\n🎯 F1 Score: {f1_score(labels, preds):.4f}")
print("\n📋 Classification Report:")
print(classification_report(labels, preds, target_names=["FAKE", "REAL"]))

# === SHAP wrapper ===
def wrapped_model(texts):
    if isinstance(texts, (str, np.generic)) or not isinstance(texts, list):
        texts = [str(t) for t in np.atleast_1d(texts)]
    encodings = tokenizer(texts, return_tensors="pt", truncation=True, padding=True, max_length=512)
    encodings = {k: v.to(device) for k, v in encodings.items()}
    with torch.no_grad():
        logits = model(**encodings).logits
        probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs.cpu().numpy()

# === SHAP Sentence Explanation ===
article = """
Trump claimed the election was stolen and widespread fraud occurred.
Officials from both parties denied these claims and upheld the results.
An independent audit found no evidence of tampering.
Social media platforms flagged the original post for misinformation.
"""
sentences = sent_tokenize(article)

print("\n📝 Sentences:")
for i, s in enumerate(sentences):
    print(f"{i+1}. {s}")

print("\n🔍 Explaining with SHAP (terminal version)...")
text_masker = Text(tokenizer)
explainer = shap.Explainer(wrapped_model, text_masker)
shap_values = explainer(sentences)

# Predicted class
pred_class = np.argmax(np.mean(wrapped_model(sentences), axis=0))
class_name = ['FAKE', 'REAL'][pred_class]
print(f"\n📊 Predicted class: {class_name}")

# ✅ FINAL FIX FOR FORMAT ERROR
contributions = shap_values.values

print(f"\n🧾 Sentence-level SHAP values for class: {class_name}")
print(f"🔍 shap_values.values shape: {contributions.shape}")
print(f"🔍 type of shap_values.values: {type(contributions)}")

for i, (sentence, value) in enumerate(zip(sentences, contributions)):
    print(f"\n{i+1}. {sentence}")
    print(f"   Raw value shape: {value.shape}")
    try:
        scalar = float(value[:, pred_class].sum())  # Aggregate token contributions for predicted class
        print(f"   ✅ Contribution to {class_name}: {scalar:.4f}")
    except Exception as e:
        print(f"   ❌ Error: {e}")
