### 1. Imports and Setup

In [1]:
import pandas as pd
import Logistic_bootstrap_metrics as lbm
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.metrics import roc_auc_score, balanced_accuracy_score, f1_score, recall_score, confusion_matrix
import statsmodels.api as sm
import torch
from transformers import BertTokenizer, BertModel, BertForTokenClassification

  from .autonotebook import tqdm as notebook_tqdm


### 2. Load and Preprocess Data

In [2]:
file_path = "data/TrainTest_Table.csv"
train_test_df = pd.read_csv(file_path)

train_df = train_test_df[train_test_df["Split"] == "Train"]
test_df = train_test_df[train_test_df["Split"] == "Test"]

### 3. Define MMSE Question Mapping

In [3]:
mmse_questions = {
    "MMYEAR": "What year is it?",
    "MMMONTH": "What month is it?",
    "MMDAY": "What day of the week is it?",
    "MMSEASON": "What season is it?",
    "MMDATE": "What is today’s date?",
    "MMSTATE": "What state are we in?",
    "MMCITY": "What city are we in?",
    "MMAREA": "What county are we in?",
    "MMHOSPIT": "What building are we in?",
    "MMFLOOR": "What floor are we on?",
    "WORD1": "Repeat the word.",
    "WORD2": "Repeat the word.",
    "WORD3": "Repeat the word.",
    "MMD": "Count backward from 100 by 7s.",
    "MML": "Count backward from 100 by 7s.",
    "MMR": "Count backward from 100 by 7s.",
    "MMO": "Count backward from 100 by 7s.",
    "MMW": "Count backward from 100 by 7s.",
    "MMLTR1": "Spell the word 'WORLD' backward.",
    "MMLTR2": "Spell the word 'WORLD' backward.",
    "MMLTR3": "Spell the word 'WORLD' backward.",
    "MMLTR4": "Spell the word 'WORLD' backward.",
    "MMLTR5": "Spell the word 'WORLD' backward.",
    "MMLTR6": "Spell the word 'WORLD' backward.",
    "MMLTR7": "Spell the word 'WORLD' backward.",
    "WORD1DL": "Can you recall the first word from earlier?",
    "WORD2DL": "Can you recall the second word from earlier?",
    "WORD3DL": "Can you recall the third word from earlier?",
    "MMWATCH": "What is this object? (Watch)",
    "MMPENCIL": "What is this object? (Pencil)",
    "MMREPEAT": "Repeat after me: 'No ifs, ands, or buts.'",
    "MMHAND": "Take this paper in your right hand.",
    "MMFOLD": "Fold this paper in half.",
    "MMONFLR": "Place this paper on the floor.",
    "MMREAD": "Read this sentence aloud.",
    "MMWRITE": "Write a sentence.",
    "MMDRAW": "Copy this design."
}

