# YCITE: Model with Explanation

This notebook consist of the code that are used to evaluate the final result

### The next section runs the model (with LIME), produce the outcome for the dataset and determine the results. 

In [None]:
# The actual experiment that find the cosine similarity score for everything

import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from lime.lime_text import LimeTextExplainer
import spacy
from tqdm import tqdm
from sentence_transformers import SentenceTransformer, util
import signal
import random
import string

# Timeout handler for long-running LIME explanations
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

# Load SpaCy's English model and stop words
nlp = spacy.load("en_core_web_sm")
stop_words = nlp.Defaults.stop_words
punctuation_table = str.maketrans("", "", string.punctuation)

# Initialize SentenceTransformer for BERT-based similarity
bert_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the fine-tuned model and tokenizer
model_name = "allenai/scibert_scivocab_uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.load_state_dict(torch.load("./scibert_fine_tuned_model_LIME.pth"))
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./local_model')

# Define prediction function for LIME
def predict_proba(texts):
    tokens = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors='pt'
    ).to(device)
    with torch.no_grad():
        outputs = model(**tokens)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# Initialize LIME explainer
explainer = LimeTextExplainer(class_names=['No Citation Needed', 'Citation Needed'])

# Helper function to clean and tokenize text
def clean_and_tokenize(text):
    text = text.translate(punctuation_table).lower()  # Remove punctuation and lowercase
    tokens = [token.text for token in nlp(text) if token.text not in stop_words]
    return tokens

# Generate explanations with cleaned multi-word phrases
def generate_explanation(lime_keywords_pos, sentence, label):
    doc = nlp(sentence)

    # Extract multi-word phrases
    multi_word_phrases = [
        chunk.text for chunk in doc.noun_chunks
    ] + [
        entity.text for entity in doc.ents
    ] + [
        f"{token.head.text} {token.text}" for token in doc if token.dep_ == "amod"
    ]
    
    # Remove duplicates, normalize, and filter out empty strings
    multi_word_phrases = list(set([
        ' '.join(sorted([word for word in phrase.lower().split() if word not in stop_words])).strip()
        for phrase in multi_word_phrases
    ]))
    multi_word_phrases = [phrase for phrase in multi_word_phrases if phrase]  # Remove empty strings

    # Extract relationships
    relationships = []
    for token in doc:
        if token.dep_ in {"nsubj", "nsubjpass"} and token.head.pos_ == "VERB":
            objs = [obj.text for obj in token.head.children if obj.dep_ in {"dobj", "pobj", "attr"}]
            relationships.append((token.text, token.head.text, objs))

    explanation = []
    global used_phrases
    used_phrases = []  # Track used phrases throughout the explanation generation process

    if label == 1:  # "Citation Needed" template
        for subj, verb, objs in relationships:
            if objs:
                obj_list = ", ".join(objs)
                standardized_subj = ' '.join(sorted(subj.lower().split()))
                if standardized_subj not in used_phrases:
                    explanation_text = f"The effect of action '{verb.upper()}' on '{subj}' and '{obj_list}' requires evidence to support the claim."
                    explanation.append(explanation_text)
                    used_phrases.append(standardized_subj)
                    print(f"Debug: Adding relationship explanation - subj: '{subj}', verb: '{verb}', objs: '{obj_list}'")
        for phrase in multi_word_phrases[:3]:  # Limit to top 3 phrases to reduce redundancy
            standardized_phrase = ' '.join(sorted(phrase.lower().split()))
            if standardized_phrase not in used_phrases:
                explanation_text = f"The phrase '{phrase}' presents a claim that requires substantiation."
                explanation.append(explanation_text)
                used_phrases.append(standardized_phrase)
                print(f"Debug: Adding multi-word phrase explanation - phrase: '{phrase}'")

    else:  # "No Citation Needed" template
        if "result" in sentence.lower() or "method" in sentence.lower():
            if "result" not in used_phrases:
                explanation_text = "This sentence reports study-specific findings, not requiring citation."
                explanation.append(explanation_text)
                used_phrases.append("result")
                print(f"Debug: Adding 'No Citation Needed' explanation - {explanation_text}")
        elif any(token.dep_ == "ROOT" and token.pos_ == "AUX" for token in doc):
            if "logical deduction" not in used_phrases:
                explanation_text = "This sentence reflects logical deduction and doesn't need citation."
                explanation.append(explanation_text)
                used_phrases.append("logical deduction")
                print(f"Debug: Adding logical deduction explanation - {explanation_text}")
        elif any(token.ent_type_ in {"DATE", "TIME", "PERCENT", "QUANTITY"} for token in doc):
            if "quantitative information" not in used_phrases:
                explanation_text = "This sentence provides quantitative information that is self-contained."
                explanation.append(explanation_text)
                used_phrases.append("quantitative information")
                print(f"Debug: Adding quantitative information explanation - {explanation_text}")
        elif any(token.text.lower() in {"this", "these", "our"} for token in doc):
            if "specific study findings" not in used_phrases:
                explanation_text = "This sentence refers to specific study findings, which are sufficiently explained and don't need external citation."
                explanation.append(explanation_text)
                used_phrases.append("specific study findings")
                print(f"Debug: Adding specific study findings explanation - {explanation_text}")
        else:
            if "general background information" not in used_phrases:
                explanation_text = "This sentence provides general background information, making external citation unnecessary in this context."
                explanation.append(explanation_text)
                used_phrases.append("general background information")
                print(f"Debug: Adding general background information explanation - {explanation_text}")

    if not explanation:
        explanation.append("This sentence provides background or context without requiring citation.")
        print(f"Debug: Adding default explanation - This sentence provides background or context without requiring citation.")
    explanation_tokens = [
        token.lower() for phrase in used_phrases for token in phrase.split()
    ]
    return " ".join(explanation), explanation_tokens


