In [1]:
#@title Load the files
#@markdown Mount drive and show the extracted clean text from psychological science papers

from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
#@title Downloading dependencies

# !pip uninstall -y torch torchvision torchaudio transformers
!pip cache purge
!pip install --upgrade --no-cache-dir torch torchvision torchaudio
!pip install --upgrade --no-cache-dir transformers


# Install other dependencies after PyTorch
!pip install nltk groq
import nltk
nltk.download('punkt_tab')


Found existing installation: torch 2.5.1+cu124
Uninstalling torch-2.5.1+cu124:
  Successfully uninstalled torch-2.5.1+cu124
Found existing installation: torchvision 0.20.1+cu124
Uninstalling torchvision-0.20.1+cu124:
  Successfully uninstalled torchvision-0.20.1+cu124
Found existing installation: torchaudio 2.5.1+cu124
Uninstalling torchaudio-2.5.1+cu124:
  Successfully uninstalled torchaudio-2.5.1+cu124
Found existing installation: transformers 4.47.1
Uninstalling transformers-4.47.1:
  Successfully uninstalled transformers-4.47.1
[0mFiles removed: 0
Collecting torch
  Downloading torch-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (28 kB)
Collecting torchvision
  Downloading torchvision-0.21.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.1 kB)
Collecting torchaudio
  Downloading torchaudio-2.6.0-cp311-cp311-manylinux1_x86_64.whl.metadata (6.6 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.w

[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


True

In [3]:
#@title Generate all text

import os
import pandas as pd
from lxml import etree
import nltk

nltk.download('punkt', quiet=True)

NS = {'tei': 'http://www.tei-c.org/ns/1.0'}

def count_words(text: str) -> int:
    return len(nltk.word_tokenize(text))

def extract_authors(root, namespaces):
    """Extract authors from <tei:author> elements in <sourceDesc>."""
    authors_elem = root.xpath('//tei:sourceDesc//tei:author', namespaces=namespaces)
    authors = []
    for author in authors_elem:
        forename = author.xpath('.//tei:forename[@type="first"]', namespaces=namespaces)
        middle_name = author.xpath('.//tei:forename[@type="middle"]', namespaces=namespaces)
        surname = author.xpath('.//tei:surname', namespaces=namespaces)

        full_name = ""
        if forename and forename[0].text:
            full_name += forename[0].text.strip()
        if middle_name and middle_name[0].text:
            full_name += " " + middle_name[0].text.strip()
        if surname and surname[0].text:
            full_name += " " + surname[0].text.strip()
        authors.append(full_name.strip())
    return authors

def extract_metadata(root, namespaces):
    """Extract basic metadata: title, authors, keywords, and DOI."""
    title_elem = root.xpath('//tei:titleStmt//tei:title', namespaces=namespaces)
    authors = extract_authors(root, namespaces)
    keywords_elem = root.xpath('//tei:textClass//tei:keywords//tei:term', namespaces=namespaces)
    doi_elem = root.xpath('//tei:idno[@type="DOI"]', namespaces=namespaces)

    title = title_elem[0].text.strip() if title_elem and title_elem[0].text else 'No Title'
    keywords = [kw.text.strip() for kw in keywords_elem if kw.text]
    doi = doi_elem[0].text.strip() if doi_elem and doi_elem[0].text else 'No DOI'

    return {
        'title': title,
        'authors': authors,
        'keywords': keywords,
        'doi': doi
    }

def classify_section(header: str) -> str:
    """Given a header string, classify which section it belongs to."""
    header_lower = header.lower()
    if 'abstract' in header_lower:
        return 'abstract'
    elif 'method' in header_lower or 'methods' in header_lower or 'Materials and Methods' in header_lower or 'model development' in header_lower or 'Feedback Sessions with Experts' in header_lower:
        return 'method'
    elif 'result' in header_lower or 'results' in header_lower or 'our view' in header_lower or 'reason 1: coordination is necessary to answer complex questions' in header_lower:
        return 'results'
    elif 'discussion' in header_lower:
        return 'discussion'
    elif 'conclusion' in header_lower:
        return 'conclusion'
    return None

def extract_abstract(root):
    """
    Extract the abstract text from either:
    1. A standard <abstract> element.
    2. The first <div> within <body> that lacks a <head>, assuming it's the abstract.
    """
    # 1. Try to extract from <abstract>
    abstract_nodes = root.xpath('.//*[local-name()="abstract"]')
    if abstract_nodes:
        abstract_node = abstract_nodes[0]
        paragraphs = abstract_node.xpath('.//*[local-name()="p"] | .//*[local-name()="s"]')
        if paragraphs:
            lines = []
            for p in paragraphs:
                text = " ".join(t.strip() for t in p.itertext() if t.strip())
                lines.append(text)
            abstract_text = " ".join(lines).strip()
            if abstract_text:
                return abstract_text
        else:
            abstract_text = " ".join(t.strip() for t in abstract_node.itertext() if t.strip())
            if abstract_text:
                return abstract_text

    # 2. If no <abstract>, try to extract from the first <div> without a <head>
    div_nodes = root.xpath('//tei:text//tei:body//tei:div[not(tei:head)]', namespaces=NS)
    if div_nodes:
        first_div = div_nodes[0]
        paragraphs = first_div.xpath('.//tei:p', namespaces=NS)
        if paragraphs:
            lines = []
            for p in paragraphs:
                text = " ".join(t.strip() for t in p.itertext() if t.strip())
                lines.append(text)
            abstract_text = " ".join(lines).strip()
            if abstract_text:
                return abstract_text
        else:
            abstract_text = " ".join(t.strip() for t in first_div.itertext() if t.strip())
            if abstract_text:
                return abstract_text

    return ""

def extract_main_sections(root, namespaces):
    """
    Extract text for main sections: abstract, methods, results, discussion, conclusion.
    This function will exclude the abstract if it was already extracted separately.
    """
    divs = root.xpath('//tei:text//tei:body//tei:div', namespaces=namespaces)

    sections = {
        'abstract': "",
        'method': "",
        'results': "",
        'discussion': "",
        'conclusion': ""
    }

    abstract_from_div = ""
    for div in divs:
        head_elems = div.xpath('./tei:head', namespaces=namespaces)
        header_text = head_elems[0].text.strip() if head_elems and head_elems[0].text else ''
        section_type = classify_section(header_text)
        if section_type == 'abstract':
            paragraphs = div.xpath('.//tei:p', namespaces=namespaces)
            div_text = " ".join(
                " ".join(t.strip() for t in p.itertext() if t.strip())
                for p in paragraphs
            )
            abstract_from_div = div_text.strip()
            break

    if abstract_from_div:
        sections['abstract'] = abstract_from_div
        # Remove the abstract div from further processing to avoid duplication
        divs = [
            d for d in divs
            if not (
                d.xpath('./tei:head', namespaces=namespaces) and
                classify_section(d.xpath('./tei:head', namespaces=namespaces)[0].text.strip()) == 'abstract'
            )
        ]

    current_section = None
    conclusion_ended = False

    for div in divs:
        head_elems = div.xpath('./tei:head', namespaces=namespaces)
        header_text = head_elems[0].text.strip() if head_elems and head_elems[0].text else ''
        section_type = classify_section(header_text)

        paragraphs = div.xpath('.//tei:p', namespaces=namespaces)
        div_text = " ".join(
            " ".join(t.strip() for t in p.itertext() if t.strip())
            for p in paragraphs
        ).strip()

        if section_type:
            current_section = section_type
            conclusion_ended = False
            sections[current_section] += " " + div_text
        else:
            if head_elems and current_section == 'conclusion':
                conclusion_ended = True

            if current_section and not conclusion_ended:
                sections[current_section] += " " + div_text

    for sec in sections:
        sections[sec] = sections[sec].strip()

    return sections

def process_xml_file(file_path):
    """Parse one XML file and return a list of dicts with metadata + each section's text."""
    try:
        root = etree.parse(file_path).getroot()
    except Exception as e:
        print(f"ERROR: Could not parse file {file_path}. Reason: {str(e)}")
        return None

    # 1) Extract all metadata
    metadata = extract_metadata(root, NS)
    title = metadata['title']
    doi = metadata['doi']
    authors = metadata['authors']
    keywords = metadata['keywords']

    # 2) Extract abstract & other main sections
    abstract_text = extract_abstract(root)
    sections = extract_main_sections(root, NS)
    if not abstract_text and sections['abstract']:
        abstract_text = sections['abstract']
    if abstract_text:
        sections['abstract'] = abstract_text.strip()

    # 3) Build our “final_sections” list of dicts
    final_sections = []

    # -- Add a "title" row *first* --
    if title and title != "No Title":
        final_sections.append({
            "DOI": doi,
            "Title": title,
            "Authors": ", ".join(authors),
            "Keywords": ", ".join(keywords),
            "Section": "title",
            "Text": title,  # we store the paper's title here
            "WordCount": count_words(title)
        })

    # -- Now do the usual sections --
    main_sections = ['abstract', 'method', 'results', 'discussion', 'conclusion']
    for sec in main_sections:
        text = sections[sec].strip()
        if text:
            final_sections.append({
                "DOI": doi,
                "Title": title,
                "Authors": ", ".join(authors),
                "Keywords": ", ".join(keywords),
                "Section": sec,
                "Text": text,
                "WordCount": count_words(text)
            })

    return final_sections

def process_all_files_in_folder(folder_path):
    all_results = []
    for filename in os.listdir(folder_path):
        if filename.endswith('.xml'):
            file_path = os.path.join(folder_path, filename)
            results = process_xml_file(file_path)
            if results:
                all_results.extend(results)

    # Create DataFrame from all the results
    return pd.DataFrame(all_results)

# Example usage:
folder_path = '/content/drive/My Drive/HTI thesis/Interview papers'
all_text = process_all_files_in_folder(folder_path)
df = all_text  # Renamed for clarity
df


Unnamed: 0,DOI,Title,Authors,Keywords,Section,Text,WordCount
0,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",title,Optimizing Interval Training Through Power-Out...,10
1,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",abstract,Purpose: Maximal oxygen uptake ( VO 2 max) is ...,314
2,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",method,A total of 14 well-trained male cyclists volun...,2228
3,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",results,Statistics and effect size estimations from th...,321
4,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",discussion,"Consistent with our first hypothesis, well-tra...",1536
5,10.1123/ijspp.2019-0260,Optimizing Interval Training Through Power-Out...,"Arthur H Bossi, Cristian Mesquida, Louis Passf...","intensity prescription, time at VO 2 max, elit...",conclusion,In comparison with a HIIT session with constan...,44
6,10.1017/S1355771819000451,Human-human collaboration enhanced with emergi...,"Rick Knops, Irina Bianca Şerban, Steven Houben","CCS Concepts:, Human-centered computing → Visu...",title,Human-human collaboration enhanced with emergi...,8
7,10.1017/S1355771819000451,Human-human collaboration enhanced with emergi...,"Rick Knops, Irina Bianca Şerban, Steven Houben","CCS Concepts:, Human-centered computing → Visu...",abstract,Co-creation sessions have traditionally relied...,231
8,10.1017/S1355771819000451,Human-human collaboration enhanced with emergi...,"Rick Knops, Irina Bianca Şerban, Steven Houben","CCS Concepts:, Human-centered computing → Visu...",results,The increasing integration of AI in our daily ...,375
9,10.1017/S1355771819000451,Human-human collaboration enhanced with emergi...,"Rick Knops, Irina Bianca Şerban, Steven Houben","CCS Concepts:, Human-centered computing → Visu...",conclusion,We are interested in discussing opportunities ...,341


In [4]:
#@title Classify
import pandas as pd
import nltk
from transformers import BertForSequenceClassification, BertTokenizer
import torch
import torch.nn.functional as F
from tqdm import tqdm
import ipywidgets as widgets
from IPython.display import display

nltk.download('punkt')

# Path to your BERT model for causal/non-causal classification
model_path = '/content/drive/My Drive/HTI thesis/2_BERT2'
model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

def classify_sentences_with_confidence(sentences, model, tokenizer, batch_size=32):
    all_predictions = []
    all_confidences = []
    model.eval()
    with torch.no_grad():
        for i in range(0, len(sentences), batch_size):
            batch = sentences[i:i+batch_size]
            inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)

            # Forward pass through the model
            outputs = model(**inputs)
            logits = outputs.logits

            # Apply softmax to get probabilities (confidence scores)
            probs = F.softmax(logits, dim=-1)

            # Get predictions and confidence (probability of the predicted class)
            preds = torch.argmax(probs, dim=-1).cpu().tolist()
            confidences = torch.max(probs, dim=-1).values.cpu().tolist()

            all_predictions.extend(preds)
            all_confidences.extend(confidences)

    return all_predictions, all_confidences

# Dropdown widget for selecting paper (only unique titles)
paper_titles = df['Title'].unique().tolist()  # Make sure titles are unique
dropdown = widgets.Dropdown(
    options=paper_titles,
    description='Select Paper:',
)

display(dropdown)

def classify_selected_paper(change):
    global final_df, doi  # Declare final_df and doi as global to use them throughout the notebook

    # Get the selected paper title
    selected_title = change.new
    print(f"Classifying paper: {selected_title}")

    # Filter the DataFrame for the selected paper
    selected_paper_df = df[df['Title'] == selected_title]

    sentence_rows = []
    title_processed = False

    for idx, row in tqdm(selected_paper_df.iterrows(), total=selected_paper_df.shape[0], desc="Classifying full paper"):
        doi = row["DOI"]  # Set the global doi variable
        title = row["Title"]
        section = row["Section"]
        text = row["Text"]
        authors = row["Authors"]

        if not title_processed:
            predictions, confidences = classify_sentences_with_confidence([title], model, tokenizer)
            for pred, conf in zip(predictions, confidences):
                sentence_rows.append({
                    "DOI": doi,
                    "Title": title,
                    "Authors": authors,
                    "Section": "title",  # Mark this as Title in the section
                    "Sentence": title,
                    "Causal": pred,  # 0 = Non-causal, 1 = Causal
                    "Confidence": conf  # Confidence score for the prediction
                })
            title_processed = True  # Ensure title is only processed once

        if section.lower() in ["method"]:
            continue

        sentences = nltk.sent_tokenize(text)
        filtered_sentences = [s for s in sentences if len(s.split()) >= 5]

        if filtered_sentences:
            predictions, confidences = classify_sentences_with_confidence(filtered_sentences, model, tokenizer)
            for sent, pred, conf in zip(filtered_sentences, predictions, confidences):
                sentence_rows.append({
                    "DOI": doi,
                    "Title": title,
                    "Authors": authors,
                    "Section": section,
                    "Sentence": sent,
                    "Causal": pred,  # 0 = Non-causal, 1 = Causal
                    "Confidence": conf  # Confidence score for the prediction
                })

    # Create a DataFrame of the results for the selected paper
    result_df = pd.DataFrame(sentence_rows)

    causal_df = result_df[result_df['Causal'] == 1].copy()

    # Process causal sentences into three categories
    second_model_path = '/content/drive/My Drive/HTI thesis/3_BERT'
    second_model = BertForSequenceClassification.from_pretrained(second_model_path)
    second_tokenizer = BertTokenizer.from_pretrained(second_model_path)

    second_model.to(device)

    def classify_three_classes(sentences, model, tokenizer, batch_size=32):
        predictions = []
        confidences = []
        model.eval()
        with torch.no_grad():
            for i in range(0, len(sentences), batch_size):
                batch = sentences[i:i+batch_size]
                inputs = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=512)
                inputs = {key: value.to(device) for key, value in inputs.items()}

                outputs = model(**inputs)
                logits = outputs.logits
                probs = F.softmax(logits, dim=-1)

                batch_preds = torch.argmax(logits, dim=-1).tolist()
                batch_confidences = [probs[j, pred].item() for j, pred in enumerate(batch_preds)]

                predictions.extend(batch_preds)
                confidences.extend(batch_confidences)
        return predictions, confidences

    causal_sentences = causal_df['Sentence'].tolist()
    three_class_predictions, confidences = classify_three_classes(causal_sentences, second_model, second_tokenizer)
    mapped_predictions = [p + 1 for p in three_class_predictions]

    causal_df['Causal'] = mapped_predictions
    causal_df['confidence'] = confidences

    non_causal_df = result_df[result_df['Causal'] == 0].copy()

    # Add original_index to ensure we maintain the correct order
    result_df['original_index'] = result_df.index
    causal_df['original_index'] = causal_df.index
    non_causal_df['original_index'] = non_causal_df.index

    final_df = pd.concat([causal_df[['DOI', 'Title','Authors', 'Section', 'Sentence', 'Causal', 'Confidence', 'original_index']],
                          non_causal_df[['DOI', 'Title', 'Authors', 'Section', 'Sentence', 'Causal', 'Confidence', 'original_index']]])

    final_df = final_df.sort_values(by='original_index').drop(columns='original_index').reset_index(drop=True)

    # Now final_df and doi are available globally and can be used throughout the notebook
    print(f"Classification complete for DOI: {doi}")
    print("final_df is ready for further use.")