mmse_context = {
    "MMYEAR": "This question assesses the patient's orientation to time by asking for the current year.",
    "MMMONTH": "This question assesses the patient's orientation to time by asking for the current month.",
    "MMDAY": "This question assesses the patient's orientation to time by asking for the current day of the week.",
    "MMSEASON": "This question assesses the patient's orientation to time by asking for the current season.",
    "MMDATE": "This question assesses the patient's orientation to time by asking for today's date.",
    "MMSTATE": "This question assesses the patient's orientation to place by asking for the current state.",
    "MMCITY": "This question assesses the patient's orientation to place by asking for the current city.",
    "MMAREA": "This question assesses the patient's orientation to place by asking for the current county.",
    "MMHOSPIT": "This question assesses the patient's orientation to place by asking for the current building.",
    "MMFLOOR": "This question assesses the patient's orientation to place by asking for the current floor.",
    "WORD1": "This question tests the patient's ability to repeat a word for memory recall.",
    "WORD2": "This question tests the patient's ability to repeat a word for memory recall.",
    "WORD3": "This question tests the patient's ability to repeat a word for memory recall.",
    "MMD": "This question tests the patient's attention and calculation skills by asking them to count backward from 100 by 7s.",
    "MML": "This question tests the patient's attention and calculation skills by asking them to count backward from 100 by 7s.",
    "MMR": "This question tests the patient's attention and calculation skills by asking them to count backward from 100 by 7s.",
    "MMO": "This question tests the patient's attention and calculation skills by asking them to count backward from 100 by 7s.",
    "MMW": "This question tests the patient's attention and calculation skills by asking them to count backward from 100 by 7s.",
    "MMLTR1": "This question tests the patient's ability to spell a word backward.",
    "MMLTR2": "This question tests the patient's ability to spell a word backward.",
    "MMLTR3": "This question tests the patient's ability to spell a word backward.",
    "MMLTR4": "This question tests the patient's ability to spell a word backward.",
    "MMLTR5": "This question tests the patient's ability to spell a word backward.",
    "MMLTR6": "This question tests the patient's ability to spell a word backward.",
    "MMLTR7": "This question tests the patient's ability to spell a word backward.",
    "WORD1DL": "This question tests the patient's ability to recall the first word presented earlier.",
    "WORD2DL": "This question tests the patient's ability to recall the second word presented earlier.",
    "WORD3DL": "This question tests the patient's ability to recall the third word presented earlier.",
    "MMWATCH": "This question tests the patient's ability to identify a common object (watch).",
    "MMPENCIL": "This question tests the patient's ability to identify a common object (pencil).",
    "MMREPEAT": "This question tests the patient's ability to repeat a complex sentence.",
    "MMHAND": "This question tests the patient's ability to follow a simple command involving their right hand.",
    "MMFOLD": "This question tests the patient's ability to follow a simple command to fold a piece of paper.",
    "MMONFLR": "This question tests the patient's ability to follow a simple command to place a paper on the floor.",
    "MMREAD": "This question tests the patient's ability to read and comprehend a sentence.",
    "MMWRITE": "This question tests the patient's ability to write a sentence.",
    "MMDRAW": "This question tests the patient's ability to copy a design."
}

### 4. Generate MMSE Prompts

In [4]:
def generate_structured_mmse_prompts(df, mmse_questions):
    prompts = []
    for _, row in df.iterrows():
        subject_id = row["subject_id"]
        visit = row["visit"]
        ad = row["AD"] 
        for mmse_var, question in mmse_questions.items():
            if mmse_var in df.columns:
                score = row[mmse_var]
                if not pd.isna(score):
                    prompt = {
                        "subject_id": subject_id,
                        "visit": visit,
                        "AD": ad,  # Updated key name
                        "MMSE Prompt": (
                            f"Question: {question}\n"
                            f"Result: {1 if score == 1 else 0}\n"
                        )
                    }
                    prompts.append(prompt)
    return prompts


def generate_contextual_mmse_prompts(df, mmse_questions, mmse_context):
    prompts = []
    for _, row in df.iterrows():
        subject_id = row["subject_id"]
        visit = row["visit"]
        ad = row["AD"]
        for mmse_var, question in mmse_questions.items():
            if mmse_var in df.columns:
                score = row[mmse_var]
                if not pd.isna(score):
                    context = mmse_context.get(mmse_var, "No context available.")
                    prompt = {
                        "subject_id": subject_id,
                        "visit": visit,
                        "AD": ad,
                        "MMSE Prompt": (
                            f"Question: {question}\n"
                            f"Context: {context}\n"
                            f"Result: {1 if score == 1 else 0}\n"
                        )
                    }
                    prompts.append(prompt)
    return prompts

# Generate prompts 
structured_mmse_prompts_train = generate_structured_mmse_prompts(train_df, mmse_questions)
structured_mmse_prompts_test = generate_structured_mmse_prompts(test_df, mmse_questions)

# Generate contextual prompts
contextual_mmse_prompts_train = generate_contextual_mmse_prompts(train_df, mmse_questions, mmse_context)
contextual_mmse_prompts_test = generate_contextual_mmse_prompts(test_df, mmse_questions, mmse_context)