# Calculate BERT Cosine Similarity
def calculate_bert_cosine_similarity(text1, text2):
    # Encode the texts
    embeddings1 = bert_model.encode(text1, convert_to_tensor=True)
    embeddings2 = bert_model.encode(text2, convert_to_tensor=True)
    
    # Compute cosine similarity
    similarity_score = util.pytorch_cos_sim(embeddings1, embeddings2).item()
    return similarity_score


# Token-level comparison with intersection of original sentence, LLM explanation, and filtered LIME tokens
def token_level_comparison_with_intersection(sentence_tokens, llm_tokens, lime_tokens, used_phrases):
    # Filter LIME tokens to include only those in used_phrases
    filtered_lime_tokens = [
        token.lower() for phrase in used_phrases for token in phrase.split()
    ]

    # Create the intersection of original sentence and LLM explanation tokens
    intersection_tokens = set(sentence_tokens).intersection(set(llm_tokens))

    # Identify matched tokens between the intersection and the filtered LIME-derived tokens
    matched_tokens = [token for token in filtered_lime_tokens if token in intersection_tokens]

    return matched_tokens, intersection_tokens, filtered_lime_tokens

# Summarize results for final analysis with total and micro/macro averages
def summarize_results(df, results, global_metrics):
    summary = {
        "bert_similarity": {
            "all_sentences": [],
            "category_2_and_3": [],
            "category_3": []
        },
        "token_metrics": {
            "precision": [],
            "recall": [],
            "f1": []
        }
    }

    for res in results:
        summary["bert_similarity"]["all_sentences"].append(res["bert_similarity"])
        if res["category"] in [2, 3]:
            summary["bert_similarity"]["category_2_and_3"].append(res["bert_similarity"])
        if res["category"] == 3:
            summary["bert_similarity"]["category_3"].append(res["bert_similarity"])
        
        summary["token_metrics"]["precision"].append(res["precision"])
        summary["token_metrics"]["recall"].append(res["recall"])
        summary["token_metrics"]["f1"].append(res["f1"])
    
    # Handle empty lists to avoid ZeroDivisionError
    def safe_average(lst):
        return sum(lst) / len(lst) if len(lst) > 0 else 0

    final_summary = {
        "average_bert_similarity": {
            "all_sentences": safe_average(summary["bert_similarity"]["all_sentences"]),
            "category_2_and_3": safe_average(summary["bert_similarity"]["category_2_and_3"]),
            "category_3": safe_average(summary["bert_similarity"]["category_3"]),
        },
        "average_token_metrics": {
            "precision": safe_average(summary["token_metrics"]["precision"]),
            "recall": safe_average(summary["token_metrics"]["recall"]),
            "f1": safe_average(summary["token_metrics"]["f1"]),
        },
        "total_metrics": global_metrics  # Add total precision, recall, and F1 from global metrics
    }
    
    return final_summary