# Attach the function to the dropdown menu
dropdown.observe(classify_selected_paper, names='value')


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Dropdown(description='Select Paper:', options=('Optimizing Interval Training Through Power-Output Variation Wi…

Classifying paper: Human-human collaboration enhanced with emerging technologies of AI


Classifying full paper: 100%|██████████| 4/4 [00:00<00:00,  6.06it/s]


Classification complete for DOI: 10.1017/S1355771819000451
final_df is ready for further use.
Classifying paper: Optimizing Interval Training Through Power-Output Variation Within the Work Intervals


Classifying full paper: 100%|██████████| 6/6 [00:00<00:00, 16.16it/s]


Classification complete for DOI: 10.1123/ijspp.2019-0260
final_df is ready for further use.


In [9]:
#@title Final UI

import pandas as pd
import nltk
from IPython.display import display, HTML
from groq import Groq
import plotly.graph_objects as go
import numpy as np
from google.colab import output
import re  # Import for regular expressions

# Download necessary NLTK data
nltk.download('punkt')

###################################################
# 1) Best Practices and Decision Tree Analysis
###################################################

# Synonyms for Randomized Controlled Trials
rct_synonyms = [
    'randomized controlled trial', 'randomized trial', 'controlled trial',
    'randomized experiment', 'random controlled trial', 'rct',
    'random assignment trial', 'random allocation trial'
]

# Expanded list of hedging terms
hedging_terms = [
    'might', 'could', 'possibly', 'may', 'perhaps', 'likely', 'seems', 'suggests',
    'appears', 'potentially', 'probably', 'may be', 'could be',
    'might be', 'seem to', 'appear to', 'suggest that', 'indicate that', 'imply that'
]

# Conditional causal terms
conditional_causal_terms = [
    'might cause', 'could lead to', 'may contribute to', 'appears to increase',
    'potentially affects', 'probably leads to', 'could cause', 'might lead to',
    'may influence', 'possibly results in'
]

# Ambiguous words categorized by type
ambiguous_words_correlational = ['predict', 'forecast', 'estimate']
ambiguous_words_conditional = ['prove', 'ensure', 'warrant']
ambiguous_words_direct = ['guarantee', 'always', 'never']

# Define causal categories globally
causal_categories = {1: 'Correlational', 2: 'Conditional Causal', 3: 'Direct Causal'}

# Best practices dictionary for warnings and examples
best_practices = {
    1: {
        "strict": '''<b>Misuse of causal language in non-RCT studies:</b><br>
Observational studies, including cross-sectional designs, lack the rigorous controls necessary to establish definitive causal relationships. Using definitive causal language (e.g., "X causes Y") in such studies can mislead readers about the strength of the evidence. It's essential to use associative terms like "associated with" to accurately reflect the study's limitations.<br>''',
        "moderate": '''<b>Watch out using causal language in observational designs:</b><br>
In observational studies, it's crucial to carefully control for confounders or use advanced statistical methods to support any causal claims. Without these precautions, causal language may overstate the findings.<br>'''
    },
    2: {
        "description": '''<b>Excessive use of hedging terms:</b><br>
Hedging terms like <i>"might," "could," "possibly,"</i> and <i>"may"</i> are useful for conveying uncertainty. However, overusing these terms can lead to vagueness and weaken the impact of your statements. Aim to limit hedging to maintain clarity and confidence in your findings.<br>
'''
    },
    4: {
        "description": '''<b>Use of ambiguous words implying certainty:</b><br>
Words like <i>"predict," "prove,"</i> and <i>"guarantee"</i> often imply a level of certainty that may not be warranted, especially in observational studies. It's better to use terms that accurately reflect the level of certainty, such as <i>"suggests"</i> or <i>"indicates"</i>, to avoid overstating your findings.<br>
'''
    },
    5: {
        "description": '''<b>Inconsistency in Causal Classification Between Title, Abstract, and Conclusion:</b><br>
The title, abstract, and conclusion of a paper should consistently reflect the same level of causal language. Inconsistencies can confuse readers about the study's findings. Ensure that if the title implies causality, the abstract and conclusion also support this classification, and vice versa.<br>
'''
    },
    6: {
        "description": '''<b>Conditional causal language without addressing confounders:</b><br>
Using conditional causal language (e.g., <i>"might cause," "could lead to"</i>) in studies requires careful consideration of confounding variables. Failing to address confounders can lead to erroneous conclusions. Always acknowledge how confounders are controlled for, or clearly state the study's limitations if they are not.<br>
'''
    },
    9: {
        "description": '''<b>Provide numerical or statistical context:</b><br>
Distinguishing between statistical significance and practical significance is crucial. A result may be statistically significant but have a negligible effect size, which might not be practically meaningful. Always provide numerical context, such as effect sizes, confidence intervals, or p-values, to help readers interpret the real-world impact of your findings.<br>
'''
    },
    10: {
        "description": '''<b>Incorporate caveats and limitations:</b><br>
Including limitations in the methods section or discussion provides transparency about the study's scope and potential weaknesses. Caveats such as <i>"further research is needed"</i> or <i>"these findings cannot establish causation"</i> help contextualize the results and prevent misinterpretation. Acknowledging limitations enhances the study's credibility.<br>
'''
    },
    11: {
        "description": '''<b>Excessive Conditional Causal Statements in a Section:</b><br>
Having more than five conditional causal sentences (e.g., "might cause," "could lead to") within a single section can indicate an overreliance on speculative language. This may undermine the clarity and strength of your findings. Strive for balanced language to maintain the study's credibility.<br>
'''
    },
    12: {
        "description": '''<b>Inconsistent Causal Classifications Across Sections:</b><br>
The causal classifications among the title, abstract, and conclusion sections are inconsistent. This inconsistency can confuse readers about the study's findings and weaken the overall credibility.<br>
'''
    }
}

# Function for decision tree-based analysis of sentences
def decision_tree_analysis(df, study_type, guideline_mode='strict'):
    warnings_dict = {}

    # Define confounders list (example list; should be customized as needed)
    confounders = ['age', 'gender', 'income', 'education', 'baseline health', 'prior conditions']

    # Preprocess study_type to determine if it's RCT or not
    study_type_lower = study_type.lower()
    is_rct = any(rct_synonym in study_type_lower for rct_synonym in rct_synonyms)

    # Check if confounders are addressed in the entire Method section
    method_confounders_present = False
    method_section = 'method'
    method_df = df[(df['Section'].str.lower() == method_section)]
    for conf in confounders:
        pattern = re.compile(r'\b' + re.escape(conf) + r'\b', re.IGNORECASE)
        if method_df['Sentence'].str.contains(pattern).any():
            method_confounders_present = True
            break  # If any confounder is present, no need to check further

    # Collect causal classifications per section
    section_causal = df[df['Causal'].isin([1,2,3])].groupby('Section')['Causal'].max().to_dict()

    # Detect mismatches between key sections
    key_sections = ['title', 'abstract', 'conclusion']
    classifications = {sec: section_causal.get(sec, 0) for sec in key_sections}

    # Check for inconsistencies among key sections
    unique_classifications = set(classifications.values())
    if len(unique_classifications) > 1:
        key = 12  # New Warning for inconsistent classifications
        description = best_practices[key]['description']
        trigger_reason = "Inconsistent causal classifications across Title, Abstract, and Conclusion."

        # Prepare per-section classifications
        per_section_classifications = {sec.capitalize(): causal_categories.get(cls, 'Unknown') for sec, cls in classifications.items() if cls != 0}

        # Prepare distribution (count) per section
        distribution = df[df['Section'].isin(key_sections) & df['Causal'].isin([1,2,3])]
        distribution_counts = distribution.groupby(['Section', 'Causal']).size().unstack(fill_value=0).to_dict('index')

        # Format distribution
        distribution_formatted = ""
        for sec, counts in distribution_counts.items():
            sec_cap = sec.capitalize()
            counts_formatted = ", ".join([f"{causal_categories.get(c, 'Unknown')}: {count}" for c, count in counts.items()])
            distribution_formatted += f"<b>{sec_cap}:</b> {counts_formatted}<br>"

        # Trigger sentences per section
        trigger_sentences = []
        for sec in key_sections:
            section_sentences = df[(df['Section'].str.lower() == sec) & (df['Causal'].isin([1,2,3]))]
            for _, row in section_sentences.iterrows():
                trigger_sentences.append({
                    'sentence': row['Sentence'],
                    'section': sec.capitalize()
                })

        warnings_dict[key] = {
            'key': key,
            'description': description,
            'reason': trigger_reason,
            'per_section_classifications': per_section_classifications,
            'distribution': distribution_formatted,
            'trigger_sentences': trigger_sentences
        }

    # Iterate through the dataframe and apply rules based on the study type and other factors
    for index, row in df.iterrows():
        sentence = row['Sentence']
        causal_type = row['Causal']
        confidence = row['Confidence']
        section = row['Section']

        # Skip non-causal sentences
        if causal_type == 0:
            continue

        sentence_lower = sentence.lower()

        # 1. Check Study Type and Causal Language
        # Warning 1 behaves differently based on mode
        if (section.lower() == 'abstract' or section.lower() == 'conclusion') and causal_type in [1,2,3]:
            key = 1
            if guideline_mode == 'strict':
                description_strict = best_practices[key]['strict']
                description_moderate = best_practices[key]['moderate']
                trigger_reason = "Causal claims are not allowed in this mode."
                # In strict mode, always trigger Warning 1 for causal claims in Abstract/Conclusion
                if key not in warnings_dict:
                    warnings_dict[key] = {
                        'key': key,
                        'description_strict': description_strict,
                        'description_moderate': description_moderate,
                        'reason_strict': trigger_reason,
                        'trigger_sentences': [{
                            'sentence': sentence,
                            'section': section.capitalize()
                        }]
                    }
                else:
                    warnings_dict[key]['trigger_sentences'].append({
                        'sentence': sentence,
                        'section': section.capitalize()
                    })
            elif guideline_mode == 'moderate':
                # In moderate mode, only trigger Warning 1 if confounders are NOT addressed
                if not method_confounders_present:
                    description_strict = best_practices[key]['strict']
                    description_moderate = best_practices[key]['moderate']
                    trigger_reason = "Causal claims without addressing confounders."
                    if key not in warnings_dict:
                        warnings_dict[key] = {
                            'key': key,
                            'description_strict': description_strict,
                            'description_moderate': description_moderate,
                            'reason_moderate': trigger_reason,
                            'trigger_sentences': [{
                                'sentence': sentence,
                                'section': section.capitalize()
                            }]
                        }
                    else:
                        warnings_dict[key]['trigger_sentences'].append({
                            'sentence': sentence,
                            'section': section.capitalize()
                        })
                # Else, do not trigger Warning 1
            else:
                # Undefined mode
                description_strict = best_practices[key]['strict']
                description_moderate = best_practices[key]['moderate']
                trigger_reason = "Use of causal language without proper guidelines."
                if key not in warnings_dict:
                    warnings_dict[key] = {
                        'key': key,
                        'description_strict': description_strict,
                        'description_moderate': description_moderate,
                        'reason_undefined': trigger_reason,
                        'trigger_sentences': [{
                            'sentence': sentence,
                            'section': section.capitalize()
                        }]
                    }
                else:
                    warnings_dict[key]['trigger_sentences'].append({
                        'sentence': sentence,
                        'section': section.capitalize()
                    })

        # 2. Check Excessive Hedging Terms
        hedge_count = sum(sentence_lower.count(term) for term in hedging_terms)
        if hedge_count > 2:
            key = 2
            trigger_reason = "Overuse of hedging terms leading to vagueness."
            if key not in warnings_dict:
                warnings_dict[key] = {
                    'key': key,
                    'description': best_practices[key]['description'],
                    'reason': trigger_reason,
                    'trigger_sentences': [{
                        'sentence': sentence,
                        'section': section.capitalize()
                    }]
                }
            else:
                warnings_dict[key]['trigger_sentences'].append({
                    'sentence': sentence,
                    'section': section.capitalize()
                })

        # 3. New Warning: Excessive Conditional Causal Sentences in a Section
        if causal_type == 2:
            # Count conditional causal terms in the sentence
            conditional_causal_count = sum(term in sentence_lower for term in conditional_causal_terms)
            if conditional_causal_count > 5:
                key = 11
                trigger_reason = "Overuse of conditional causal terms in a single section."
                if key not in warnings_dict:
                    warnings_dict[key] = {
                        'key': key,
                        'description': best_practices[key]['description'],
                        'reason': trigger_reason,
                        'trigger_sentences': [{
                            'sentence': sentence,
                            'section': section.capitalize()
                        }]
                    }
                else:
                    warnings_dict[key]['trigger_sentences'].append({
                        'sentence': sentence,
                        'section': section.capitalize()
                    })

        # 4. Check Ambiguous Words with Clear Categorization
        if any(word in sentence_lower for word in ambiguous_words_correlational):
            key = 4
            trigger_reason = "Use of ambiguous words implying correlational certainty."
            if key not in warnings_dict:
                warnings_dict[key] = {
                    'key': key,
                    'description': best_practices[key]['description'],
                    'reason': trigger_reason,
                    'trigger_sentences': [{
                        'sentence': sentence,
                        'section': section.capitalize()
                    }]
                }
            else:
                warnings_dict[key]['trigger_sentences'].append({
                    'sentence': sentence,
                    'section': section.capitalize()
                })

        if any(word in sentence_lower for word in ambiguous_words_conditional):
            key = 4
            trigger_reason = "Use of ambiguous words implying conditional causal certainty."
            if key not in warnings_dict:
                warnings_dict[key] = {
                    'key': key,
                    'description': best_practices[key]['description'],
                    'reason': trigger_reason,
                    'trigger_sentences': [{
                        'sentence': sentence,
                        'section': section.capitalize()
                    }]
                }
            else:
                warnings_dict[key]['trigger_sentences'].append({
                    'sentence': sentence,
                    'section': section.capitalize()
                })

        if any(word in sentence_lower for word in ambiguous_words_direct):
            key = 4
            trigger_reason = "Use of ambiguous words implying direct causal certainty."
            if key not in warnings_dict:
                warnings_dict[key] = {
                    'key': key,
                    'description': best_practices[key]['description'],
                    'reason': trigger_reason,
                    'trigger_sentences': [{
                        'sentence': sentence,
                        'section': section.capitalize()
                    }]
                }
            else:
                warnings_dict[key]['trigger_sentences'].append({
                    'sentence': sentence,
                    'section': section.capitalize()
                })

        # 5. Check Conditional Causal Language Without Addressing Confounders
        if not is_rct and causal_type == 2 and confidence > 0.90 and not method_confounders_present:
            key = 6
            trigger_reason = "Conditional causal claims without addressing confounders."
            if key not in warnings_dict:
                warnings_dict[key] = {
                    'key': key,
                    'description': best_practices[key]['description'],
                    'reason': trigger_reason,
                    'trigger_sentences': [{
                        'sentence': sentence,
                        'section': section.capitalize()
                    }]
                }
            else:
                warnings_dict[key]['trigger_sentences'].append({
                    'sentence': sentence,
                    'section': section.capitalize()
                })

    # Always include Warning 10 as a general recommendation
    key = 9
    description = best_practices[key]['description']
    trigger_reason = "Missing caveats and limitations in the study."
    if key not in warnings_dict:
        warnings_dict[key] = {
            'key': key,
            'description': description,
            'reason': trigger_reason,
            'trigger_sentences': []
        }

    return list(warnings_dict.values())

###################################################
# 2) Initialize Groq Client
###################################################
# **Security Warning:** Do not share your actual API key publicly.
client = Groq(api_key="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")  # Replace with your actual Groq API key

###################################################
# 3) FUNCTION: analyze_method_text
###################################################
def analyze_method_text(method_text, doi, title):
    """
    Parses the method text to gather certain fields. We omit sample size, population,
    and stats. The prompt has the word 'concise' for a more concise result. We store
    randomization as a short tag, and parse groups, mediators, data collection as lists.
    """

    prompt = f"""
You are a helpful assistant. Please analyze the **Method** section from a social science or behavioral research paper.
**Respond** in a structured format with headings exactly like this (be very concise):

## 1. Summary (max 100 words)
[Your concise summary of the methodology and study design.]

## 2. Study Type
[One of: "Randomized Controlled Trial (RCT)", "Cohort Study", "Case-Control Study",
"Cross-sectional Study", "Quasi-experimental", "Observational", etc.]

## 3. Randomization
["Yes" or "No"]

## 4. Groups or Comparisons
[Max 4 groups, separated by commas]

## 5. Mediators / Confounders
[List any mediators/confounders, look for phrases like: mediated by... controlled for.... If none, "Not stated". Comma-separated.]

## 6. Data Collection Method
[List data collection methods. If none, "Not stated". Comma-separated.]

## 7. Causal Inference Implications
[Explain what type of causality (correlational, conditional causal, or direct causal) is allowed with this study design e.g. randomization and controlled for confounders so hinting towards causality may be used.]

**Method Text**:
{method_text}
    """

    # Call LLM
    chat_completion = client.chat.completions.create(
        messages=[
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": prompt}
        ],
        model="llama-3.3-70b-versatile"  # Updated model name
    )
    response_content = chat_completion.choices[0].message.content

    # Prepare dictionary (no sample size, pop, stats)
    study_info = {
        "DOI": doi,
        "Title": title,
        "Summary": "",
        "Study Type": "",
        "Randomization": "",
        "Groups or Comparisons": [],
        "Mediators / Confounders": [],
        "Data Collection Method": [],
        "Causal Inference Implications": ""
    }

    heading_map = {
        "1. Summary": "Summary",
        "2. Study Type": "Study Type",
        "3. Randomization": "Randomization",
        "4. Groups or Comparisons": "Groups or Comparisons",
        "5. Mediators / Confounders": "Mediators / Confounders",
        "6. Data Collection Method": "Data Collection Method",
        "7. Causal Inference Implications": "Causal Inference Implications"
    }

    lines = response_content.split("\n")
    current_heading = None
    current_text_lines = []

    def commit_heading_text(hdg, txt_list):
        content = " ".join(t.strip() for t in txt_list).strip()

        if hdg == "Randomization":
            study_info[hdg] = content.strip().title()

        elif hdg in ["Groups or Comparisons", "Mediators / Confounders", "Data Collection Method"]:
            items = [x.strip() for x in content.split(",") if x.strip()]
            study_info[hdg] = items

        elif hdg == "Causal Inference Implications":
            study_info[hdg] = content.strip()

        else:
            study_info[hdg] = content.strip()

    for line in lines:
        line_stripped = line.strip()
        if line_stripped.startswith("## "):
            if current_heading and current_text_lines:
                commit_heading_text(heading_map[current_heading], current_text_lines)
                current_text_lines = []
            heading_title = line_stripped.replace("## ", "")
            for possible_key in heading_map.keys():
                if possible_key in heading_title:
                    current_heading = possible_key
                    break
            else:
                current_heading = None
        else:
            if current_heading:
                current_text_lines.append(line_stripped)

    if current_heading and current_text_lines:
        commit_heading_text(heading_map[current_heading], current_text_lines)

    return study_info

###################################################
# 4) HIGHLIGHTING FUNCTION (wavy for <0.90)
###################################################
def highlight_causal_sentences(text, causal_sentences_dict, confidence_dict, categories):
    import nltk
    sentences = nltk.sent_tokenize(text)

    BRIGHT_COLORS = {1: '#1827FF', 2: '#2CBCFF', 3: '#FB43FF'}
    PASTEL_COLORS_HIGH = {1: '#C6D6FF', 2: '#C2F0FF', 3: '#FFC9FF'}
    PASTEL_COLORS_LOW = {1: '#EFF3FF', 2: '#EEFCFF', 3: '#FFF0FF'}

    highlighted_text = ""
    for sentence in sentences:
        ctype = causal_sentences_dict.get(sentence, 0)
        if ctype == 0:
            highlighted_text += f"{sentence} "
            continue

        confidence = confidence_dict.get(sentence, 1.0)
        category = categories.get(ctype, "Unknown")
        bright_color = BRIGHT_COLORS.get(ctype, '#000000')

        if confidence >= 0.90:
            bg_color = PASTEL_COLORS_HIGH.get(ctype, '#e0e0e0')
            dec_style = "solid"
            dec_thickness = "2px"
        else:
            bg_color = PASTEL_COLORS_LOW.get(ctype, '#f3f3f3')
            dec_style = "wavy"
            dec_thickness = "1.5px"

        span_style = (
            f"background-color: {bg_color}; "
            f"padding:0.2em; "
            f"border-radius:5px; "
            f"cursor:pointer; "
            f"text-decoration-line: none;"
        )

        if confidence < 0.90:
            hover_text = f"<div style='color:white;'><span style='font-weight:bold;'>⚠️ LOW CONFIDENCE</span><br><b>Confidence:</b> {confidence:.2f}<br><b>Classification:</b> {category}</div>"
        else:
            hover_text = f"<b>Classification:</b> {category}"

        hover_escaped = hover_text.replace('"', "&quot;").replace("'", "&#39;")
        sentence_escaped = sentence.replace("'", "\\'").replace('"', '\\"')
        classification_escaped = category.replace("'", "\\'").replace('"', '\\"')

        highlighted_text += f"""
        <span
           class="hover-sentence"
           style="{span_style}"
           data-brightcolor="{bright_color}"
           data-confidence="{confidence}"
           data-bgcolor="{bg_color}"
           data-decstyle="{dec_style}"
           data-decthickness="{dec_thickness}"
           data-sentence="{sentence_escaped}"
           data-classification="{classification_escaped}"
           data-hoverinfo="{hover_escaped}"
        >
          {sentence}
        </span>
        """

    return highlighted_text

###################################################
# 5) SINGLE STACKED BAR CHART
###################################################
def build_single_stacked_bar_chart(df):
    df = df[~df['Section'].str.lower().eq('title')].copy()

    code_to_label = {0: 'Non-Causal', 1: 'Correlational', 2: 'Conditional Causal', 3: 'Direct Causal'}
    col_colors = {
        'Non-Causal': '#DEDEDE',
        'Correlational': '#1827FF',
        'Conditional Causal': '#2CBCFF',
        'Direct Causal': '#FB43FF'
    }

    desired_order = ["Conclusion", "Discussion", "Results", "Abstract"]

    df['Section'] = df['Section'].str.capitalize()
    df_pivot = df.pivot_table(index='Section', columns='Causal', values='Count', fill_value=0).reset_index()

    rename_cols = {}
    for c in df_pivot.columns:
        if c in code_to_label:
            rename_cols[c] = code_to_label[c]
        elif c == 'Section':
            rename_cols[c] = 'Section'
    df_pivot.rename(columns=rename_cols, inplace=True)

    for label in code_to_label.values():
        if label not in df_pivot.columns:
            df_pivot[label] = 0

    df_pivot['SecOrder'] = df_pivot['Section'].apply(
        lambda x: desired_order.index(x) if x in desired_order else len(desired_order)
    )
    df_pivot = df_pivot.sort_values('SecOrder').drop(columns=['SecOrder'])

    # Use reindex to handle missing sections gracefully
    df_pivot = df_pivot.set_index('Section').reindex(desired_order, fill_value=0).reset_index()

    fig = go.Figure()
    sections = df_pivot['Section'].tolist()

    categories_stack = ['Non-Causal', 'Correlational', 'Conditional Causal', 'Direct Causal']
    for cat in categories_stack:
        x_vals = df_pivot[cat].values
        fig.add_trace(
            go.Bar(
                name=cat,
                y=sections,
                x=x_vals,
                orientation='h',
                marker=dict(color=col_colors[cat]),
                hovertemplate=f"<b>{cat}</b>: %{{x}}<extra></extra>"
            )
        )

    fig.update_layout(
        barmode='stack',
        template='plotly_white',
        height=400,
        margin=dict(t=40, b=40, l=80, r=40),
        xaxis_title="Number of Sentences",
        yaxis_title="Sections",
        yaxis=dict(
            categoryorder='array',
            categoryarray=desired_order
        ),
        legend_title="Causal Type",
    )

    chart_html = fig.to_html(full_html=False, include_plotlyjs='cdn', config={'displayModeBar': False})
    container_html = f"""
    <div style='width: 48%; display: flex; flex-direction: column; justify-content: space-between;'>
        <h3 style='margin-top: 0; margin-bottom: 10px;'>Causal Classification Overview</h3>
        <div style='width: 100%; flex-grow: 1;'>
            {chart_html}
        </div>
    </div>
    """
    return container_html

###################################################
# 6) REGISTER CALLBACKS
###################################################

def ask_explanation(sentence, classification):
    if not sentence or not classification:
        return "Invalid input."

    prompt = f"""
The following sentence has been classified as "{classification}":
"{sentence}"

Please evaluate this classification critically and provide a concise explanation (max 100 words) for why it fits this category.

Definitions:
1. **Correlational**: Describes an association between variables without implying causation.
2. **Conditional Causal**: Suggests one variable may influence another, but with uncertainty.
3. **Direct Causal**: Clearly states that one variable directly causes changes in another without doubt.

### Evaluation Criteria:
- Does the classification match the definition? Justify your reasoning.
    """

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile"  # Updated model name
        )
        explanation = chat_completion.choices[0].message.content.strip()
        return explanation
    except Exception as e:
        return "Failed to generate explanation."


def ask_more(question, sentence, classification):
    if not question or not sentence or not classification:
        return "Invalid input."

    prompt = f"""
You previously explained why the following sentence was classified as "{classification}":
"{sentence}"

The user has a follow-up question or disagreement:
"{question}"

Please critically assess their point (max 100 words).
    """

    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are a helpful assistant."},
                {"role": "user", "content": prompt}
            ],
            model="llama-3.3-70b-versatile"  # Updated model name
        )
        answer = chat_completion.choices[0].message.content.strip()
        return answer
    except Exception as e:
        return "Failed to fetch answer."

