### CS-441-Project Code
### Name: Atharva Chaudhari, Angad Chaudhary, Mihir Sahasrabudhe
### NetID: ac151, angadc2, mihirss2


In [None]:
# Name: Atharva Chaudhari, Angad Chaudhary, Mihir Sahasrabudhe

# NetID: ac151, angadc2, mihirss2



# Multi-model experiments using QUESTION + ANSWER text

# Importing the required libraries

import pandas as pd
import numpy as np
import re

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from sklearn.base import clone

from sklearn.feature_extraction.text import TfidfVectorizer

from sklearn.dummy import DummyClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.ensemble import RandomForestClassifier, HistGradientBoostingClassifier
from sklearn.neural_network import MLPClassifier

from xgboost import XGBClassifier

# 1. Load MedQuad dataset

medquad_path = "medquad_300_focus10_realtopics.csv"   # Importing the code file
df = pd.read_csv(medquad_path)

print("Loaded medquad dataset:", df.shape)
print(df.head())

# ensure correct dtypes
df["question"]   = df["question"].astype(str)
df["answer"]     = df["answer"].astype(str)
df["focus_area"] = df["focus_area"].astype(str)


# 2. Build scrub patterns from focus_area labels (scrubbing will be applied ONLY to QUESTION text)


focus_terms = df["focus_area"].str.lower().unique().tolist()
print("\nFocus / disease terms:", focus_terms)

scrub_patterns = []
for term in focus_terms:
    base = term.replace("_", " ").replace("-", " ")
    # full phrase
    scrub_patterns.append(re.escape(base))
    # individual words
    for w in base.split():
        scrub_patterns.append(re.escape(w))

scrub_patterns = list(set(scrub_patterns))
print("\nNumber of scrub patterns:", len(scrub_patterns))

def scrub_text(text: str, patterns) -> str:
    t = text.lower()
    for p in patterns:
        t = re.sub(r"\b" + p + r"\b", " ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

# 3. Building the model:
#    - qa_text = question + answer  (no scrubbing)
#    - scrubbed_question = question with disease terms removed
#    - qa_text_scrubbed = scrubbed_question + answer

df["scrubbed_question"] = df["question"].apply(lambda x: scrub_text(x, scrub_patterns))

df["qa_text"] = df["question"] + " " + df["answer"]
df["qa_text_scrubbed"] = df["scrubbed_question"] + " " + df["answer"]

print("\n=== Example original vs scrubbed QA text (scrub only question part) ===")
for i in range(5):
    print("\nOriginal Q:", df.loc[i, "question"])
    print("Scrubbed Q:", df.loc[i, "scrubbed_question"])
    print("Answer    :", df.loc[i, "answer"])
    print("QA text   :", df.loc[i, "qa_text"])
    print("QA scrub  :", df.loc[i, "qa_text_scrubbed"])


# 4. Encode labels

label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df["focus_area"].values)

print("\nLabel mapping:")
for i, cls in enumerate(label_encoder.classes_):
    print(f"{i}: {cls}")

# 5. Train/test split for MedQuad
#    1) Train on medquad (80%) test on medquad (20%) using original QA
#    2) Train on scrubbed QA (80%), test on original QA (20%)

train_idx, test_idx = train_test_split(
    np.arange(len(df)),
    test_size=0.2,
    random_state=42,
    stratify=y
)

# Exp1: original QA text
X_train_exp1 = df.loc[train_idx, "qa_text"].values
y_train_exp1 = y[train_idx]
X_test_exp1  = df.loc[test_idx, "qa_text"].values
y_test_exp1  = y[test_idx]

# Exp2: scrubbed QA for training, original QA for testing
X_train_exp2 = df.loc[train_idx, "qa_text_scrubbed"].values
y_train_exp2 = y[train_idx]
X_test_exp2  = df.loc[test_idx, "qa_text"].values
y_test_exp2  = y[test_idx]


# 6. Full MedQuad data for Experiments 3 & 4
#    3) Train on whole original QA, test on implicit dataset
#    4) Train on whole scrubbed QA, test on scrubbed implicit

X_train_exp3 = df["qa_text"].values          # full original QA
y_train_exp3 = y

X_train_exp4 = df["qa_text_scrubbed"].values # full scrubbed QA
y_train_exp4 = y

# ------------------------------------------------------------
# 7. Build implicit (symptom-style) dataset

implicit_rows = []

def add_implicit_example(question, focus):
    implicit_rows.append({"question": question, "focus_area": focus})

# We have added approximately 5 symptom-like queries per focus area
# NOTE: names must match df["focus_area"] values exactly
focus_names = label_encoder.classes_