# Categorize and log sentences
def categorize_and_log(df, lime_timeout=10):
    results = []  # Store per-sentence metrics for final analysis
    citation_needed_category_3_results = []  # For "Citation Needed" in category 3
    global_tp, global_fp, global_fn = 0, 0, 0  # For total precision/recall calculations

#     if max_rows is not None:
#         df = df.head(max_rows)

    with open("category_1.log", "w") as log1, open("category_2.log", "w") as log2, open("category_3.log", "w") as log3:
        for _, row in tqdm(df.iterrows(), total=len(df), desc="Processing sentences"):
            sentence = row['current_sentence']
            llm_explanation = row['explanation']
            expected_label = row['label']

            # Predict label and generate LIME explanation
            predicted_probs = predict_proba([sentence])[0]
            predicted_label = int(predicted_probs.argmax())

            signal.alarm(lime_timeout)
            try:
                explanation = explainer.explain_instance(
                    sentence,
                    predict_proba,
                    num_features=10,
                    num_samples=200
                )
                signal.alarm(0)

                lime_results = explanation.as_list()
                lime_tokens = [word.lower() for word, _ in lime_results]  # LIME-derived tokens

                # Generate explanation and extract tokens from used phrases
                generated_explanation, explanation_tokens = generate_explanation(
                    lime_results, sentence, predicted_label
                )

                # Token-level comparison
                sentence_tokens = clean_and_tokenize(sentence)
                llm_tokens = clean_and_tokenize(llm_explanation)

                # Pass explanation tokens and used phrases
                matched_tokens, intersection_tokens, filtered_lime_tokens = token_level_comparison_with_intersection(
                    sentence_tokens, llm_tokens, explanation_tokens, used_phrases
                )

                # Calculate precision, recall, and F1 using explanation tokens
                tp = len(matched_tokens)
                fp = len(explanation_tokens) - tp
                fn = len(intersection_tokens) - tp
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = (2 * precision * recall) / (precision + recall) if precision + recall > 0 else 0


                # Update global metrics
                global_tp += tp
                global_fp += fp
                global_fn += fn

                # Log metrics for each sentence
                if predicted_label != expected_label:
                    category = 1
                    log_file = log1
                elif not intersection_tokens:
                    category = 2
                    log_file = log2
                else:
                    category = 3
                    log_file = log3

                # Log metrics for each sentence
                log_file.write("\n=== Debug Info ===\n")
                log_file.write(f"Sentence: {sentence}\n")
                log_file.write(f"Expected Label: {'Citation Needed' if expected_label == 1 else 'No Citation Needed'}\n")
                log_file.write(f"Predicted Label: {'Citation Needed' if predicted_label == 1 else 'No Citation Needed'}\n")
                log_file.write(f"Generated Explanation: {generated_explanation}\n")
                log_file.write(f"LLM Explanation: {llm_explanation}\n")
                log_file.write(f"Original Sentence Tokens: {sentence_tokens}\n")
                log_file.write(f"LLM Explanation Tokens: {llm_tokens}\n")
                log_file.write(f"Filtered LIME Tokens (from used phrases): {filtered_lime_tokens}\n")
                log_file.write(f"Intersection Tokens: {intersection_tokens}\n")
                log_file.write(f"Matched Tokens: {matched_tokens}\n")

                # Add TP, FP, FN to the log
                log_file.write(f"True Positives (TP): {tp}\n")
                log_file.write(f"False Positives (FP): {fp}\n")
                log_file.write(f"False Negatives (FN): {fn}\n")

                log_file.write(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}\n")
                log_file.write("===================\n\n")

                results.append({
                    "category": category,
                    "bert_similarity": calculate_bert_cosine_similarity(llm_explanation, generated_explanation),
                    "precision": precision,
                    "recall": recall,
                    "f1": f1
                })

                if category == 3 and expected_label == 1:
                    citation_needed_category_3_results.append({
                        "category": category,
                        "bert_similarity": calculate_bert_cosine_similarity(llm_explanation, generated_explanation),
                        "precision": precision,
                        "recall": recall,
                        "f1": f1
                    })

            except TimeoutException:
                log1.write(f"Skipped sentence due to timeout: {sentence}\n")
                signal.alarm(0)
                continue

    # Calculate global metrics for total precision/recall
    total_precision = global_tp / (global_tp + global_fp) if (global_tp + global_fp) > 0 else 0
    total_recall = global_tp / (global_tp + global_fn) if (global_tp + global_fn) > 0 else 0
    total_f1 = (2 * total_precision * total_recall) / (total_precision + total_recall) if total_precision + total_recall > 0 else 0

    global_metrics = {
        "total_precision": total_precision,
        "total_recall": total_recall,
        "total_f1": total_f1
    }

    # Summarize and log final analysis
    summary = summarize_results(df, results, global_metrics)
    citation_needed_category_3_summary = summarize_results(df, citation_needed_category_3_results, global_metrics)

    with open("final_analysis.log", "w") as final_log:
        final_log.write("=== Final Analysis ===\n")
        final_log.write(f"Average BERT Similarity (All Sentences): {summary['average_bert_similarity']['all_sentences']:.4f}\n")
        final_log.write(f"Average BERT Similarity (Category 2 & 3): {summary['average_bert_similarity']['category_2_and_3']:.4f}\n")
        final_log.write(f"Average BERT Similarity (Category 3): {summary['average_bert_similarity']['category_3']:.4f}\n")
        final_log.write(f"Average Precision: {summary['average_token_metrics']['precision']:.4f}\n")
        final_log.write(f"Average Recall: {summary['average_token_metrics']['recall']:.4f}\n")
        final_log.write(f"Average F1: {summary['average_token_metrics']['f1']:.4f}\n")
        final_log.write("====================\n")

        # Additional summary for "Citation Needed" class within category 3
        final_log.write("\n=== Citation Needed Class within Category 3 Analysis ===\n")
        final_log.write(f"Average BERT Similarity: {citation_needed_category_3_summary['average_bert_similarity']['all_sentences']:.4f}\n")
        final_log.write(f"Token-Level Precision: {citation_needed_category_3_summary['average_token_metrics']['precision']:.4f}\n")
        final_log.write(f"Token-Level Recall: {citation_needed_category_3_summary['average_token_metrics']['recall']:.4f}\n")
        final_log.write(f"Token-Level F1: {citation_needed_category_3_summary['average_token_metrics']['f1']:.4f}\n")
        final_log.write("====================\n")

        # Debugging: Total Metrics
        final_log.write("\n=== Debugging Information: Total Metrics ===\n")
        final_log.write(f"Total TP (True Positives): {global_tp}\n")
        final_log.write(f"Total FP (False Positives): {global_fp}\n")
        final_log.write(f"Total FN (False Negatives): {global_fn}\n")
        final_log.write("====================\n")