# Register the callbacks
output.register_callback('ask_explanation', ask_explanation)
output.register_callback('ask_more', ask_more)
# Removed: output.register_callback('rebuild_ui_with_mode', rebuild_ui_with_mode)

def build_ui(final_df, all_text, doi, warnings=None, guideline_mode='strict'):
    """
    Renders the complete UI. The 'guideline_mode' parameter ensures that
    when we call build_ui from 'rebuild_ui_with_mode', the warnings and
    text displayed reflect the correct mode (strict or moderate).
    """
    global study_info_dict  # Declare as global to be accessible in callbacks

    # Filter just for this paper
    paper_df = final_df.loc[final_df['DOI'] == doi].copy()

    # Prepare data for highlighting
    causal_sentences_dict = {row['Sentence']: row['Causal'] for _, row in paper_df.iterrows()}
    confidence_dict = {row['Sentence']: row['Confidence'] for _, row in paper_df.iterrows()}

    # Build bar chart
    counts_by_section = (
        paper_df
        .groupby(['Section', 'Causal'])
        .size()
        .reset_index(name='Count')
    )
    chart = build_single_stacked_bar_chart(counts_by_section)

    # Analyze method text
    the_method_row = all_text[
        (all_text['DOI'] == doi) & (all_text['Section'].str.lower() == "method")
    ].copy()

    if the_method_row.empty:
        # If Method section is missing, set default empty values
        study_info_dict_local = {
            "DOI": doi,
            "Title": all_text.loc[all_text['DOI'] == doi, 'Title'].iloc[0] if not all_text.loc[all_text['DOI'] == doi, 'Title'].empty else "No Title",
            "Summary": "No Method section provided.",
            "Study Type": "Not stated",
            "Randomization": "Not stated",
            "Groups or Comparisons": [],
            "Mediators / Confounders": [],
            "Data Collection Method": [],
            "Causal Inference Implications": ""
        }
    else:
        the_method_row = the_method_row.iloc[0]
        title = the_method_row['Title']
        method_text = the_method_row['Text']
        study_info_dict_local = analyze_method_text(method_text, doi, title)

    # Assign to global variable
    study_info_dict = study_info_dict_local

    # Convert randomization to a colored badge
    if study_info_dict['Randomization'].lower() == "yes":
        random_badge_html = (
            "<span style='background-color:#1827FF; color:white; "
            "padding:2px 6px; border-radius:4px; font-weight:bold;'>"
            "Randomized</span>"
        )
    elif study_info_dict['Randomization'].lower() == "no":
        random_badge_html = (
            "<span style='background-color:#1827FF; color:white; "
            "padding:2px 6px; border-radius:4px; font-weight:bold;'>"
            "Not Randomized</span>"
        )
    else:
        random_badge_html = (
            "<span style='background-color:#ccc; color:black; "
            "padding:2px 6px; border-radius:4px; font-weight:bold;'>"
            "Not Stated</span>"
        )

    # Formatting helpers
    def make_br_list(items):
        return "<br>".join(f"\"{i.strip()}\"" for i in items) if items else "Not stated"

    groups_list = study_info_dict['Groups or Comparisons']
    mediators_list = study_info_dict['Mediators / Confounders']
    data_list = study_info_dict['Data Collection Method']

    groups_html = make_br_list(groups_list)
    mediators_html = make_br_list(mediators_list)
    data_html = make_br_list(data_list)

    # Build method summary table
    table_html = f"""
    <table style="border-collapse:collapse; margin-bottom:15px; width:100%;">
      <tr>
        <th style="border:1px solid #ccc; padding:6px;">Groups / Comparisons</th>
        <th style="border:1px solid #ccc; padding:6px;">Mediators / Confounders</th>
        <th style="border:1px solid #ccc; padding:6px;">Data Collection Method</th>
      </tr>
      <tr>
        <td style="border:1px solid #ccc; vertical-align:top; padding:6px;">
          {groups_html}
        </td>
        <td style="border:1px solid #ccc; vertical-align:top; padding:6px;">
          {mediators_html}
        </td>
        <td style="border:1px solid #ccc; vertical-align:top; padding:6px;">
          {data_html}
        </td>
      </tr>
    </table>
    """

    # # Bullet points for "Causal Inference Implications"
    # implications_raw = study_info_dict['Causal Inference Implications'] or ""
    # if implications_raw and implications_raw != "Not stated":
    #     lines = [ln.strip() for ln in re.split(r'\.\s*', implications_raw) if ln.strip()]
    #     bullet_html = "<ul>" + "".join(f"<li>{line}.</li>" for line in lines) + "</ul>"
    # else:
    #     bullet_html = "<p>Not stated.</p>"

    # Extracting "Causal Inference Implications" from study_info_dict
    implications_raw = study_info_dict['Causal Inference Implications'] or ""

    # Check if there are implications and they are not explicitly stated
    if implications_raw and implications_raw != "Not stated":
        # Split into lines and clean each line
        lines = [ln.strip() for ln in re.split(r'\.\s*', implications_raw) if ln.strip()]
        # Generate HTML without bullet points
        implications_html = "<p>" + "</p><p>".join(lines) + "</p>"
    else:
        implications_html = "<p>Not stated.</p>"


    summary_html = f"""
    <h3 style="margin-top:0;">Method Summary</h3>
    <p style="margin-top:0;">{study_info_dict['Summary']}</p>

    <p><b>Study Type:</b> {study_info_dict['Study Type']}</p>
    <p><b>Randomization:</b> {random_badge_html}</p>

    {table_html}

    <p><b>Causal Inference Implications:</b></p>
    {implications_html}
    """

    # Build the paper text (title, abstract, method, results, discussion, conclusion)
    all_sections = all_text.loc[all_text['DOI'] == doi].copy()
    section_order = ["title", "abstract", "method", "results", "discussion", "conclusion"]
    all_sections['Section'] = all_sections['Section'].str.lower().str.strip()
    all_sections['SectionOrder'] = all_sections['Section'].apply(
        lambda x: section_order.index(x) if x in section_order else len(section_order)
    )
    all_sections = all_sections.sort_values('SectionOrder').drop(columns=['SectionOrder'])

    highlighted_text_all = ""
    for _, row in all_sections.iterrows():
        sec_lower = row['Section']
        highlight = highlight_causal_sentences(
            row['Text'], causal_sentences_dict, confidence_dict, causal_categories
        )
        if sec_lower == 'title':
            highlighted_text_all += f"<h1 data-sec='title' style='font-size:28px; line-height:1.4; font-weight:bold; margin-bottom:4px;'>{highlight}</h1>"
            continue
        if sec_lower == 'method':
            # Add "Method Summary" heading only once
            highlighted_text_all += f"<h3 data-sec='{sec_lower}' style='margin-bottom:4px;'>{row['Section'].capitalize()}</h3>"
            highlighted_text_all += f"<p data-sec='{sec_lower}'>{highlight}</p>"
            # Insert the Method Summary below the Method section
            # highlighted_text_all += f"{summary_html}"
            continue
        highlighted_text_all += f"<h3 data-sec='{sec_lower}' style='margin-bottom:4px;'>{row['Section'].capitalize()}</h3>"
        highlighted_text_all += f"<p data-sec='{sec_lower}'>{highlight}</p>"

    # Decision Tree Warnings
    sentences_df = paper_df.loc[:, ['DOI', 'Section', 'Sentence', 'Causal', 'Confidence']].copy()
    study_type = study_info_dict['Study Type']

    # If warnings are not passed in, run analysis (keeps a default if user calls build_ui directly)
    if warnings is None:
        if guideline_mode is None:
            # Default to 'strict' if no mode is provided
            guideline_mode = 'strict'
        warnings = decision_tree_analysis(sentences_df, study_type, guideline_mode)

    # Generate warnings/feedback HTML
    if warnings:
        # Sort warnings to ensure specific warnings are on top
        warning_order = [12, 6, 1, 2, 4, 5, 9, 10, 11]  # Define desired order
        sorted_warnings = sorted(
            warnings,
            key=lambda x: warning_order.index(x['key']) if x['key'] in warning_order else len(warning_order)
        )

        warnings_feedback_html = "<div style='margin-bottom:20px;'><h3>Warnings and Feedback</h3><div style='display: flex; flex-wrap: wrap; gap: 20px;'>"
        for warning in sorted_warnings:
            # Check if it's Warning 1 with mode-specific descriptions
            if warning['key'] == 1:
                # It's Warning 1
                description_strict = warning.get('description_strict', '')
                description_moderate = warning.get('description_moderate', '')
                # Choose description based on mode
                if guideline_mode == 'strict':
                    selected_description = description_strict
                elif guideline_mode == 'moderate':
                    selected_description = description_moderate
                else:
                    selected_description = description_strict  # Fallback

                trigger_sentences = warning.get('trigger_sentences', [])
                # Define examples for Warning 1 based on mode
                if guideline_mode == 'strict':
                    example_html = '''
                    <p><b>Examples:</b></p>
                    <ul>
                        <li><b>Less accurate:</b> “Testosterone causes higher narcissism.”</li>
                        <li><b>More accurate:</b> “Testosterone levels were associated with narcissism, but this study does not confirm causality.”</li>
                    </ul>
                    <p><b>Causal claims are not allowed in this mode.</b></p>
                    '''
                elif guideline_mode == 'moderate':
                    example_html = '''
                    <p><b>Examples:</b></p>
                    <ul>
                        <li><b>Appropriate:</b> “Testosterone might be associated with higher narcissism levels after controlling for age and baseline health.”</li>
                    </ul>
                    '''
                else:
                    example_html = ''  # No examples for undefined mode

                # Extract trigger sentences count
                trigger_count = len(trigger_sentences)
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                # Assign a unique ID to Warning 1 card for dynamic updates
                warnings_feedback_html += f"""
                <div class="warning-card" id="warning1-card" data-description-strict='{description_strict.replace("'", "&#39;")}' data-description-moderate='{description_moderate.replace("'", "&#39;")}'>
                  <h4>Warning</h4>
                  <p class="warning-description">{selected_description}</p>
                  <p><b>Trigger Sentences Count:</b> {trigger_count}</p>
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    {example_html}
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue  # Move to the next warning after handling Warning 1

            # For Warning 12, handle specially to include per-section details
            if warning['key'] == 12:
                description = warning['description']
                reason = warning['reason']
                distribution = warning.get('distribution', '')
                trigger_sentences = warning.get('trigger_sentences', [])

                # Format distribution
                distribution_html = f"""
                <p><b>Distribution:</b></p>
                {distribution}
                """

                # Format trigger sentences
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>Warning</h4>
                  <p>{description}</p>
                  <p>{reason}</p>
                  {distribution_html}
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue  # Move to the next warning after handling Warning 12

            # For Warning 6, include examples in details
            if warning['key'] == 6:
                description = warning['description']
                reason = warning['reason']
                trigger_sentences = warning.get('trigger_sentences', [])

                # Define examples for Warning 6
                examples_html = '''
                <p><b>Examples:</b></p>
                <ul>
                    <li><b>Without addressing confounders:</b> "Testosterone might cause narcissistic behavior."</li>
                    <li><b>With confounders addressed:</b> "Testosterone might be related to narcissistic behavior, though confounding variables may still exist."</li>
                </ul>
                '''

                # Format trigger sentences
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>Warning</h4>
                  <p>{description}</p>
                  <p>{reason}</p>
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    {examples_html}
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue  # Move to the next warning after handling Warning 6

            # For other warnings
            key = warning['key']
            description = warning.get('description', 'No description available.')
            reason = warning.get('reason', 'No reason provided.')
            trigger_sentences = warning.get('trigger_sentences', [])
            example_html = ""
            # Handle specific warnings with additional tips
            if key == 9:
                # Warning 9
                tip_html = """
                <p><b>Tip:</b> Distinguish between statistical and practical significance by providing effect sizes, confidence intervals, or p-values.</p>
                """
                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>General Recommendation</h4>
                  <p>{description}</p>
                  <p>{reason}</p>
                  {tip_html}
                </div>
                """
                continue
            if key == 10:
                # Warning 10
                tip_html = """
                <p><b>Tip:</b> Include limitations in your methods or discussion sections to provide transparency about your study's scope and potential weaknesses.</p>
                """
                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>General Recommendation</h4>
                  <p>{description}</p>
                  <p>{reason}</p>
                  {tip_html}
                </div>
                """
                continue
            if key == 4:
                # Warning 4
                trigger_count = len(trigger_sentences)
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>Warning</h4>
                  <p>{description}</p>
                  <p><b>Trigger Sentences Count:</b> {trigger_count}</p>
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue
            if key == 2:
                # Warning 2
                trigger_count = len(trigger_sentences)
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>Warning</h4>
                  <p>{description}</p>
                  <p><b>Trigger Sentences Count:</b> {trigger_count}</p>
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue
            if key == 5 or key == 11:
                # Warning 5 and 11
                trigger_count = len(trigger_sentences)
                trigger_sentences_html = "".join(f'<li><i>{item["section"]}</i>: "{item["sentence"]}"</li>' for item in trigger_sentences)

                warnings_feedback_html += f"""
                <div class="warning-card">
                  <h4>Warning</h4>
                  <p>{description}</p>
                  <p><b>Trigger Sentences Count:</b> {trigger_count}</p>
                  <button onclick="toggleWarningDetails(this)">View Details</button>
                  <div class="warning-details" style="display: none;">
                    <p><b>Trigger Sentences:</b></p>
                    <ul>
                      {trigger_sentences_html}
                    </ul>
                  </div>
                </div>
                """
                continue

        warnings_feedback_html += "</div></div>"
    else:
        warnings_feedback_html = """
        <div style='margin-bottom:20px;'>
          <h3>Warnings and Feedback</h3>
          <p>No warnings detected based on the current analysis.</p>
        </div>
        """

    mode_info_html = """
    <div style="margin-bottom:10px;">
      <p><b>Strict Mode:</b> Only RCTs can use causal language. Non-RCTs (observational studies) must use "association" instead.</p>
      <p><b>Moderate Mode:</b> Allows well-controlled observational studies to make causal claims if they account for confounders.</p>
    </div>
    """

    # Properly escape curly braces in CSS and JavaScript by doubling them
    final_html = f"""
    <style>
      #topBar {{
        background: none;
        padding: 10px 0;
        margin-bottom: 10px;
      }}
      #titleDivider {{
        border: none;
        border-bottom: 1px solid #ccc;
        margin: 10px 0;
      }}
      #outerContainer {{
        max-height: 85vh;
        overflow-y: auto;
        padding-right: 10px;
        font-family: Arial, sans-serif;
      }}
      #container {{
        display: flex;
        flex-direction: row;
        width: 100%;
      }}
      #mainContent {{
        flex: 1;
        margin-right: 20px;
        padding: 0 20px 100px 20px;
      }}
      #mainContent p {{
        line-height: 1.7;
        text-align: justify;
        margin-top: 6px;
        margin-bottom: 14px;
      }}
      #sidebar {{
        width: 300px;
        min-width: 300px;
        border: none !important;
        background: #fff;
        padding: 10px;
        position: relative;
        box-sizing: border-box;
      }}
      #sidebarOverlay {{
        position: absolute;
        width: 280px;
        background: #f9f9f9;
        padding: 10px;
        border-radius: 5px;
        display: none;
        box-shadow: 0 2px 5px rgba(0,0,0,0.2);
        z-index: 1000;
      }}
      #viewToggles {{
        margin: 10px 0 20px;
        display: flex;
        align-items: center;
        justify-content: space-between;
      }}
      #sectionToggles button {{
        margin-right: 5px;
        padding: 6px 10px;
        cursor: pointer;
        border-radius: 20px;
        border: 1px solid #ccc;
        background: #ccc;
        color: #333;
      }}
      #sectionToggles button.active {{
        background-color: #1827FF;
        color: white;
      }}
      #modeToggle {{
        display: flex;
        align-items: center;
        gap: 10px;
        flex-wrap: wrap;
      }}
      #modeToggle button {{
        padding: 6px 10px;
        cursor: pointer;
        border-radius: 20px;
        border: 1px solid #ccc;
        background: #ccc;
        color: #333;
      }}
      #modeToggle button.active {{
        background-color: #1827FF;
        color: white;
      }}
      #confidenceSliderContainer {{
        display: flex;
        flex-direction: column;
        align-items: flex-end;
      }}
      #confidenceSlider {{
        width: 350px;
        -webkit-appearance: none;
        appearance: none;
        height: 6px;
        border-radius: 3px;
        background: #ccc;
        cursor: pointer;
        background-image: linear-gradient(to right, #1827FF 0%, #1827FF 0%, #ccc 0%, #ccc 100%);
      }}
      #confidenceSlider::-webkit-slider-thumb {{
        -webkit-appearance: none;
        height: 20px;
        width: 20px;
        border-radius: 50%;
        background: #1827FF;
        cursor: pointer;
        border: none;
      }}
      #confidenceSlider::-moz-range-thumb {{
        height: 20px;
        width: 20px;
        border-radius: 50%;
        background: #1827FF;
        cursor: pointer;
        border: none;
      }}
      #confidenceSliderValueNumber {{
        font-size: 1.0em;
        font-weight: bold;
      }}
      #confidenceSliderValueLabel {{
        font-size: 1.0em;
        margin-left:6px;
        /* Removed position:absolute; */
      }}
      #sidebarOverlay .explain-btn,
      #sidebarOverlay .ask-more-btn,
      #sidebarOverlay .submit-ask-btn {{
        margin-top: 10px;
        padding: 5px 10px;
        font-size: 0.9em;
        cursor: pointer;
        border: none;
        background-color: #1827FF;
        color: white;
        border-radius: 3px;
      }}
      #sidebarOverlay .explain-btn:hover,
      #sidebarOverlay .ask-more-btn:hover,
      #sidebarOverlay .submit-ask-btn:hover {{
        background-color: #0f1fb5;
      }}
      #sidebarOverlay .ask-more-input {{
        width: 100%;
        padding: 5px;
        margin-top: 10px;
        box-sizing: border-box;
        border: 1px solid #ccc;
        border-radius: 3px;
      }}
      #sectionToggles {{
        margin-left: 20px;
        margin-top: 45px;
      }}
      #confidenceSliderContainer {{
        margin-right: 345px;
      }}
      #viewToggles {{
          border-bottom: 1px solid #ccc;
          padding-bottom: 10px;
          margin-bottom: 20px;
        }}
      .warning-card {{
        flex: 1 1 30%;
        max-width: 30%;
        border:1px solid #ddd;
        padding:15px;
        border-radius:5px;
        box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
        position: relative;
        background-color: #f0f0f0; /* Updated to grey background */
        border-left: 5px solid #1827FF; /* Added blue stripe on the left */
      }}
      .warning-card h4 {{
        margin-top:0;
        font-size: 1.5em;
        color: black;
      }}
      .warning-card button {{
        background-color:#1827FF;
        color:white;
        border:none;
        padding:5px 10px;
        border-radius:3px;
        cursor:pointer;
      }}
      .warning-card button:hover {{
        background-color:#0f1fb5;
      }}
      .warning-details {{
        display: none;
        margin-top:10px;
      }}
    </style>

    <div id="outerContainer">
      <div id="topBar">
        <h1 style="font-size:28px; line-height:1.4; font-weight:bold; margin-bottom:8px;">{study_info_dict['Title']}</h1>
        <div style="font-size:0.9em; color:#666;">DOI: {study_info_dict['DOI']}</div>
        <hr id="titleDivider"/>
        <p style="margin-top:8px; font-size:0.95em; color:#333;">
          Our model classifies sentences into <b>Non-Causal</b>, <b>Correlational</b>,
          <b>Conditional Causal</b>, or <b>Direct Causal</b>. 1. Correlational: Describes an association between variables without implying causation. 2. Conditional Causal: Suggests one variable may influence another, but with uncertainty. 3. Direct Causal: Clearly states that one variable directly causes changes in another without doubt.
        </p>
      </div>

      <div style="display:flex; flex-wrap:wrap; gap:30px; margin-bottom:20px;">
        {chart}
        <div style="width:48%; display:flex; flex-direction:column; justify-content:flex-start;">
          {summary_html}
        </div>
      </div>

      <!-- Mode Toggle and Descriptions Above Warnings -->
      <div style="margin-bottom:20px;">
        {mode_info_html}
        <div id="modeToggle" style="display:flex; align-items:center; gap:10px;">
          <button class="modeToggleBtn {'active' if guideline_mode=='strict' else ''}" data-mode="strict">Strict Mode</button>
          <button class="modeToggleBtn {'active' if guideline_mode=='moderate' else ''}" data-mode="moderate">Moderate Mode</button>
        </div>
      </div>

      {warnings_feedback_html}

      <div id="viewToggles">
        <div id="sectionToggles" style="display:flex; gap:10px;">
          <button class="viewToggleBtn active" data-section="abstract">Abstract</button>
          <button class="viewToggleBtn active" data-section="method">Method</button>
          <button class="viewToggleBtn active" data-section="results">Results</button>
          <button class="viewToggleBtn active" data-section="discussion">Discussion</button>
          <button class="viewToggleBtn active" data-section="conclusion">Conclusion</button>
        </div>
        <div id="confidenceSliderContainer">
          <h3 style="margin:0; padding:0; font-size:1.0em; color:#333;">
            Confidence Score Slider
          </h3>
          <p style="margin:0 0 8px 0; font-size:0.85em; color:#666;">
            These scores show how strict or lenient the AI is, not how accurate it is.
          </p>
          <div style="display:flex; align-items:center; gap:10px;">
            <input
              type="range"
              id="confidenceSlider"
              min="0"
              max="100"
              step="1"
              value="0"
            />
            <div style="color:#333; width: 100px; text-align:right;">
              <span id="confidenceSliderValueNumber">0.00</span>
              <span id="confidenceSliderValueLabel" style="margin-left:6px;">Show All</span>
            </div>
          </div>
        </div>
      </div>

      <div id="container">
        <div id="mainContent">
          {highlighted_text_all}
        </div>
        <div id="sidebar">
          <div id="sidebarOverlay" data-sentence="" data-classification=""></div>
        </div>
      </div>
    </div>

    <script src="https://cdn.plot.ly/plotly-latest.min.js"></script>
    <script>
      // Function to toggle warning card details
      function toggleWarningDetails(button) {{
        const details = button.nextElementSibling;
        if (details.style.display === "none" || details.style.display === "") {{
          details.style.display = "block";
          button.textContent = "Hide Details";
        }} else {{
          details.style.display = "none";
          button.textContent = "View Details";
        }}
      }}

      const overlay = document.getElementById('sidebarOverlay');
      const sidebar = document.getElementById('sidebar');
      let lastHovered = null;
      let currentThreshold = 0.0;
      let lockedSentence = null; // Track the locked sentence

      document.addEventListener('mouseover', function(e) {{
        if (e.target.classList.contains('hover-sentence')) {{
          if (lockedSentence && lockedSentence !== e.target) {{
            // Do not affect locked sentence
            return;
          }}
          const conf = parseFloat(e.target.getAttribute('data-confidence'));
          const brightColor = e.target.getAttribute('data-brightcolor');
          const decStyle = e.target.getAttribute('data-decstyle');
          const decThickness = e.target.getAttribute('data-decthickness');

          if (conf >= currentThreshold) {{
            e.target.style.textDecorationLine = 'underline';
            e.target.style.textDecorationStyle = decStyle;
            e.target.style.textDecorationColor = brightColor;
            e.target.style.textDecorationThickness = decThickness;
          }} else {{
            e.target.style.textDecorationLine = 'none';
          }}

          if (lastHovered && lastHovered !== e.target && lastHovered !== lockedSentence) {{
            lastHovered.style.textDecorationLine = 'none';
          }}
          lastHovered = e.target;

          const classification = e.target.getAttribute('data-classification');
          const sentence = e.target.getAttribute('data-sentence');
          if (conf < currentThreshold) {{
            overlay.style.backgroundColor = '#444';
            overlay.style.color = 'white';
            overlay.innerHTML = `
              <div>
                <span style="font-weight:bold;">Below Current Filter</span><br>
                <b>Classification:</b> ${{classification}}
              </div>
            `;
          }} else {{
            if (conf < 0.90) {{
              overlay.style.backgroundColor = 'black';
              overlay.style.color = 'white';
              overlay.innerHTML = `
                <div>
                  <span style="font-weight:bold;">⚠️ LOW CONFIDENCE</span><br>
                  <b>Confidence:</b> ${{conf.toFixed(2)}}<br>
                  <b>Classification:</b> ${{classification}}<br>
                  <button class="explain-btn" onclick="fetchExplanation('${{sentence}}','${{classification}}')">
                    Explain
                  </button>
                </div>
              `;
            }} else {{
              overlay.style.backgroundColor = '#f9f9f9';
              overlay.style.color = 'black';
              overlay.innerHTML = `
                <div>
                  <b>Classification:</b> ${{classification}}<br>
                  <button class="explain-btn" onclick="fetchExplanation('${{sentence}}','${{classification}}')">
                    Explain
                  </button>
                </div>
              `;
            }}
          }}

          overlay.setAttribute('data-sentence', sentence);
          overlay.setAttribute('data-classification', classification);

          const sidebarRect = sidebar.getBoundingClientRect();
          const targetRect = e.target.getBoundingClientRect();
          const relativeTop = targetRect.top - sidebarRect.top + window.scrollY;
          overlay.style.top = relativeTop + 'px';
          overlay.style.left = '10px';
          overlay.style.display = 'block';
        }}
      }});

      document.addEventListener('mouseout', function(e) {{
        if (e.target.classList.contains('hover-sentence')) {{
          if (lockedSentence !== e.target) {{
            e.target.style.textDecorationLine = 'none';
          }}
        }}
      }});

      // Click event to toggle lock mode
      document.addEventListener('click', function(e) {{
        if (e.target.classList.contains('hover-sentence')) {{
          if (lockedSentence === e.target) {{
            // Unlock the sentence
            lockedSentence.classList.remove('locked');
            e.target.style.textDecorationLine = 'none';
            overlay.style.display = 'none';
            lockedSentence = null;
          }} else {{
            // Unlock previous if any
            if (lockedSentence) {{
              lockedSentence.classList.remove('locked');
              lockedSentence.style.textDecorationLine = 'none';
            }}
            // Lock the new sentence
            lockedSentence = e.target;
            lockedSentence.classList.add('locked');
            const conf = parseFloat(lockedSentence.getAttribute('data-confidence'));
            const brightColor = lockedSentence.getAttribute('data-brightcolor');
            const decStyle = lockedSentence.getAttribute('data-decstyle');
            const decThickness = lockedSentence.getAttribute('data-decthickness');
            lockedSentence.style.textDecorationLine = 'underline';
            lockedSentence.style.textDecorationStyle = decStyle;
            lockedSentence.style.textDecorationColor = brightColor;
            lockedSentence.style.textDecorationThickness = decThickness;

            const classification = lockedSentence.getAttribute('data-classification');
            const sentence = lockedSentence.getAttribute('data-sentence');

            if (conf < currentThreshold) {{
              overlay.style.backgroundColor = '#444';
              overlay.style.color = 'white';
              overlay.innerHTML = `
                <div>
                  <span style="font-weight:bold;">Below Current Filter</span><br>
                  <b>Classification:</b> ${{classification}}
                </div>
              `;
            }} else {{
              if (conf < 0.90) {{
                overlay.style.backgroundColor = 'black';
                overlay.style.color = 'white';
                overlay.innerHTML = `
                  <div>
                    <span style="font-weight:bold;">⚠️ LOW CONFIDENCE</span><br>
                    <b>Confidence:</b> ${{conf.toFixed(2)}}<br>
                    <b>Classification:</b> ${{classification}}<br>
                    <button class="explain-btn" onclick="fetchExplanation('${{sentence}}','${{classification}}')">
                      Explain
                    </button>
                  </div>
                `;
              }} else {{
                overlay.style.backgroundColor = '#f9f9f9';
                overlay.style.color = 'black';
                overlay.innerHTML = `
                  <div>
                    <b>Classification:</b> ${{classification}}<br>
                    <button class="explain-btn" onclick="fetchExplanation('${{sentence}}','${{classification}}')">
                      Explain
                    </button>
                  </div>
                `;
              }}
            }}

            overlay.setAttribute('data-sentence', sentence);
            overlay.setAttribute('data-classification', classification);

            const sidebarRect = sidebar.getBoundingClientRect();
            const targetRect = lockedSentence.getBoundingClientRect();
            const relativeTop = targetRect.top - sidebarRect.top + window.scrollY;
            overlay.style.top = relativeTop + 'px';
            overlay.style.left = '10px';
            overlay.style.display = 'block';
          }}
        }}
      }});

      async function fetchExplanation(sentence, classification) {{
        overlay.innerHTML = `
          <div>
            <b>Classification:</b> ${{classification}}<br>
            <i>Loading explanation...</i>
          </div>
        `;
        try {{
          const response = await google.colab.kernel.invokeFunction(
            'ask_explanation', [sentence, classification], {{}}
          );
          const explanation = response.data['text/plain'];
          overlay.innerHTML = `
            <div>
              <b>Classification:</b> ${{classification}}<br>
              <b>Explanation:</b> ${{explanation}}<br>
              <button class="ask-more-btn" onclick="showAskMore()">Ask More</button>
              <div id="askMoreSection" style="display:none; margin-top:10px;">
                <input type="text" id="askMoreInput" class="ask-more-input" placeholder="Type your question here...">
                <button class="submit-ask-btn" onclick="submitAskMore()">Submit</button>
              </div>
              <div id="askMoreResponse" style="margin-top:10px;"></div>
            </div>
          `;
        }} catch (error) {{
          overlay.innerHTML = `
            <div>
              <b>Classification:</b> ${{classification}}<br>
              <span style="color:red;">Failed to fetch explanation.</span>
            </div>
          `;
        }}
      }}

      function showAskMore() {{
        const askMoreSection = document.getElementById('askMoreSection');
        askMoreSection.style.display = 'block';
      }}

      async function submitAskMore() {{
        const questionInput = document.getElementById('askMoreInput');
        const question = questionInput.value.trim();
        const askMoreResponse = document.getElementById('askMoreResponse');
        const classification = overlay.getAttribute('data-classification');
        const sentence = overlay.getAttribute('data-sentence');

        if (!question) {{
          askMoreResponse.innerHTML = '<span style="color:red;">Please enter a question.</span>';
          return;
        }}

        askMoreResponse.innerHTML = '<i>Loading answer...</i>';
        try {{
          const response = await google.colab.kernel.invokeFunction(
            'ask_more', [question, sentence, classification], {{}}
          );
          const answer = response.data['text/plain'];
          askMoreResponse.innerHTML = `<b>Answer:</b> ${{answer}}`;
        }} catch (error) {{
          askMoreResponse.innerHTML = '<span style="color:red;">Failed to fetch answer.</span>';
        }}
      }}

      // Confidence slider
      const confidenceSlider = document.getElementById('confidenceSlider');
      const sliderValueNumber = document.getElementById('confidenceSliderValueNumber');
      const sliderValueLabel = document.getElementById('confidenceSliderValueLabel');

      confidenceSlider.addEventListener('input', function() {{
        const val = parseInt(this.value);
        currentThreshold = val / 100.0;

        sliderValueNumber.textContent = currentThreshold.toFixed(2);

        let label = "Show All";
        if (currentThreshold >= 0.99) {{
          label = "Very Strict";
        }} else if (currentThreshold >= 0.95) {{
          label = "Strict";
        }} else if (currentThreshold >= 0.90) {{
          label = "Lenient";
        }}
        sliderValueLabel.textContent = label;

        const fillPercent = val;
        this.style.background = `linear-gradient(to right, #1827FF 0%, #1827FF ${{fillPercent}}%, #ccc ${{fillPercent}}%, #ccc 100%)`;

        const sentences = document.querySelectorAll('.hover-sentence');
        sentences.forEach(span => {{
          const conf = parseFloat(span.getAttribute('data-confidence'));
          const originalBG = span.getAttribute('data-bgcolor');
          if (conf >= currentThreshold) {{
            span.style.backgroundColor = originalBG;
          }} else {{
            span.style.backgroundColor = 'transparent';
          }}
          span.style.textDecorationLine = 'none';
        }});

        // Hide overlay if lastHovered is below threshold
        if (overlay.style.display === 'block' && lastHovered) {{
          const hoveredConf = parseFloat(lastHovered.getAttribute('data-confidence'));
          if (hoveredConf < currentThreshold) {{
            overlay.style.display = 'none';
          }}
        }}
      }});

      // Section toggles
      const toggleBtns = document.querySelectorAll('.viewToggleBtn');
      toggleBtns.forEach(btn => {{
        btn.addEventListener('click', function() {{
          btn.classList.toggle('active');
          const sec = btn.getAttribute('data-section');
          const isActive = btn.classList.contains('active');
          const headings = document.querySelectorAll(`#mainContent h3[data-sec='${{sec}}']`);
          const paragraphs = document.querySelectorAll(`#mainContent p[data-sec='${{sec}}']`);

          headings.forEach(el => {{
            el.style.display = isActive ? 'block' : 'none';
          }});
          paragraphs.forEach(el => {{
            el.style.display = isActive ? 'block' : 'none';
          }});
        }});
      }});

      // Mode toggles
      const modeToggleBtns = document.querySelectorAll('.modeToggleBtn');
      modeToggleBtns.forEach(btn => {{
        btn.addEventListener('click', function() {{
          modeToggleBtns.forEach(b => b.classList.remove('active'));
          btn.classList.add('active');
          const selectedMode = btn.getAttribute('data-mode');

          // Update Warning 1 Description Based on Mode
          const warning1Card = document.getElementById('warning1-card');
          if (warning1Card) {{
            let selected_description = "";
            if (selectedMode === 'strict') {{
              selected_description = warning1Card.getAttribute('data-description-strict');
            }} else if (selectedMode === 'moderate') {{
              selected_description = warning1Card.getAttribute('data-description-moderate');
            }} else {{
              selected_description = warning1Card.getAttribute('data-description-strict'); // Fallback
            }}
            const descriptionPara = warning1Card.querySelector('.warning-description');
            if (descriptionPara) {{
              descriptionPara.innerHTML = selected_description;
            }}
          }}
        }});
      }});
    </script>
    """
    display(HTML(final_html))


[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [10]:
#@title Load Final UI

from google.colab import output
output.eval_js("google.colab.output.setIframeHeight('20000px')")

build_ui(final_df, all_text, doi)

Groups / Comparisons,Mediators / Confounders,Data Collection Method
"""Varied-intensity work intervals"" ""constant-intensity work intervals""","""age"" ""training habits"" ""cycling experience"" ""Not stated (but controlled for: environmental conditions"" ""meal consumption"" ""caffeine intake)""","""lactate threshold test"" ""maximal incremental test"" ""heart rate monitoring"" ""VO2 measurement"" ""blood sampling"" ""questionnaire"" ""cycle computer"""


In [None]:
#@title Final UI (Simplified for Text Input + Two-Stage Classification)

###############################
# 0) IMPORTS AND SETUP
###############################
import nltk
import torch
import torch.nn.functional as F
import pandas as pd
import ipywidgets as widgets
import plotly.graph_objects as go
from IPython.display import display, HTML, clear_output, update_display
from transformers import BertForSequenceClassification, BertTokenizer

# If you're using Colab / Jupyter:
try:
    from google.colab import output
    HAVE_GOOGLE_COLAB = True
except ImportError:
    HAVE_GOOGLE_COLAB = False

# If you want "Explain" and "Ask More" via Groq:
try:
    from groq import Groq
    client = Groq(api_key="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")
    HAVE_GROQ = True
except:
    HAVE_GROQ = False

nltk.download("punkt", quiet=True)

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

###############################
# 1) LOAD TWO-STAGE MODELS
###############################
# Change to your actual model paths
model_path_stage1 = "/content/drive/My Drive/HTI thesis/2_BERT"  # e.g. "/content/drive/My Drive/HTI thesis/2_BERT"
model_stage1 = BertForSequenceClassification.from_pretrained(model_path_stage1)
tokenizer_stage1 = BertTokenizer.from_pretrained(model_path_stage1)

model_path_stage2 = "/content/drive/My Drive/HTI thesis/3_BERT"  # e.g. "/content/drive/My Drive/HTI thesis/3_BERT"
model_stage2 = BertForSequenceClassification.from_pretrained(model_path_stage2)
tokenizer_stage2 = BertTokenizer.from_pretrained(model_path_stage2)

model_stage1.to(device)
model_stage2.to(device)

###############################
# 2) TWO-STAGE CLASSIFY FUNCTION
###############################
def two_stage_classify(
    sentences,
    batch_size=32
):
    """
    Stage1 => 0=Non-Causal, 1=Causal
    Stage2 => 0=Correlational, 1=Conditional, 2=Direct
    Final => 0=Non-Causal, 1=Corr, 2=Cond, 3=Direct
    Skips very short sentences (<5 words).
    """
    filtered = [s.strip() for s in sentences if len(s.split()) >= 5]
    results = []
    model_stage1.eval()
    model_stage2.eval()

    with torch.no_grad():
        for i in range(0, len(filtered), batch_size):
            batch = filtered[i:i+batch_size]
            inputs = tokenizer_stage1(batch, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
            out1 = model_stage1(**inputs)
            probs1 = F.softmax(out1.logits, dim=-1)
            pred1 = torch.argmax(probs1, dim=-1).cpu().tolist()
            conf1 = torch.max(probs1, dim=-1).values.cpu().tolist()

            # Stage2 if stage1>0
            causal_sents = [batch[j] for j in range(len(batch)) if pred1[j] > 0]
            if causal_sents:
                inputs2 = tokenizer_stage2(causal_sents, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
                out2 = model_stage2(**inputs2)
                probs2 = F.softmax(out2.logits, dim=-1)
                pred2 = torch.argmax(probs2, dim=-1).cpu().tolist()
                conf2 = torch.max(probs2, dim=-1).values.cpu().tolist()

                idx2 = 0
                for j, s in enumerate(batch):
                    if pred1[j] > 0:
                        # +1 => 0->1, 1->2, 2->3
                        label_final = pred2[idx2] + 1
                        conf_final = conf2[idx2]
                        idx2 += 1
                    else:
                        label_final = 0
                        conf_final = conf1[j]
                    results.append((s, label_final, conf_final))
            else:
                # all non-causal
                for j, s in enumerate(batch):
                    results.append((s, pred1[j], conf1[j]))

    # For omitted short sentences => Non-Causal
    omitted = set(sentences) - set(filtered)
    for s in omitted:
        results.append((s, 0, 0.75))

    # Sort back in original order
    order_map = {s:i for i,s in enumerate(sentences)}
    results.sort(key=lambda x: order_map.get(x[0], 999999))

    return results

###############################
# 3) CHART FUNCTION
###############################
def build_stacked_bar_chart(class_results):
    """
    class_results: list of (sentence, label, confidence)
    label in [0..3]
    Build a horizontal stacked bar for [Non-causal, Correlational, Conditional, Direct].
    """
    label_map = {0:"Non-Causal",1:"Correlational",2:"Conditional Causal",3:"Direct Causal"}
    bright_colors = {
        0:'#E0E0E0',
        1:'#1827FF',
        2:'#2CBCFF',
        3:'#FB43FF'
    }
    # Count
    total = len(class_results)
    counts = {0:0,1:0,2:0,3:0}
    for _, lbl, _ in class_results:
        counts[lbl] += 1

    fig = go.Figure()
    # We'll just do 1 row: "Distribution"
    for lbl in [1,2,3]:
        fig.add_trace(go.Bar(
            name=label_map[lbl],
            y=["Distribution"],
            x=[counts[lbl]],
            orientation="h",
            marker=dict(color=bright_colors[lbl]),
            hovertemplate=f"{label_map[lbl]}: {{x}}<extra></extra>"
        ))

    fig.add_trace(go.Bar(
        name="Non-Causal",
        y=["Distribution"],
        x=[counts[0]],
        orientation="h",
        marker=dict(color=bright_colors[0]),
        hovertemplate="Non-Causal: {x}<extra></extra>"
    ))

    fig.update_layout(
        barmode="stack",
        height=100,
        margin=dict(l=0, r=0, t=10, b=0),
        showlegend=True,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
        xaxis=dict(title="Number of Sentences"),
        yaxis=dict(showticklabels=False)
    )

    return fig.to_html(full_html=False, include_plotlyjs='cdn', config={'displayModeBar': False})

###############################
# 4) HIGHLIGHT FUNCTION
###############################
def highlight_sentences(text, class_results):
    """
    text: original text
    class_results: list of (sentence, label, confidence)
    label-> [0..3], conf float
    """
    # Label map for display
    label_map = {1:"Correlational",2:"Conditional Causal",3:"Direct Causal"}
    bright = {1:'#1827FF',2:'#2CBCFF',3:'#FB43FF'}
    pastel_high = {1:'#C6D6FF',2:'#C2F0FF',3:'#FFC9FF'}
    pastel_low  = {1:'#EFF3FF',2:'#EEFCFF',3:'#FFF0FF'}

    # Convert to dict
    label_dict = {t[0]: t[1] for t in class_results}
    conf_dict  = {t[0]: t[2] for t in class_results}

    sentences = nltk.sent_tokenize(text)
    out_html = ""

    for s in sentences:
        ctype = label_dict.get(s, 0)
        conf  = conf_dict.get(s, 1.0)

        if ctype == 0:
            # Non-causal
            out_html += f"{s} "
            continue

        # Colors
        if conf >= 0.90:
            bg_col = pastel_high.get(ctype, '#e0e0e0')
            dec_style = "solid"
            dec_thickness = "2px"
        else:
            bg_col = pastel_low.get(ctype, '#f3f3f3')
            dec_style = "wavy"
            dec_thickness = "1.5px"
        bright_color = bright.get(ctype, '#000')
        cat_name = label_map.get(ctype, "Non-Causal")

        # Escape quotes
        s_esc = s.replace('"', '\\"')
        snippet = f"""
        <span class="hover-sentence"
              style="background-color:{bg_col}; padding:0.2em; border-radius:5px; cursor:pointer;"
              data-conf="{conf}"
              data-bright="{bright_color}"
              data-decstyle="{dec_style}"
              data-decth="{dec_thickness}"
              data-sent="{s_esc}"
              data-cls="{cat_name}"
        >
          {s}
        </span>
        """
        out_html += snippet + " "

    return out_html

###############################
# 5) EXPLAIN / ASK MORE (Groq)
###############################
def ask_explanation(sentence, classification):
    if not (HAVE_GROQ and sentence and classification):
        return "Explanation not available or invalid."
    prompt = f"""