for fa in focus_names:
    if fa == "High Blood Pressure":
        add_implicit_example(
            "After walking up a single flight of stairs my heart rate stays high and I feel pressure in my head for some time.",
            fa
        )
        add_implicit_example(
            "At recent checkups the upper number on my reading is always high even when I sit calmly in the clinic.",
            fa
        )
        add_implicit_example(
            "Sometimes when I am resting I feel my heart pounding and get mild headaches without any clear reason.",
            fa
        )
        add_implicit_example(
            "I often wake up with slight blurred vision and a heavy feeling in my head in the morning.",
            fa
        )
        add_implicit_example(
            "My home device usually shows high readings during the evening even when I feel normal.",
            fa
        )

    elif fa == "Type 2 Diabetes":
        add_implicit_example(
            "Lately I feel thirsty all the time and need to use the washroom very frequently during the day.",
            fa
        )
        add_implicit_example(
            "Small cuts on my skin are taking a long time to heal and I feel tired even after normal activity.",
            fa
        )
        add_implicit_example(
            "I have recently gained weight around my stomach and often feel very hungry again soon after meals.",
            fa
        )
        add_implicit_example(
            "Sometimes my vision gets slightly blurry especially in the evening after dinner.",
            fa
        )
        add_implicit_example(
            "During my last routine blood test the sugar level was higher than normal but I had no major symptoms.",
            fa
        )

    elif fa == "Asthma":
        add_implicit_example(
            "When I climb stairs or walk fast I start wheezing and feel tightness in my chest.",
            fa
        )
        add_implicit_example(
            "I often wake up at night with coughing and a whistling sound when I breathe.",
            fa
        )
        add_implicit_example(
            "Cold air and dust make it hard for me to take a deep breath and I feel breathless quickly.",
            fa
        )
        add_implicit_example(
            "Sometimes after laughing hard I get a dry cough and shortness of breath.",
            fa
        )
        add_implicit_example(
            "I feel my chest tightening when I am around strong smells or smoke even for a short time.",
            fa
        )

    elif fa == "Major Depressive Disorder":
        add_implicit_example(
            "For the past few months I have lost interest in activities I used to enjoy and feel low most days.",
            fa
        )
        add_implicit_example(
            "My sleep pattern is disturbed and I either sleep too much or struggle to fall asleep at all.",
            fa
        )
        add_implicit_example(
            "I find it hard to focus on simple tasks and constantly feel guilty without a clear reason.",
            fa
        )
        add_implicit_example(
            "Even small everyday tasks feel exhausting and I feel hopeless about the future.",
            fa
        )
        add_implicit_example(
            "I have changes in appetite and my weight is fluctuating without intentionally dieting.",
            fa
        )

    elif fa == "Generalized Anxiety Disorder":
        add_implicit_example(
            "I worry about many small things all day and find it difficult to control these thoughts.",
            fa
        )
        add_implicit_example(
            "I often feel restless and my muscles feel tense even when nothing serious is happening.",
            fa
        )
        add_implicit_example(
            "Before normal situations like meetings or phone calls my heart races and I feel very nervous.",
            fa
        )
        add_implicit_example(
            "Sometimes I get stomach discomfort when I am stressed and my mind keeps imagining worst outcomes.",
            fa
        )
        add_implicit_example(
            "I find it hard to relax and my mind keeps jumping from one worry to another.",
            fa
        )

    elif fa == "Chronic Back Pain":
        add_implicit_example(
            "I have a dull ache in my lower back that has been present for several months.",
            fa
        )
        add_implicit_example(
            "Sitting for long periods or lifting even light objects makes my back pain worse.",
            fa
        )
        add_implicit_example(
            "Sometimes the pain spreads down into my legs when I bend forward.",
            fa
        )
        add_implicit_example(
            "In the mornings my lower back feels stiff and it takes time before I can move comfortably.",
            fa
        )
        add_implicit_example(
            "I feel a constant nagging pain in my back even when I am resting on a chair.",
            fa
        )

    elif fa == "Irritable Bowel Syndrome":
        add_implicit_example(
            "I frequently have stomach cramps and alternating constipation and loose stools.",
            fa
        )
        add_implicit_example(
            "After meals my abdomen becomes bloated and I pass gas more than usual.",
            fa
        )
        add_implicit_example(
            "Stressful days seem to make my bowel discomfort and urgency worse.",
            fa
        )
        add_implicit_example(
            "I often feel an urgent need to go to the washroom but sometimes only pass a small amount.",
            fa
        )
        add_implicit_example(
            "Certain foods like spicy or greasy meals cause stomach pain and irregular motions for me.",
            fa
        )

    elif fa == "Osteoarthritis":
        add_implicit_example(
            "My knee joints hurt when I climb stairs and feel stiff after sitting for a long time.",
            fa
        )
        add_implicit_example(
            "In the morning my finger joints feel tight and painful for some time before loosening up.",
            fa
        )
        add_implicit_example(
            "I hear a grinding sound in my knee when I walk and the joint sometimes swells slightly.",
            fa
        )
        add_implicit_example(
            "Walking long distances has become uncomfortable because of pain in my hip joints.",
            fa
        )
        add_implicit_example(
            "The pain in my joints worsens by the end of the day after regular activity.",
            fa
        )

    elif fa == "Chronic Kidney Disease":
        add_implicit_example(
            "Recently my ankles and feet look swollen and I feel tired most of the time.",
            fa
        )
        add_implicit_example(
            "I need to pass urine more often at night and sometimes notice it is foamy.",
            fa
        )
        add_implicit_example(
            "I have a poor appetite and occasional nausea without any obvious stomach infection.",
            fa
        )
        add_implicit_example(
            "During a routine check my blood test showed reduced kidney function and my blood pressure was high.",
            fa
        )
        add_implicit_example(
            "I sometimes have muscle cramps and dry itchy skin along with tiredness.",
            fa
        )

    elif fa == "Coronary Artery Disease":
        add_implicit_example(
            "When I walk fast or climb stairs I feel heaviness and tightness in the center of my chest.",
            fa
        )
        add_implicit_example(
            "The discomfort in my chest sometimes spreads to my left arm and jaw during exertion.",
            fa
        )
        add_implicit_example(
            "I quickly get short of breath during mild physical activity compared to earlier.",
            fa
        )
        add_implicit_example(
            "Occasionally I feel pressure in my chest that eases when I rest for a few minutes.",
            fa
        )
        add_implicit_example(
            "My doctor mentioned I have high cholesterol and a family history of heart problems.",
            fa
        )

implicit_df = pd.DataFrame(implicit_rows)
print("\nLoaded implicit test dataset (created in code):", implicit_df.shape)
print(implicit_df.head())

implicit_df["question"]   = implicit_df["question"].astype(str)
implicit_df["focus_area"] = implicit_df["focus_area"].astype(str)

X_imp_original = implicit_df["question"].values
X_imp_scrubbed = np.array([scrub_text(q, scrub_patterns) for q in X_imp_original])

y_imp_text = implicit_df["focus_area"].values
y_imp = label_encoder.transform(y_imp_text)


# 8. Define models