### 5. Save Prompts to CSV

In [5]:
# Convert to DataFrame
df_prompts_train = pd.DataFrame(structured_mmse_prompts_train)
df_prompts_test = pd.DataFrame(structured_mmse_prompts_test)

df_context_train = pd.DataFrame(contextual_mmse_prompts_train)
df_context_test = pd.DataFrame(contextual_mmse_prompts_test)

# Save to CSV
df_prompts_train.to_csv("data/MMSE_Prompts_Train.csv", index=False)
df_prompts_test.to_csv("data/MMSE_Prompts_Test.csv", index=False)

df_context_train.to_csv("data/MMSE_Context_Promts_Train.csv", index=False)
df_context_test.to_csv("data/MMSE_Context_Promts_Test.csv", index=False)

### 6. Load Bio_ClinicalBERT and Extract Embeddings

In [None]:
# Load pre-trained model and tokenizer
model_name = "emilyalsentzer/Bio_ClinicalBERT"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertModel.from_pretrained(model_name)

# Put the model in evaluation mode
model.eval()

# Function to extract CLS token embeddings
def extract_cls_embedding(texts, model, tokenizer):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        outputs = model(**inputs)
    cls_embeddings = outputs.last_hidden_state[:, 0, :].numpy()
    return cls_embeddings

### 7. Extract Embeddings

In [7]:
# Promts without context
train_prompts = df_prompts_train["MMSE Prompt"].tolist()
test_prompts = df_prompts_test["MMSE Prompt"].tolist()

# Embedding Promts without context
train_embeddings = extract_cls_embedding(train_prompts, model, tokenizer)
test_embeddings = extract_cls_embedding(test_prompts, model, tokenizer)

# Convert embeddings to DataFrames
train_embeddings_df = pd.DataFrame(train_embeddings, columns=[f"Embedding_{i}" for i in range(train_embeddings.shape[1])])
test_embeddings_df = pd.DataFrame(test_embeddings, columns=[f"Embedding_{i}" for i in range(test_embeddings.shape[1])])

# Concatenate the embeddings DataFrame with the original DataFrame
df_prompts_train = pd.concat([df_prompts_train.reset_index(drop=True), train_embeddings_df], axis=1)
df_prompts_test = pd.concat([df_prompts_test.reset_index(drop=True), test_embeddings_df], axis=1)

# Save the updated DataFrames to CSV
df_prompts_train.to_csv("data/MMSE_Prompts_Train.csv", index=False)
df_prompts_test.to_csv("data/MMSE_Prompts_Test.csv", index=False)

### 8. Extract Contextual Embeddings

In [None]:
# Promts with context
train_context_prompts = df_context_train["MMSE Prompt"].tolist()
test_context_prompts = df_context_test["MMSE Prompt"].tolist()

# Embedding Promts with context
train_context_embeddings = extract_cls_embedding(train_context_prompts, model, tokenizer)
test_context_embeddings = extract_cls_embedding(test_context_prompts, model, tokenizer)

# Convert embeddings to DataFrames
train_context_embeddings_df = pd.DataFrame(train_context_embeddings, columns=[f"Context_Embedding_{i}" for i in range(train_context_embeddings.shape[1])])
test_context_embeddings_df = pd.DataFrame(test_context_embeddings, columns=[f"Context_Embedding_{i}" for i in range(test_context_embeddings.shape[1])])

# Concatenate the embeddings DataFrame with the original DataFrame
df_context_train = pd.concat([df_context_train.reset_index(drop=True), train_context_embeddings_df], axis=1)
df_context_test = pd.concat([df_context_test.reset_index(drop=True), test_context_embeddings_df], axis=1)

# Save the updated DataFrames to CSV
df_context_train.to_csv("data/MMSE_Context_Promts_Train.csv", index=False)
df_context_test.to_csv("data/MMSE_Context_Promts_Test.csv", index=False)