# Load the dataset
training_data_path = "training_data.csv"
df = pd.read_csv(training_data_path)

# Run the categorization and logging
categorize_and_log(df)



### This section does similar action as the previous section, but is used to determine the shuffled result, which is used as the baseline. 

In [None]:
# The shuffled version, which is serving as the baseline.

import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from lime.lime_text import LimeTextExplainer
import spacy
from tqdm import tqdm
import signal
from sentence_transformers import SentenceTransformer, util
from random import shuffle

# Timeout handler for long-running LIME explanations
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

# Load SpaCy's English model
nlp = spacy.load("en_core_web_sm")

# Initialize SentenceTransformer for BERT-based similarity
bert_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Load the fine-tuned model and tokenizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "allenai/scibert_scivocab_uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.load_state_dict(torch.load("./scibert_fine_tuned_model_LIME.pth"))
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./local_model')

# Define prediction function for LIME
def predict_proba(texts):
    tokens = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors='pt'
    ).to(device)
    with torch.no_grad():
        outputs = model(**tokens)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# Initialize LIME explainer
explainer = LimeTextExplainer(class_names=['No Citation Needed', 'Citation Needed'])

# Helper function to clean and tokenize text
def clean_and_tokenize(text):
    tokens = [token.text for token in nlp(text) if token.text]
    return tokens