models = {
    "Dummy_most_frequent": DummyClassifier(strategy="most_frequent"),
    "KNN_k5": KNeighborsClassifier(n_neighbors=5),
    "LogisticRegression_L2": LogisticRegression(
        max_iter=2000,
        n_jobs=-1
    ),
    "LinearSVC": LinearSVC(),
    "RandomForest_200": RandomForestClassifier(
        n_estimators=200,
        random_state=42,
        n_jobs=-1
    ),
    "HistGradientBoosting": HistGradientBoostingClassifier(
        random_state=42
    ),
    "MLP_256_128": MLPClassifier(
        hidden_layer_sizes=(256, 128),
        max_iter=500,
        random_state=42
    ),
    "XGBoost": XGBClassifier(
        n_estimators=300,
        learning_rate=0.1,
        max_depth=6,
        subsample=0.8,
        colsample_bytree=0.8,
        objective="multi:softmax",
        eval_metric="mlogloss",
        random_state=42,
        num_class=len(label_encoder.classes_),
        n_jobs=-1
    ),
}

print("\nModels defined:")
for name in models:
    print(" -", name)


# 9. Helper to run one experiment for one model

def run_experiment(model_name, base_clf, train_texts, train_labels,
                   test_texts, test_labels, exp_tag):

    print(f"\n[ {exp_tag} ]")
    tfidf = TfidfVectorizer(
        ngram_range=(1, 2),
        min_df=2,
        max_features=5000
    )
    X_train = tfidf.fit_transform(train_texts)
    X_test  = tfidf.transform(test_texts)

    # Some models require dense inputs
    if model_name in ["HistGradientBoosting", "MLP_256_128"]:
        X_train_in = X_train.toarray()
        X_test_in  = X_test.toarray()
    else:
        X_train_in = X_train
        X_test_in  = X_test

    clf = clone(base_clf)
    clf.fit(X_train_in, train_labels)
    y_pred = clf.predict(X_test_in)

    acc = accuracy_score(test_labels, y_pred)
    print(f"Accuracy: {acc:.3f}\n")

    true_labels_text = label_encoder.inverse_transform(test_labels)
    pred_labels_text = label_encoder.inverse_transform(y_pred)

    print("Classification report:")
    print(classification_report(
        true_labels_text,
        pred_labels_text,
        digits=3
    ))

    return acc

# ------------------------------------------------------------
# 10. Run all experiments for all models
#     1) Exp1: Train 80% original QA, test 20% original QA  (MedQuad)
#     2) Exp2: Train 80% scrubbed QA, test 20% original QA (MedQuad)
#     3) Exp3: Train FULL original QA, test implicit questions
#     4) Exp4: Train FULL scrubbed QA, test scrubbed implicit questions

summary_rows = []

for model_name, base_clf in models.items():
    print("\n" + "="*80)
    print(f"MODEL: {model_name}")
    print("="*80)

    # 1) Train on 80% MedQuad (original QA), test on 20% MedQuad (original QA)
    acc1 = run_experiment(
        model_name,
        base_clf,
        X_train_exp1,
        y_train_exp1,
        X_test_exp1,
        y_test_exp1,
        exp_tag="Exp1: Orig QA 80/20 (train/test on MedQuad)"
    )

    # 2) Train on 80% SCRUBBED MedQuad QA, test on 20% original MedQuad QA
    acc2 = run_experiment(
        model_name,
        base_clf,
        X_train_exp2,
        y_train_exp2,
        X_test_exp2,
        y_test_exp2,
        exp_tag="Exp2: Scrubbed QA 80% train, test on original MedQuad 20%"
    )

    # 3) Train on FULL original MedQuad QA, test on implicit questions
    acc3 = run_experiment(
        model_name,
        base_clf,
        X_train_exp3,
        y_train_exp3,
        X_imp_original,
        y_imp,
        exp_tag="Exp3: Train on FULL original MedQuad QA, test on implicit"
    )

    # 4) Train on FULL SCRUBBED MedQuad QA, test on SCRUBBED implicit questions
    acc4 = run_experiment(
        model_name,
        base_clf,
        X_train_exp4,
        y_train_exp4,
        X_imp_scrubbed,
        y_imp,
        exp_tag="Exp4: Train on FULL scrubbed MedQuad QA, test on scrubbed implicit"
    )

    summary_rows.append({
        "model": model_name,
        "Exp1_OrigQA_80_20_acc": acc1,
        "Exp2_ScrubQA_80_20_acc": acc2,
        "Exp3_OrigAll_Implicit_acc": acc3,
        "Exp4_ScrubAll_Implicit_acc": acc4,
    })

-
# 11. Show accuracy summary

summary_df = pd.DataFrame(summary_rows)
print("\n==================== SUMMARY OF ALL EXPERIMENTS ====================")
print(summary_df.to_string(index=False))


Loaded medquad dataset: (300, 5)
                                            question  \
0                       What is high blood pressure?   
1                   What causes high blood pressure?   
2  What are the common symptoms of high blood pre...   
3      How is high blood pressure usually diagnosed?   
4       How is high blood pressure commonly treated?   

                                              answer  \
0  High blood pressure is a health condition that...   
1  High blood pressure usually develops from a mi...   
2  Common symptoms of high blood pressure can inc...   
3  High blood pressure is usually diagnosed by co...   
4  Treatment for high blood pressure often involv...   

                      source           focus_area  Emotion  
0  synthetic_medquad_focus10  High Blood Pressure  neutral  
1  synthetic_medquad_focus10  High Blood Pressure  neutral  
2  synthetic_medquad_focus10  High Blood Pressure  neutral  
3  synthetic_medquad_focus10  High Blood Pressure

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.100     1.000     0.182         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.433

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.143     0.667     0.235         6
           Chronic Back Pain      0.500     0.500     0.500         6
      Chronic Kidney Disease      0.167     0.167     0.167         6
     Coronary Artery Disease      1.000     0.500     0.667         6