The following sentence was classified as '{classification}':
\"{sentence}\"
Please explain briefly why it falls under this category.
"""
    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role":"system","content":"You are a helpful assistant."},
                {"role":"user","content":prompt}
            ],
            model="llama3-8b-8192"
        )
        return chat_completion.choices[0].message.content.strip()
    except:
        return "Failed to fetch explanation."

def ask_more_followup(question, sentence, classification):
    if not (HAVE_GROQ and question and sentence and classification):
        return "Ask More not available or invalid."
    prompt = f"""
We had a sentence classified as '{classification}':
\"{sentence}\"
Follow-up question: \"{question}\"

Please provide a detailed but concise answer.
"""
    try:
        chat_completion = client.chat.completions.create(
            messages=[
                {"role":"system","content":"You are a helpful assistant."},
                {"role":"user","content":prompt}
            ],
            model="llama3-8b-8192"
        )
        return chat_completion.choices[0].message.content.strip()
    except:
        return "Failed to fetch answer."

# Register callbacks if in Colab
if HAVE_GOOGLE_COLAB:
    output.register_callback('ask_explanation', ask_explanation)
    output.register_callback('ask_more_followup', ask_more_followup)

###############################
# 6) BUILD & DISPLAY UI
###############################
def build_ui(section_label, text, classification_results):
    """
    Creates the final HTML with:
      - Title area
      - Stacked bar chart
      - Highlighted text
      - A side overlay that does NOT vanish on mouseout (includes an 'X' to close)
      - "Explain" & "Ask More" buttons
      - Confidence slider
      - Section toggles (but we only have one section, so optional)
    """
    chart_html = build_stacked_bar_chart(classification_results)
    highlighted_html = highlight_sentences(text, classification_results)

    # We'll store everything in an f-string. Carefully handle braces with doubling.
    html_code = f"""
    <style>
      #topBar {{
        background: #f4f4f4;
        padding: 20px;
        margin-bottom: 20px;
      }}
      #topBar h1 {{
        margin: 0;
        font-size: 2em;
      }}
      .paperSection {{
        font-size: 1.1em;
        margin: 5px 0;
        color: #333;
      }}
      #chartContainer {{
        width: 100%;
        margin: 20px 0;
      }}
      #container {{
        display: flex;
        flex-direction: row;
      }}
      #mainContent {{
        flex: 1;
        margin-right: 20px;
        padding: 10px;
      }}
      #sidebar {{
        width: 350px;
        min-width: 280px;
        border: 1px solid #ccc;
        background: #fff;
        padding: 10px;
        position: relative;
      }}
      #sidebarOverlay {{
        position: absolute;
        width: 90%;
        border: 1px solid #ccc;
        background: #f9f9f9;
        padding: 10px;
        border-radius: 5px;
        display: none;
        box-shadow: 0 2px 5px rgba(0,0,0,0.2);
        z-index: 9999;
      }}
      #mainContent p {{
        line-height: 1.6;
        text-align: justify;
      }}
      .close-overlay-btn {{
        float:right;
        background: #aaa;
        color: #fff;
        border: none;
        cursor: pointer;
        border-radius: 3px;
        font-weight: bold;
        padding: 0 6px;
        margin-bottom: 4px;
      }}
      .ask-more-input {{
        width:100%;
        padding:5px;
        margin-top:8px;
        box-sizing:border-box;
      }}
    </style>

    <div id="topBar">
      <h1>Classification Results</h1>
      <div class="paperSection">Section: {section_label}</div>
    </div>

    <div id="chartContainer">
      {chart_html}
    </div>

    <div style="margin:10px 0;">
      <b>Confidence Slider:</b> Sentences below this score get faded out.
      <input type="range" id="confSlider" min="0" max="100" step="1" value="0" style="width:300px;">
      <span id="confSliderValue" style="margin-left:6px;">0.00 (Show All)</span>
    </div>

    <div id="container">
      <div id="mainContent">
        <p>{highlighted_html}</p>
      </div>
      <div id="sidebar">
        <h4 style="margin-top:0;">Hover Details</h4>
        <p style='font-size:0.9em;color:#666;'>Hover over a highlighted sentence to see details.</p>
        <div id="sidebarOverlay" data-sentence="" data-classification=""></div>
      </div>
    </div>

    <script>
      const confSlider = document.getElementById('confSlider');
      const confSliderVal = document.getElementById('confSliderValue');
      let currentThreshold = 0.0;

      confSlider.addEventListener('input', function(){{
        const val = parseInt(this.value);
        currentThreshold = val / 100.0;
        confSliderVal.textContent = currentThreshold.toFixed(2);

        const spans = document.querySelectorAll('.hover-sentence');
        spans.forEach(sp => {{
          const c = parseFloat(sp.getAttribute('data-conf'));
          if(c < currentThreshold) {{
            sp.style.backgroundColor = 'transparent';
            sp.style.textDecorationLine = 'none';
          }} else {{
            // restore background
            const brightColor = sp.getAttribute('data-bright');
            const decStyle = sp.getAttribute('data-decstyle');
            const decTh = sp.getAttribute('data-decth');
            // We have to recalc original BG color or store it in data attr
            // For simplicity, if it's above threshold, call highlight_sentences again?
            // We'll store original BG in another data attr if needed. For now let's skip it:
            // The user would need to re-run classification or refresh. We'll do a naive approach:
            const cVal = parseFloat(sp.getAttribute('data-conf'));
            let bg = '#e0e0e0';
            if(sp.getAttribute('data-cls')==='Correlational') bg = (cVal >= 0.90)? '#C6D6FF' : '#EFF3FF';
            if(sp.getAttribute('data-cls')==='Conditional Causal') bg = (cVal >= 0.90)? '#C2F0FF' : '#EEFCFF';
            if(sp.getAttribute('data-cls')==='Direct Causal') bg = (cVal >= 0.90)? '#FFC9FF' : '#FFF0FF';
            sp.style.backgroundColor = bg;
          }}
        }});
      }});

      const mainContent = document.getElementById('mainContent');
      const overlay = document.getElementById('sidebarOverlay');

      let lastHovered = null;

      document.addEventListener('mouseover', function(e){{
        if(e.target.classList.contains('hover-sentence')) {{
          const confVal = parseFloat(e.target.getAttribute('data-conf'));
          const brightColor = e.target.getAttribute('data-bright');
          const decStyle = e.target.getAttribute('data-decstyle');
          const decThick = e.target.getAttribute('data-decth');

          // If above threshold, show underline
          if(confVal >= currentThreshold) {{
            e.target.style.textDecorationLine = 'underline';
            e.target.style.textDecorationStyle = decStyle;
            e.target.style.textDecorationColor = brightColor;
            e.target.style.textDecorationThickness = decThick;
          }} else {{
            e.target.style.textDecorationLine = 'none';
          }}

          if(lastHovered && lastHovered!==e.target){{
            lastHovered.style.textDecorationLine = 'none';
          }}
          lastHovered = e.target;

          // Build content
          const classification = e.target.getAttribute('data-cls');
          const sentence = e.target.getAttribute('data-sent');

          let contentHtml = '<button class="close-overlay-btn" onclick="document.getElementById(\\'sidebarOverlay\\').style.display=\\'none\\';">X</button>';

          if(confVal < 0.90) {{
            contentHtml += `
              <div style='background:black; color:white; padding:6px;'>
                <b>⚠ LOW CONFIDENCE</b><br/>
                Confidence: ${{confVal.toFixed(2)}}<br/>
                Classification: ${{classification}}<br/>
                <button style='margin-top:8px;' onclick="fetchExplanation(\\'${{sentence}}\\', \\'${{classification}}\\')">Explain</button>
              </div>
            `;
          }} else {{
            contentHtml += `
              <div style='background:#f9f9f9; color:#000; padding:6px;'>
                <b>Classification:</b> ${{classification}}<br/>
                <button style='margin-top:8px;' onclick="fetchExplanation(\\'${{sentence}}\\', \\'${{classification}}\\')">Explain</button>
              </div>
            `;
          }}

          overlay.innerHTML = contentHtml;
          overlay.style.display = 'block';

          const mainRect = mainContent.getBoundingClientRect();
          const targetRect = e.target.getBoundingClientRect();
          const relativeTop = targetRect.top - mainRect.top;
          overlay.style.top = relativeTop + 'px';
        }}
      }});

      // We do NOT hide overlay on mouseout so you can click the buttons:
      // (comment out the event)

      async function fetchExplanation(sentence, classification){{
        overlay.innerHTML += '<p><i>Loading explanation...</i></p>';
        try {{
          const response = await google.colab.kernel.invokeFunction(
            'ask_explanation', [sentence, classification], {{}}
          );
          const explanation = response.data['text/plain'];
          overlay.innerHTML += `
            <p><b>Explanation:</b> ${{explanation}}</p>
            <button style='margin-top:8px;' onclick="showAskMore()">Ask More</button>
            <div id='askMoreArea' style='display:none; margin-top:8px;'>
              <input type='text' id='askMoreInput' class='ask-more-input' placeholder='Follow-up question...'/>
              <button onclick="submitAskMore()">Submit</button>
            </div>
            <div id='askMoreOutput' style='margin-top:8px; color:#333;'></div>
          `;
        }} catch(err) {{
          overlay.innerHTML += '<p style="color:red;">Failed to fetch explanation.</p>';
        }}
      }}

      function showAskMore(){{
        const area = document.getElementById('askMoreArea');
        area.style.display='block';
      }}

      async function submitAskMore(){{
        if(!lastHovered) return;
        const classification = lastHovered.getAttribute('data-cls');
        const sentence = lastHovered.getAttribute('data-sent');
        const question = document.getElementById('askMoreInput').value;
        const outDiv = document.getElementById('askMoreOutput');
        outDiv.innerHTML = '<i>Loading answer...</i>';
        try {{
          const response = await google.colab.kernel.invokeFunction(
            'ask_more_followup', [question, sentence, classification], {{}}
          );
          const answer = response.data['text/plain'];
          outDiv.innerHTML = '<b>Answer:</b> ' + answer;
        }} catch(e) {{
          outDiv.innerHTML = '<span style="color:red;">Failed to get answer.</span>';
        }}
      }}
    </script>
    """

    display(HTML(html_code))

###############################
# 7) MASTER: classify_and_display
###############################
def classify_and_display(section, text):
    """
    Splits text => two-stage classify => build UI
    """
    clear_output(wait=True)

    # Tokenize
    sentences = nltk.sent_tokenize(text)
    # Optional progress bar (omitted for simplicity)

    # Two-stage classification
    results = two_stage_classify(sentences)

    # Build final UI
    build_ui(section, text, results)

###############################
# 8) WIDGETS FOR USER INPUT
###############################
section_dropdown = widgets.Dropdown(
    options=["title","abstract","method","results","discussion","conclusion"],
    description="Section:"
)
text_input = widgets.Textarea(
    placeholder="Paste or type text here...",
    layout=widgets.Layout(width="100%", height="200px")
)
classify_button = widgets.Button(
    description="Classify Text",
    button_style="primary"
)

def on_classify_click(b):
    classify_and_display(section_dropdown.value, text_input.value)

classify_button.on_click(on_classify_click)

ui_box = widgets.VBox([section_dropdown, text_input, classify_button])
display(ui_box)


VBox(children=(Dropdown(description='Section:', options=('title', 'abstract', 'method', 'results', 'discussion…