# Generate explanations with cleaned multi-word phrases
def generate_explanation(lime_keywords_pos, sentence, label):
    doc = nlp(sentence)
    multi_word_phrases = [
        chunk.text for chunk in doc.noun_chunks
    ] + [
        entity.text for entity in doc.ents
    ] + [
        f"{token.head.text} {token.text}" for token in doc if token.dep_ == "amod"
    ]
    multi_word_phrases = list(set(multi_word_phrases))
    explanation = []
    if label == 1:  # Citation Needed
        for phrase in multi_word_phrases[:3]:  # Top 3 phrases
            explanation.append(f"The phrase '{phrase}' requires substantiation.")
    else:  # No Citation Needed
        explanation.append("This sentence provides general background information.")
    return " ".join(explanation), multi_word_phrases

# Main processing function with periodic and final score summaries
def process_and_compute_similarity(df, lime_timeout=10):
    # Shuffle LLM explanations
    shuffled_explanations = df['explanation'].tolist()
    shuffle(shuffled_explanations)
    df['shuffled_explanation'] = shuffled_explanations

    # Lists to store results
    bert_scores_original = []
    bert_scores_shuffled = []

    for i, row in tqdm(df.iterrows(), total=len(df), desc="Processing sentences"):
        sentence = row['current_sentence']
        llm_explanation = row['explanation']
        shuffled_explanation = row['shuffled_explanation']
        expected_label = row['label']

        # Get predicted probabilities and label
        predicted_probs = predict_proba([sentence])[0]
        predicted_label = int(predicted_probs.argmax())

        # Skip sentences where predicted label does not match expected label
        if predicted_label != expected_label:
            continue

        signal.alarm(lime_timeout)
        try:
            explanation = explainer.explain_instance(
                sentence,
                predict_proba,
                num_features=10,
                num_samples=200
            )
            signal.alarm(0)

            # Retrieve LIME results
            lime_results = explanation.as_list()
            lime_keywords_pos = [
                (word, token.pos_)
                for word, _ in lime_results
                for token in nlp(str(word))
            ]

            # Generate explanation
            generated_explanation, lime_phrases = generate_explanation(
                lime_keywords_pos, sentence, predicted_label
            )

            # Compute BERT similarity scores
            score_original = util.cos_sim(
                bert_model.encode(generated_explanation, convert_to_tensor=True),
                bert_model.encode(llm_explanation, convert_to_tensor=True)
            ).item()
            score_shuffled = util.cos_sim(
                bert_model.encode(generated_explanation, convert_to_tensor=True),
                bert_model.encode(shuffled_explanation, convert_to_tensor=True)
            ).item()

            # Save scores
            bert_scores_original.append(score_original)
            bert_scores_shuffled.append(score_shuffled)

        except TimeoutException:
            print(f"Skipped sentence due to timeout: {sentence}")
            signal.alarm(0)
            continue

        # Print averages every 500 sentences
        if (i + 1) % 500 == 0:
            avg_original = sum(bert_scores_original) / len(bert_scores_original)
            avg_shuffled = sum(bert_scores_shuffled) / len(bert_scores_shuffled)
            print(f"\nProcessed {i + 1} sentences:")
            print(f" - Average BERT Similarity (Original): {avg_original:.4f}")
            print(f" - Average BERT Similarity (Shuffled): {avg_shuffled:.4f}\n")

    # Store similarity scores in the DataFrame
    df['bert_similarity_original'] = pd.Series(bert_scores_original)
    df['bert_similarity_shuffled'] = pd.Series(bert_scores_shuffled)

    # Print final overall averages
    final_avg_original = sum(bert_scores_original) / len(bert_scores_original)
    final_avg_shuffled = sum(bert_scores_shuffled) / len(bert_scores_shuffled)
    print("\n=== Final Summary ===")
    print(f" - Overall Mean BERT Similarity (Original): {final_avg_original:.4f}")
    print(f" - Overall Mean BERT Similarity (Shuffled): {final_avg_shuffled:.4f}")

    return df