Generalized Anxiety Disorder      0.800     0.667     0.727         6
         High Blood Pressure      0.750     0.500     0.600         6
    Irritable Bowel Syndrome      1.000     0.333     0.500         6
   Major Depressive Disorder      1.000     0.167     0.286         6
              Osteoarthritis      1.000     0.333     0.500         6
             Type 2 Diabetes      1.000     0.500     0.667         6

                    accuracy                          0.433        60
                   macro avg      0.736     0.433     0.485        60
                weighted avg      0.736     0.43

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.280

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.152     1.000     0.263         5
           Chronic Back Pain      0.625     1.000     0.769         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.500     0.600     0.545         5
    Irritable Bowel Syndrome      1.000     0.200     0.333         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.280        50
                   macro avg      0.228     0.280     0.191        50
                weighted avg      0.228     0.28

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.100     1.000     0.182         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 1.000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.100     1.000     0.182         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 1.000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.100     1.000     0.182         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.983

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     0.833     0.909         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      0.857     1.000     0.923         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          0.983        60
                   macro avg      0.986     0.983     0.983        60
                weighted avg      0.986     0.98

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.088     0.600     0.154         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.222     0.400     0.286         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.031     0.100     0.044        50
                weighted avg      0.031     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 1.000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.100     1.000     0.182         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.983

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     0.833     0.909         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      0.857     1.000     0.923         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          0.983        60
                   macro avg      0.986     0.983     0.983        60
                weighted avg      0.986     0.98

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.120

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.105     0.400     0.167         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.143     0.800     0.242         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.120        50
                   macro avg      0.025     0.120     0.041        50
                weighted avg      0.025     0.12

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 1.000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.00

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Accuracy: 0.100

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.000     0.000     0.000         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.100     1.000     0.182         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.10

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:

# BERT-family experiments using QUESTION + ANSWER text

#   4 experiments per model:
#   Exp1: Train 80% orig QA (MedQuad), test 20% orig QA (MedQuad)
#   Exp2: Train 80% scrubbed QA, test 20% orig QA (MedQuad)
#   Exp3: Train FULL orig QA, test implicit questions
#   Exp4: Train FULL scrubbed QA, test scrubbed implicit questions

# Importing the require libraries
import pandas as pd
import numpy as np
import re

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report

import torch
from torch.utils.data import Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
    set_seed,
)

# Make runs reproducible
set_seed(42)


# 1. Load MedQuad dataset

medquad_path = "medquad_300_focus10_realtopics.csv"
df = pd.read_csv(medquad_path)

print("Loaded medquad dataset:", df.shape)
print(df.head())

df["question"]   = df["question"].astype(str)
df["answer"]     = df["answer"].astype(str)
df["focus_area"] = df["focus_area"].astype(str)

# 2. Build scrub patterns from focus_area labels
#    (scrubbing will be applied ONLY to QUESTION text)

focus_terms = df["focus_area"].str.lower().unique().tolist()
print("\nFocus / disease terms:", focus_terms)

scrub_patterns = []
for term in focus_terms:
    base = term.replace("_", " ").replace("-", " ")
    # full phrase
    scrub_patterns.append(re.escape(base))
    # individual words
    for w in base.split():
        scrub_patterns.append(re.escape(w))

scrub_patterns = list(set(scrub_patterns))
print("\nNumber of scrub patterns:", len(scrub_patterns))

def scrub_text(text: str, patterns) -> str:
    t = text.lower()
    for p in patterns:
        t = re.sub(r"\b" + p + r"\b", " ", t)
    t = re.sub(r"\s+", " ", t).strip()
    return t

# 3. Build:
#    - scrubbed_question = question with disease terms removed
#    - qa_text           = question + answer
#    - qa_text_scrubbed  = scrubbed_question + answer

df["scrubbed_question"] = df["question"].apply(lambda x: scrub_text(x, scrub_patterns))

df["qa_text"] = df["question"] + " " + df["answer"]
df["qa_text_scrubbed"] = df["scrubbed_question"] + " " + df["answer"]

print("\n=== Example original vs scrubbed QA text (scrub only question) ===")
for i in range(5):
    print("\nOriginal Q:", df.loc[i, "question"])
    print("Scrubbed Q:", df.loc[i, "scrubbed_question"])
    print("Answer    :", df.loc[i, "answer"])
    print("QA text   :", df.loc[i, "qa_text"])
    print("QA scrub  :", df.loc[i, "qa_text_scrubbed"])

# 4. Encode labels

label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df["focus_area"].values)

print("\nLabel mapping:")
for i, cls in enumerate(label_encoder.classes_):
    print(f"{i}: {cls}")
num_labels = len(label_encoder.classes_)


# 5. Train/test split for MedQuad (Experiments 1 & 2)

train_idx, test_idx = train_test_split(
    np.arange(len(df)),
    test_size=0.2,
    random_state=42,
    stratify=y
)

# Exp1: original QA
X_train_exp1 = df.loc[train_idx, "qa_text"].values
y_train_exp1 = y[train_idx]
X_test_exp1  = df.loc[test_idx, "qa_text"].values
y_test_exp1  = y[test_idx]

# Exp2: scrubbed QA for training, original QA for testing
X_train_exp2 = df.loc[train_idx, "qa_text_scrubbed"].values
y_train_exp2 = y[train_idx]
X_test_exp2  = df.loc[test_idx, "qa_text"].values
y_test_exp2  = y[test_idx]


# 6. Full MedQuad data for Experiments 3 & 4

X_train_exp3 = df["qa_text"].values          # full original QA
y_train_exp3 = y

X_train_exp4 = df["qa_text_scrubbed"].values # full scrubbed QA
y_train_exp4 = y


# 7. Build implicit (symptom-style) dataset in code
#    (same style as classical models)
implicit_rows = []

def add_implicit_example(question, focus):
    implicit_rows.append({"question": question, "focus_area": focus})

focus_names = label_encoder.classes_

