## TEXT MODEL TRAINING

In [6]:
import os
import re
import numpy as np
import pandas as pd
from tqdm import tqdm

TEXT_ROOT = "../dataset/adress"   # <--- your .cha dataset root
TASKS = ["cookie", "fluency", "recall", "sentence"]

os.makedirs("../models", exist_ok=True)


In [7]:
def extract_patient_text(cha_path: str, speaker_tag="*PAR:") -> str:
    lines = []
    with open(cha_path, "r", encoding="utf-8", errors="ignore") as f:
        for line in f:
            if line.startswith(speaker_tag):
                # CHAT lines often contain tabs; keep right side if present
                content = line.split("\t", 1)[-1].strip()
                lines.append(content)

    text = " ".join(lines).lower()
    # light cleanup
    text = re.sub(r"\[.*?\]", " ", text)     # remove bracket tags
    text = re.sub(r"[^a-z\s']", " ", text)   # keep letters/apostrophes
    text = re.sub(r"\s+", " ", text).strip()
    return text


## Load transcripts into a dataframe

In [8]:
records = []

for label in ["Control", "Dementia"]:
    for task in TASKS:
        folder = os.path.join(TEXT_ROOT, label, task)
        if not os.path.exists(folder):
            continue

        for f in os.listdir(folder):
            if f.lower().endswith(".cha"):
                cha_path = os.path.join(folder, f)
                text = extract_patient_text(cha_path)

                if len(text) < 10:   # skip empty/too short
                    continue

                records.append({
                    "id": os.path.splitext(f)[0],   # e.g., "332-0" or "S001"
                    "task": task,
                    "label": label,
                    "text": text
                })

df = pd.DataFrame(records)
df.head(), df.shape


(      id    task    label                                               text
 0  002-0  cookie  Control  the scene is in the in the kitchen the mother ...
 1  002-1  cookie  Control  oh i see the sink is running over i see the st...
 2  002-2  cookie  Control  um a boy and a girl are in the kitchen with th...
 3  002-3  cookie  Control  okay it was summertime and mother and the chil...
 4  006-2  cookie  Control  clears throat wait un til i put my glasses on ...,
 (1287, 4))

### TF-IDF + PCA features

In [9]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import PCA

texts = df["text"].tolist()

tfidf = TfidfVectorizer(
    max_features=300,
    ngram_range=(1, 2),
    stop_words="english"
)

X_tfidf = tfidf.fit_transform(texts).toarray()

# PCA (choose 30; adjust if needed)
pca = PCA(n_components=30, random_state=42)
X_text = pca.fit_transform(X_tfidf)

print("TFIDF shape:", X_tfidf.shape)
print("PCA shape:", X_text.shape)


TFIDF shape: (1287, 300)
PCA shape: (1287, 30)


### Encode labels + train/test split

In [10]:
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

le = LabelEncoder()
y = le.fit_transform(df["label"].values)  # Control=0, Dementia=1

X_train, X_test, y_train, y_test = train_test_split(
    X_text, y, test_size=0.2, random_state=42, stratify=y
)

le.classes_


array(['Control', 'Dementia'], dtype=object)

## Train multiple models + TRAIN/TEST metrics table

In [11]:
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import RandomForestClassifier
import joblib

def compute_metrics(y_true, y_pred):
    return {
        "Accuracy": accuracy_score(y_true, y_pred),
        "Precision": precision_score(y_true, y_pred, zero_division=0),
        "Recall": recall_score(y_true, y_pred, zero_division=0),
        "F1": f1_score(y_true, y_pred, zero_division=0)
    }

models = {
    "LogisticRegression": LogisticRegression(max_iter=2000, class_weight="balanced"),
    "SVM_RBF": SVC(kernel="rbf", probability=True, class_weight="balanced"),
    "RandomForest": RandomForestClassifier(
        n_estimators=300, random_state=42,
        class_weight="balanced",
        max_depth=12, min_samples_leaf=4
    )
}

results = []
trained = {}

for name, model in models.items():
    print(f"\nüöÄ Training: {name}")
    model.fit(X_train, y_train)

    pred_train = model.predict(X_train)
    pred_test  = model.predict(X_test)

    m_train = compute_metrics(y_train, pred_train)
    m_test  = compute_metrics(y_test, pred_test)

    results.append({
        "Model": name,
        "Train_Accuracy": m_train["Accuracy"],
        "Train_Precision": m_train["Precision"],
        "Train_Recall": m_train["Recall"],
        "Train_F1": m_train["F1"],
        "Test_Accuracy": m_test["Accuracy"],
        "Test_Precision": m_test["Precision"],
        "Test_Recall": m_test["Recall"],
        "Test_F1": m_test["F1"],
    })

    trained[name] = model
    print("üìä Test report:")
    print(classification_report(y_test, pred_test, target_names=le.classes_))

results_df = pd.DataFrame(results).sort_values("Test_F1", ascending=False)
results_df



üöÄ Training: LogisticRegression
üìä Test report:
              precision    recall  f1-score   support

     Control       0.57      0.92      0.71        50
    Dementia       0.98      0.84      0.90       208

    accuracy                           0.85       258
   macro avg       0.78      0.88      0.80       258
weighted avg       0.90      0.85      0.86       258


üöÄ Training: SVM_RBF
üìä Test report:
              precision    recall  f1-score   support

     Control       0.60      0.90      0.72        50
    Dementia       0.97      0.86      0.91       208

    accuracy                           0.86       258
   macro avg       0.79      0.88      0.82       258
weighted avg       0.90      0.86      0.87       258


üöÄ Training: RandomForest
üìä Test report:
              precision    recall  f1-score   support

     Control       0.74      0.80      0.77        50
    Dementia       0.95      0.93      0.94       208

    accuracy                           0

Unnamed: 0,Model,Train_Accuracy,Train_Precision,Train_Recall,Train_F1,Test_Accuracy,Test_Precision,Test_Recall,Test_F1
2,RandomForest,0.968902,0.996278,0.965144,0.980464,0.906977,0.95098,0.932692,0.941748
1,SVM_RBF,0.898931,0.995913,0.878606,0.933589,0.864341,0.972678,0.855769,0.910486
0,LogisticRegression,0.88241,0.988996,0.864183,0.922386,0.852713,0.977528,0.836538,0.901554


### Select best model

In [13]:
best_name = results_df.iloc[0]["Model"]
best_model = trained[best_name]

print("üèÜ Best TEXT model:", best_name)

joblib.dump(best_model, "../models/best_text_model.pkl")
joblib.dump(tfidf, "../models/text_tfidf.pkl")
joblib.dump(pca, "../models/text_pca.pkl")
joblib.dump(le, "../models/text_label_encoder.pkl")

results_df.to_csv("../models/text_model_results.csv", index=False)
print("‚úÖ Saved: best_text_model + tfidf + pca + label encoder + results csv")


üèÜ Best TEXT model: RandomForest
‚úÖ Saved: best_text_model + tfidf + pca + label encoder + results csv