# Load the training sample and run
training_data_path = "training_data.csv"
df = pd.read_csv(training_data_path)

processed_df = process_and_compute_similarity(df)
processed_df.to_csv("processed_with_bert_similarity.csv", index=False)


### This section is a interactive playground, where the user can input any sentence for classification and explanation. 

In [1]:
# The interactive playground
import pandas as pd
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from lime.lime_text import LimeTextExplainer
import spacy
from sentence_transformers import SentenceTransformer, util
import signal
import string

# Timeout handler for long-running LIME explanations
class TimeoutException(Exception):
    pass

def timeout_handler(signum, frame):
    raise TimeoutException

signal.signal(signal.SIGALRM, timeout_handler)

# Load SpaCy's English model and stop words
nlp = spacy.load("en_core_web_sm")
stop_words = nlp.Defaults.stop_words
punctuation_table = str.maketrans("", "", string.punctuation)

# Initialize SentenceTransformer for BERT-based similarity
bert_model = SentenceTransformer('paraphrase-MiniLM-L6-v2')

# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load the fine-tuned model and tokenizer
model_name = "allenai/scibert_scivocab_uncased"
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
model.load_state_dict(torch.load("./scibert_fine_tuned_model_LIME.pth"))
model.to(device)
model.eval()

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir='./local_model')

# Define prediction function for LIME
def predict_proba(texts):
    tokens = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=128,
        return_tensors='pt'
    ).to(device)
    with torch.no_grad():
        outputs = model(**tokens)
        probs = torch.nn.functional.softmax(outputs.logits, dim=1)
    return probs.cpu().numpy()

# Initialize LIME explainer
explainer = LimeTextExplainer(class_names=['No Citation Needed', 'Citation Needed'])

# Helper function to clean and tokenize text
def clean_and_tokenize(text):
    text = text.translate(punctuation_table).lower()  # Remove punctuation and lowercase
    tokens = [token.text for token in nlp(text) if token.text not in stop_words]
    return tokens