for fa in focus_names:
    if fa == "High Blood Pressure":
        add_implicit_example(
            "After walking up a single flight of stairs my heart rate stays high and I feel pressure in my head for some time.",
            fa
        )
        add_implicit_example(
            "At recent checkups the upper number on my reading is always high even when I sit calmly in the clinic.",
            fa
        )
        add_implicit_example(
            "Sometimes when I am resting I feel my heart pounding and get mild headaches without any clear reason.",
            fa
        )
        add_implicit_example(
            "I often wake up with slight blurred vision and a heavy feeling in my head in the morning.",
            fa
        )
        add_implicit_example(
            "My home device usually shows high readings during the evening even when I feel normal.",
            fa
        )

    elif fa == "Type 2 Diabetes":
        add_implicit_example(
            "Lately I feel thirsty all the time and need to use the washroom very frequently during the day.",
            fa
        )
        add_implicit_example(
            "Small cuts on my skin are taking a long time to heal and I feel tired even after normal activity.",
            fa
        )
        add_implicit_example(
            "I have recently gained weight around my stomach and often feel very hungry again soon after meals.",
            fa
        )
        add_implicit_example(
            "Sometimes my vision gets slightly blurry especially in the evening after dinner.",
            fa
        )
        add_implicit_example(
            "During my last routine blood test the sugar level was higher than normal but I had no major symptoms.",
            fa
        )

    elif fa == "Asthma":
        add_implicit_example(
            "When I climb stairs or walk fast I start wheezing and feel tightness in my chest.",
            fa
        )
        add_implicit_example(
            "I often wake up at night with coughing and a whistling sound when I breathe.",
            fa
        )
        add_implicit_example(
            "Cold air and dust make it hard for me to take a deep breath and I feel breathless quickly.",
            fa
        )
        add_implicit_example(
            "Sometimes after laughing hard I get a dry cough and shortness of breath.",
            fa
        )
        add_implicit_example(
            "I feel my chest tightening when I am around strong smells or smoke even for a short time.",
            fa
        )

    elif fa == "Major Depressive Disorder":
        add_implicit_example(
            "For the past few months I have lost interest in activities I used to enjoy and feel low most days.",
            fa
        )
        add_implicit_example(
            "My sleep pattern is disturbed and I either sleep too much or struggle to fall asleep at all.",
            fa
        )
        add_implicit_example(
            "I find it hard to focus on simple tasks and constantly feel guilty without a clear reason.",
            fa
        )
        add_implicit_example(
            "Even small everyday tasks feel exhausting and I feel hopeless about the future.",
            fa
        )
        add_implicit_example(
            "I have changes in appetite and my weight is fluctuating without intentionally dieting.",
            fa
        )

    elif fa == "Generalized Anxiety Disorder":
        add_implicit_example(
            "I worry about many small things all day and find it difficult to control these thoughts.",
            fa
        )
        add_implicit_example(
            "I often feel restless and my muscles feel tense even when nothing serious is happening.",
            fa
        )
        add_implicit_example(
            "Before normal situations like meetings or phone calls my heart races and I feel very nervous.",
            fa
        )
        add_implicit_example(
            "Sometimes I get stomach discomfort when I am stressed and my mind keeps imagining worst outcomes.",
            fa
        )
        add_implicit_example(
            "I find it hard to relax and my mind keeps jumping from one worry to another.",
            fa
        )

    elif fa == "Chronic Back Pain":
        add_implicit_example(
            "I have a dull ache in my lower back that has been present for several months.",
            fa
        )
        add_implicit_example(
            "Sitting for long periods or lifting even light objects makes my back pain worse.",
            fa
        )
        add_implicit_example(
            "Sometimes the pain spreads down into my legs when I bend forward.",
            fa
        )
        add_implicit_example(
            "In the mornings my lower back feels stiff and it takes time before I can move comfortably.",
            fa
        )
        add_implicit_example(
            "I feel a constant nagging pain in my back even when I am resting on a chair.",
            fa
        )

    elif fa == "Irritable Bowel Syndrome":
        add_implicit_example(
            "I frequently have stomach cramps and alternating constipation and loose stools.",
            fa
        )
        add_implicit_example(
            "After meals my abdomen becomes bloated and I pass gas more than usual.",
            fa
        )
        add_implicit_example(
            "Stressful days seem to make my bowel discomfort and urgency worse.",
            fa
        )
        add_implicit_example(
            "I often feel an urgent need to go to the washroom but sometimes only pass a small amount.",
            fa
        )
        add_implicit_example(
            "Certain foods like spicy or greasy meals cause stomach pain and irregular motions for me.",
            fa
        )

    elif fa == "Osteoarthritis":
        add_implicit_example(
            "My knee joints hurt when I climb stairs and feel stiff after sitting for a long time.",
            fa
        )
        add_implicit_example(
            "In the morning my finger joints feel tight and painful for some time before loosening up.",
            fa
        )
        add_implicit_example(
            "I hear a grinding sound in my knee when I walk and the joint sometimes swells slightly.",
            fa
        )
        add_implicit_example(
            "Walking long distances has become uncomfortable because of pain in my hip joints.",
            fa
        )
        add_implicit_example(
            "The pain in my joints worsens by the end of the day after regular activity.",
            fa
        )

    elif fa == "Chronic Kidney Disease":
        add_implicit_example(
            "Recently my ankles and feet look swollen and I feel tired most of the time.",
            fa
        )
        add_implicit_example(
            "I need to pass urine more often at night and sometimes notice it is foamy.",
            fa
        )
        add_implicit_example(
            "I have a poor appetite and occasional nausea without any obvious stomach infection.",
            fa
        )
        add_implicit_example(
            "During a routine check my blood test showed reduced kidney function and my blood pressure was high.",
            fa
        )
        add_implicit_example(
            "I sometimes have muscle cramps and dry itchy skin along with tiredness.",
            fa
        )

    elif fa == "Coronary Artery Disease":
        add_implicit_example(
            "When I walk fast or climb stairs I feel heaviness and tightness in the center of my chest.",
            fa
        )
        add_implicit_example(
            "The discomfort in my chest sometimes spreads to my left arm and jaw during exertion.",
            fa
        )
        add_implicit_example(
            "I quickly get short of breath during mild physical activity compared to earlier.",
            fa
        )
        add_implicit_example(
            "Occasionally I feel pressure in my chest that eases when I rest for a few minutes.",
            fa
        )
        add_implicit_example(
            "My doctor mentioned I have high cholesterol and a family history of heart problems.",
            fa
        )

implicit_df = pd.DataFrame(implicit_rows)
print("\nLoaded implicit test dataset (created in code):", implicit_df.shape)
print(implicit_df.head())

implicit_df["question"]   = implicit_df["question"].astype(str)
implicit_df["focus_area"] = implicit_df["focus_area"].astype(str)

X_imp_original = implicit_df["question"].values
X_imp_scrubbed = np.array([scrub_text(q, scrub_patterns) for q in X_imp_original])

y_imp_text = implicit_df["focus_area"].values
y_imp = label_encoder.transform(y_imp_text)


# 8. Torch Dataset for Trainer

class TextClassificationDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

# ------------------------------------------------------------
# 9. Define BERT-family models

bert_models = {
    "bert-base-uncased": "bert-base-uncased",
    "PubMedBERT": "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext",
    "MiniLM_Sbert": "sentence-transformers/all-MiniLM-L6-v2",
}

print("\nBERT-style models defined:")
for alias, name in bert_models.items():
    print(f" - {alias}: {name}")

# ------------------------------------------------------------
# 10. Helper: run one BERT experiment

from sklearn.metrics import precision_recall_fscore_support

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(
        labels, preds, average="macro", zero_division=0
    )
    return {
        "accuracy": acc,
        "macro_precision": precision,
        "macro_recall": recall,
        "macro_f1": f1,
    }

def run_bert_experiment(
    model_alias,
    model_name,
    train_texts,
    train_labels,
    test_texts,
    test_labels,
    exp_tag,
    max_length=256,
):
    print(f"\n===== {model_alias} | {exp_tag} =====")

    # 1) Tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_name)

    # 2) Tokenize
    train_enc = tokenizer(
        list(train_texts),
        padding=True,
        truncation=True,
        max_length=max_length,
    )
    test_enc = tokenizer(
        list(test_texts),
        padding=True,
        truncation=True,
        max_length=max_length,
    )

    train_dataset = TextClassificationDataset(train_enc, train_labels)
    test_dataset  = TextClassificationDataset(test_enc, test_labels)

    # 3) Model
    id2label = {i: l for i, l in enumerate(label_encoder.classes_)}
    label2id = {l: i for i, l in id2label.items()}

    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=num_labels,
        id2label=id2label,
        label2id=label2id,
    )

    # 4) Training args – simplified for older transformers versions
    safe_exp = exp_tag.replace(" ", "_").replace("/", "_")
    output_dir = f"{model_alias}_{safe_exp}_out"

    training_args = TrainingArguments(
    output_dir=output_dir,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    report_to="none"   # completely disable wandb popups
)

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        compute_metrics=compute_metrics,
    )

    # 5) Train
    trainer.train()

    # 6) Evaluate
    eval_output = trainer.predict(test_dataset)
    preds = np.argmax(eval_output.predictions, axis=-1)
    acc = accuracy_score(test_labels, preds)

    # Print classification report with text labels
    true_text = label_encoder.inverse_transform(test_labels)
    pred_text = label_encoder.inverse_transform(preds)

    print(f"\nAccuracy: {acc:.4f}")
    print("\nClassification report:")
    print(classification_report(true_text, pred_text, digits=3))

    return acc



# 11. Run all experiments for all BERT models

summary_rows = []

for model_alias, model_name in bert_models.items():
    print("\n" + "="*90)
    print(f"RUNNING MODEL: {model_alias} ({model_name})")
    print("="*90)

    # 1) Exp1: Train 80% orig QA, test 20% orig QA (MedQuad)
    acc1 = run_bert_experiment(
        model_alias,
        model_name,
        X_train_exp1,
        y_train_exp1,
        X_test_exp1,
        y_test_exp1,
        exp_tag="Exp1_OrigQA_80_20_MedQuad",
    )

    # 2) Exp2: Train 80% scrubbed QA, test 20% orig QA (MedQuad)
    acc2 = run_bert_experiment(
        model_alias,
        model_name,
        X_train_exp2,
        y_train_exp2,
        X_test_exp2,
        y_test_exp2,
        exp_tag="Exp2_ScrubQA_80_20_MedQuad",
    )

    # 3) Exp3: Train FULL orig QA, test implicit questions
    acc3 = run_bert_experiment(
        model_alias,
        model_name,
        X_train_exp3,
        y_train_exp3,
        X_imp_original,
        y_imp,
        exp_tag="Exp3_OrigAll_Implicit",
    )

    # 4) Exp4: Train FULL scrubbed QA, test scrubbed implicit
    acc4 = run_bert_experiment(
        model_alias,
        model_name,
        X_train_exp4,
        y_train_exp4,
        X_imp_scrubbed,
        y_imp,
        exp_tag="Exp4_ScrubAll_Implicit",
    )

    summary_rows.append({
        "model": model_alias,
        "Exp1_OrigQA_80_20_acc": acc1,
        "Exp2_ScrubQA_80_20_acc": acc2,
        "Exp3_OrigAll_Implicit_acc": acc3,
        "Exp4_ScrubAll_Implicit_acc": acc4,
    })

# 12. Show accuracy summary

summary_df = pd.DataFrame(summary_rows)
print("\n==================== SUMMARY OF ALL BERT EXPERIMENTS ====================")
print(summary_df.to_string(index=False))


Loaded medquad dataset: (300, 5)
                                            question  \
0                       What is high blood pressure?   
1                   What causes high blood pressure?   
2  What are the common symptoms of high blood pre...   
3      How is high blood pressure usually diagnosed?   
4       How is high blood pressure commonly treated?   

                                              answer  \
0  High blood pressure is a health condition that...   
1  High blood pressure usually develops from a mi...   
2  Common symptoms of high blood pressure can inc...   
3  High blood pressure is usually diagnosed by co...   
4  Treatment for high blood pressure often involv...   

                      source           focus_area  Emotion  