# Generate explanations with cleaned multi-word phrases
def generate_explanation(lime_keywords_pos, sentence, label):
    doc = nlp(sentence)

    # Extract multi-word phrases
    multi_word_phrases = [
        chunk.text for chunk in doc.noun_chunks
    ] + [
        entity.text for entity in doc.ents
    ] + [
        f"{token.head.text} {token.text}" for token in doc if token.dep_ == "amod"
    ]

    # Remove duplicates, normalize, and filter out empty strings
    multi_word_phrases = list(set([
        ' '.join(sorted([word for word in phrase.lower().split() if word not in stop_words])).strip()
        for phrase in multi_word_phrases
    ]))
    multi_word_phrases = [phrase for phrase in multi_word_phrases if phrase]  # Remove empty strings

    # Extract relationships
    relationships = []
    for token in doc:
        if token.dep_ in {"nsubj", "nsubjpass"} and token.head.pos_ == "VERB":
            objs = [obj.text for obj in token.head.children if obj.dep_ in {"dobj", "pobj", "attr"}]
            relationships.append((token.text, token.head.text, objs))

    explanation = []
    global used_phrases
    used_phrases = []  # Track used phrases throughout the explanation generation process

    if label == 1:  # "Citation Needed" template
        for subj, verb, objs in relationships:
            if objs:
                obj_list = ", ".join(objs)
                standardized_subj = ' '.join(sorted(subj.lower().split()))
                if standardized_subj not in used_phrases:
                    explanation_text = f"The effect of action '{verb.upper()}' on '{subj}' and '{obj_list}' requires evidence to support the claim."
                    explanation.append(explanation_text)
                    used_phrases.append(standardized_subj)
        for phrase in multi_word_phrases[:3]:  # Limit to top 3 phrases to reduce redundancy
            standardized_phrase = ' '.join(sorted(phrase.lower().split()))
            if standardized_phrase not in used_phrases:
                explanation_text = f"The phrase '{phrase}' presents a claim that requires substantiation."
                explanation.append(explanation_text)
                used_phrases.append(standardized_phrase)

    else:  # "No Citation Needed" template
        if "result" in sentence.lower() or "method" in sentence.lower():
            if "result" not in used_phrases:
                explanation_text = "This sentence reports study-specific findings, not requiring citation."
                explanation.append(explanation_text)
                used_phrases.append("result")
        elif any(token.dep_ == "ROOT" and token.pos_ == "AUX" for token in doc):
            if "logical deduction" not in used_phrases:
                explanation_text = "This sentence reflects logical deduction and doesn't need citation."
                explanation.append(explanation_text)
                used_phrases.append("logical deduction")
        elif any(token.ent_type_ in {"DATE", "TIME", "PERCENT", "QUANTITY"} for token in doc):
            if "quantitative information" not in used_phrases:
                explanation_text = "This sentence provides quantitative information that is self-contained."
                explanation.append(explanation_text)
                used_phrases.append("quantitative information")
        elif any(token.text.lower() in {"this", "these", "our"} for token in doc):
            if "specific study findings" not in used_phrases:
                explanation_text = "This sentence refers to specific study findings, which are sufficiently explained and don't need external citation."
                explanation.append(explanation_text)
                used_phrases.append("specific study findings")
        else:
            if "general background information" not in used_phrases:
                explanation_text = "This sentence provides general background information, making external citation unnecessary in this context."
                explanation.append(explanation_text)
                used_phrases.append("general background information")

    if not explanation:
        explanation.append("This sentence provides background or context without requiring citation.")
    
    return " ".join(explanation)

# Interactive sentence classification and explanation
def interactive_classification():
    print("Enter a sentence to classify and explain, or type 'quit' to exit.")
    while True:
        sentence = input("\nEnter a sentence: ").strip()
        if sentence.lower() == "quit":
            print("Exiting interactive classification. Goodbye!")
            break

        try:
            # Predict probabilities and determine label
            predicted_probs = predict_proba([sentence])[0]
            predicted_label = int(predicted_probs.argmax())
            label_name = "Citation Needed" if predicted_label == 1 else "No Citation Needed"

            # Generate explanation using LIME
            signal.alarm(10)  # Set a timeout for the explanation
            try:
                explanation = explainer.explain_instance(
                    sentence,
                    predict_proba,
                    num_features=10,
                    num_samples=200
                )
                lime_results = explanation.as_list()
                lime_keywords_pos = [
                    (word, token.pos_)
                    for word, _ in lime_results
                    for token in nlp(str(word))
                ]
                signal.alarm(0)

                # Generate explanation text
                generated_explanation = generate_explanation(lime_keywords_pos, sentence, predicted_label)
            except TimeoutException:
                print("The explanation generation timed out. Try a shorter or simpler sentence.")
                continue

            # Display results
            print(f"\n=== Classification ===")
            print(f"Sentence: {sentence}")
            print(f"Predicted Label: {label_name}")
            print(f"Generated Explanation: {generated_explanation}")

        except Exception as e:
            print(f"An error occurred: {str(e)}")

# Run interactive classification
interactive_classification()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at allenai/scibert_scivocab_uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  model.load_state_dict(torch.load("./scibert_fine_tuned_model_LIME.pth"))


Enter a sentence to classify and explain, or type 'quit' to exit.

Enter a sentence: quit
Exiting interactive classification. Goodbye!