0  synthetic_medquad_focus10  High Blood Pressure  neutral  
1  synthetic_medquad_focus10  High Blood Pressure  neutral  
2  synthetic_medquad_focus10  High Blood Pressure  neutral  
3  synthetic_medquad_focus10  High Blood Pressure

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 1.0000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.7833

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.600     1.000     0.750         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      0.500     1.000     0.667         6
     Coronary Artery Disease      0.000     0.000     0.000         6
Generalized Anxiety Disorder      1.000     0.333     0.500         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     0.500     0.667         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      0.667     1.000     0.800         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          0.783        60
                   macro avg      0.777     0.783     0.738        60
                weighted avg      0.777     0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.1000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.000     0.000     0.000         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.100     1.000     0.182         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.010     0.100     0.018        50
                weighted avg      0.010     0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.1000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.087     0.400     0.143         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.136     0.600     0.222         5
     Coronary Artery Disease      0.000     0.000     0.000         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.000     0.000     0.000         5
    Irritable Bowel Syndrome      0.000     0.000     0.000         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.100        50
                   macro avg      0.022     0.100     0.037        50
                weighted avg      0.022     0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


tokenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/385 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 1.0000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 1.0000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.2800

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.000     0.000     0.000         5
           Chronic Back Pain      0.227     1.000     0.370         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      1.000     0.200     0.333         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.214     0.600     0.316         5
    Irritable Bowel Syndrome      0.800     0.800     0.800         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.000     0.000     0.000         5
             Type 2 Diabetes      1.000     0.200     0.333         5

                    accuracy                          0.280        50
                   macro avg      0.324     0.280     0.215        50
                weighted avg      0.324     0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.2200

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.000     0.000     0.000         5
           Chronic Back Pain      0.000     0.000     0.000         5
      Chronic Kidney Disease      0.000     0.000     0.000         5
     Coronary Artery Disease      1.000     0.400     0.571         5
Generalized Anxiety Disorder      0.000     0.000     0.000         5
         High Blood Pressure      0.167     0.400     0.235         5
    Irritable Bowel Syndrome      1.000     0.600     0.750         5
   Major Depressive Disorder      0.154     0.400     0.222         5
              Osteoarthritis      1.000     0.400     0.571         5
             Type 2 Diabetes      0.000     0.000     0.000         5

                    accuracy                          0.220        50
                   macro avg      0.332     0.220     0.235        50
                weighted avg      0.332     0.

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 1.0000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 1.0000

Classification report:
                              precision    recall  f1-score   support

                      Asthma      1.000     1.000     1.000         6
           Chronic Back Pain      1.000     1.000     1.000         6
      Chronic Kidney Disease      1.000     1.000     1.000         6
     Coronary Artery Disease      1.000     1.000     1.000         6
Generalized Anxiety Disorder      1.000     1.000     1.000         6
         High Blood Pressure      1.000     1.000     1.000         6
    Irritable Bowel Syndrome      1.000     1.000     1.000         6
   Major Depressive Disorder      1.000     1.000     1.000         6
              Osteoarthritis      1.000     1.000     1.000         6
             Type 2 Diabetes      1.000     1.000     1.000         6

                    accuracy                          1.000        60
                   macro avg      1.000     1.000     1.000        60
                weighted avg      1.000     1.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.5800

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.833     1.000     0.909         5
           Chronic Back Pain      0.556     1.000     0.714         5
      Chronic Kidney Disease      0.500     0.400     0.444         5
     Coronary Artery Disease      0.600     0.600     0.600         5
Generalized Anxiety Disorder      0.200     0.200     0.200         5
         High Blood Pressure      1.000     0.400     0.571         5
    Irritable Bowel Syndrome      0.556     1.000     0.714         5
   Major Depressive Disorder      0.000     0.000     0.000         5
              Osteoarthritis      0.600     0.600     0.600         5
             Type 2 Diabetes      0.750     0.600     0.667         5

                    accuracy                          0.580        50
                   macro avg      0.559     0.580     0.542        50
                weighted avg      0.559     0.

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Step,Training Loss



Accuracy: 0.5200

Classification report:
                              precision    recall  f1-score   support

                      Asthma      0.625     1.000     0.769         5
           Chronic Back Pain      0.429     0.600     0.500         5
      Chronic Kidney Disease      0.250     0.200     0.222         5
     Coronary Artery Disease      0.600     0.600     0.600         5
Generalized Anxiety Disorder      0.500     0.400     0.444         5
         High Blood Pressure      1.000     0.200     0.333         5
    Irritable Bowel Syndrome      0.571     0.800     0.667         5
   Major Depressive Disorder      0.333     0.200     0.250         5
              Osteoarthritis      0.571     0.800     0.667         5
             Type 2 Diabetes      0.500     0.400     0.444         5

                    accuracy                          0.520        50
                   macro avg      0.538     0.520     0.490        50
                weighted avg      0.538     0.

###Huggingface

In [None]:

# Build MiniLM-SBERT label centroids for MedQuad
# and save files for HuggingFace app.py


!pip install -q sentence-transformers
# Importing the required libraries
import os
import json
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
from sentence_transformers import SentenceTransformer

from sklearn.preprocessing import LabelEncoder

# --------------------------------------------
# 1. Config

MEDQUAD_PATH = "medquad_300_focus10_realtopics.csv"      # CSV file
MODEL_NAME   = "sentence-transformers/all-MiniLM-L6-v2"  # best model
EXPORT_DIR   = "minilm_sbert_encoder_artifacts"          # folder to save embeddings

os.makedirs(EXPORT_DIR, exist_ok=True)

print("Using model:", MODEL_NAME)
print("Saving artifacts to:", EXPORT_DIR)

# 2. Load MedQuad and build QA text

df = pd.read_csv(MEDQUAD_PATH)

print("Loaded MedQuad:", df.shape)
print(df.head())

df["question"]   = df["question"].astype(str)
df["answer"]     = df["answer"].astype(str)
df["focus_area"] = df["focus_area"].astype(str)

# Use QUESTION + ANSWER,
df["qa_text"] = df["question"] + " " + df["answer"]


# 3. Encode labels
label_encoder = LabelEncoder()
y = label_encoder.fit_transform(df["focus_area"].values)

label_names = list(label_encoder.classes_)
num_labels  = len(label_names)

print("\nLabels:")
for i, name in enumerate(label_names):
    print(f"{i}: {name}")
print("Total labels:", num_labels)


# 4. Load MiniLM-SBERT encoder

device = "cuda" if torch.cuda.is_available() else "cpu"
print("\nLoading SentenceTransformer on", device)

encoder = SentenceTransformer(MODEL_NAME, device=device)


# 5. Compute embeddings for all QA texts

texts = df["qa_text"].tolist()

print("\nEncoding all QA examples...")
embeddings = encoder.encode(
    texts,
    batch_size=32,
    convert_to_numpy=True,
    show_progress_bar=True,
    normalize_embeddings=True,  # cosine similarity friendly
)

print("Embeddings shape:", embeddings.shape)

# --------------------------------------------
# 6. Compute per-label centroids

centroids = np.zeros((num_labels, embeddings.shape[1]), dtype="float32")

print("\nComputing label centroids...")
for idx, label_name in enumerate(label_names):
    mask = (df["focus_area"].values == label_name)
    label_vecs = embeddings[mask]

    if len(label_vecs) == 0:
        raise ValueError(f"No examples found for label: {label_name}")

    centroids[idx] = label_vecs.mean(axis=0)

print("Centroids shape:", centroids.shape)

# --------------------------------------------
# 7. Save centroids + label names

centroids_path = os.path.join(EXPORT_DIR, "label_centroids.npy")
labels_path    = os.path.join(EXPORT_DIR, "label_names.json")

np.save(centroids_path, centroids)
with open(labels_path, "w") as f:
    json.dump(label_names, f, indent=2)

print("\nSaved:")
print(" -", centroids_path)
print(" -", labels_path)

# --------------------------------------------

# --------------------------------------------
MODEL_EXPORT_DIR = os.path.join(EXPORT_DIR, "minilm_sbert_encoder")

print("\nSaving MiniLM-SBERT model + tokenizer to:", MODEL_EXPORT_DIR)
encoder.save(MODEL_EXPORT_DIR)

print("\nDone! Contents of EXPORT_DIR:")
!ls -R $EXPORT_DIR


Using model: sentence-transformers/all-MiniLM-L6-v2
Saving artifacts to: minilm_sbert_encoder_artifacts
Loaded MedQuad: (300, 5)
                                            question  \
0                       What is high blood pressure?   
1                   What causes high blood pressure?   
2  What are the common symptoms of high blood pre...   
3      How is high blood pressure usually diagnosed?   
4       How is high blood pressure commonly treated?   

                                              answer  \
0  High blood pressure is a health condition that...   
1  High blood pressure usually develops from a mi...   
2  Common symptoms of high blood pressure can inc...   
3  High blood pressure is usually diagnosed by co...   
4  Treatment for high blood pressure often involv...   

                      source           focus_area  Emotion  
0  synthetic_medquad_focus10  High Blood Pressure  neutral  
1  synthetic_medquad_focus10  High Blood Pressure  neutral  
2  synthetic_m

Batches:   0%|          | 0/10 [00:00<?, ?it/s]

Embeddings shape: (300, 384)

Computing label centroids...
Centroids shape: (10, 384)

Saved:
 - minilm_sbert_encoder_artifacts/label_centroids.npy
 - minilm_sbert_encoder_artifacts/label_names.json

Saving MiniLM-SBERT model + tokenizer to: minilm_sbert_encoder_artifacts/minilm_sbert_encoder

Done! Contents of EXPORT_DIR:
minilm_sbert_encoder_artifacts:
label_centroids.npy  label_names.json  minilm_sbert_encoder

minilm_sbert_encoder_artifacts/minilm_sbert_encoder:
1_Pooling			   README.md
2_Normalize			   sentence_bert_config.json
config.json			   special_tokens_map.json
config_sentence_transformers.json  tokenizer_config.json
model.safetensors		   tokenizer.json
modules.json			   vocab.txt

minilm_sbert_encoder_artifacts/minilm_sbert_encoder/1_Pooling:
config.json

minilm_sbert_encoder_artifacts/minilm_sbert_encoder/2_Normalize:


In [None]:
!ls -R


.:
medquad_300_focus10_realtopics.csv  minilm_sbert_encoder.zip
minilm_sbert_encoder_artifacts	    sample_data

./minilm_sbert_encoder_artifacts:
label_centroids.npy  label_names.json  minilm_sbert_encoder

./minilm_sbert_encoder_artifacts/minilm_sbert_encoder:
1_Pooling			   README.md
2_Normalize			   sentence_bert_config.json
config.json			   special_tokens_map.json
config_sentence_transformers.json  tokenizer_config.json
model.safetensors		   tokenizer.json
modules.json			   vocab.txt

./minilm_sbert_encoder_artifacts/minilm_sbert_encoder/1_Pooling:
config.json

./minilm_sbert_encoder_artifacts/minilm_sbert_encoder/2_Normalize:

./sample_data:
anscombe.json		      mnist_test.csv
california_housing_test.csv   mnist_train_small.csv
california_housing_train.csv  README.md


In [None]:
import os
[d for d in os.listdir() if os.path.isdir(d)]


['.config', 'minilm_sbert_encoder_artifacts', 'sample_data']

In [None]:
import shutil
from google.colab import files

FOLDER_NAME = "minilm_sbert_encoder_artifacts"
ZIP_NAME = FOLDER_NAME + ".zip"

# Create zip
shutil.make_archive(FOLDER_NAME, "zip", FOLDER_NAME)

# Download zip
files.download(ZIP_NAME)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>