## Training, Saving, and Fine-Tuning Logistic Regression Model for Medical Text Classification

## Adding the imports

In [22]:
import os
import re
import random
import joblib
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import SGDClassifier
from sklearn.metrics import accuracy_score, roc_auc_score, log_loss
from scipy.sparse import hstack, csr_matrix

## Define the medical terms

In [14]:
MEDICAL_TERMS = {
    'symptoms': ['pain', 'chest pain', 'shortness of breath', 'nausea', 'vomiting', 'dizziness', 'headache', 'fever', 'fatigue', 'weakness'],
    'conditions': ['hypertension', 'diabetes', 'cardiac', 'cardiology', 'emergency', 'diagnosis', 'syndrome', 'disease', 'disorder'],
    'demographics': ['patient', 'male', 'female', 'years old', 'age', 'elderly', 'adult', 'pediatric'],
    'clinical': ['diagnosis', 'treatment', 'symptoms', 'history', 'complaint', 'medication', 'therapy', 'procedure', 'examination'],
    'anatomy': ['heart', 'lung', 'brain', 'kidney', 'liver', 'stomach', 'chest', 'abdomen', 'extremities'],
    'medical_specialty': ['cardiology', 'neurology', 'gastroenterology', 'emergency', 'internal medicine', 'surgery'],
    'vitals': ['blood pressure', 'heart rate', 'temperature', 'respiratory rate', 'oxygen saturation', 'pulse'],
    'assessments': ['workup', 'evaluation', 'assessment', 'monitoring', 'follow-up', 'consultation']
}

## Reproducibility

In [15]:
random.seed(42)
np.random.seed(42)

## Data generation

In [16]:
def sample_terms(category, k=2):
    """Sample k terms from a given category."""
    terms = MEDICAL_TERMS[category]
    return random.sample(terms, min(k, len(terms)))

def build_sentence():
    """Build a synthetic medical sentence."""
    parts = []
    parts += sample_terms('demographics', k=1)
    parts += sample_terms('clinical', k=1)
    parts += sample_terms('anatomy', k=1)
    parts += sample_terms('symptoms', k=2)
    if random.random() < 0.3:
        parts += sample_terms('conditions', k=1)
    if random.random() < 0.25:
        parts += sample_terms('vitals', k=1)
    if random.random() < 0.25:
        parts += sample_terms('assessments', k=1)
    random.shuffle(parts)
    sent = " ".join(parts)
    return f"{sent}. The {random.choice(['patient','case'])} requires {random.choice(['evaluation','monitoring','assessment','consultation'])}."

def label_sentence(text):
    """Label sentence: emergency (1) or not (0)."""
    txt = text.lower()
    if 'emergency' in txt or ('chest pain' in txt and 'shortness of breath' in txt):
        return 1
    return 0

## Create synthetic dataset

In [17]:
N = 2000
texts = [build_sentence() for _ in range(N)]
labels = [label_sentence(t) for t in texts]
df = pd.DataFrame({"text": texts, "label": labels})

Train - test split

In [18]:
X_train, X_test, y_train, y_test = train_test_split(
    df["text"], df["label"],
    test_size=0.2, random_state=42, stratify=df["label"]
)

## Feature Engineering

In [19]:
def count_category_hits(text_series: pd.Series) -> pd.DataFrame:
    """Count term matches per category."""
    rows = []
    for txt in text_series:
        txt_l = txt.lower()
        row = {}
        for cat, terms in MEDICAL_TERMS.items():
            cnt = 0
            for term in terms:
                pattern = r'\b' + re.escape(term.lower()) + r'\b'
                cnt += len(re.findall(pattern, txt_l))
            row[f'count_{cat}'] = cnt
        rows.append(row)
    return pd.DataFrame(rows)

def cat_count_transform(X):
    if isinstance(X, np.ndarray):
        X = pd.Series(X.ravel())
    elif isinstance(X, list):
        X = pd.Series(X)
    return count_category_hits(X)

# TF-IDF + counts
tfidf = TfidfVectorizer(lowercase=True, ngram_range=(1,2), min_df=2, max_df=0.95)

X_train_tfidf = tfidf.fit_transform(X_train)
X_train_counts = cat_count_transform(X_train)
X_train_full = hstack([X_train_tfidf, csr_matrix(X_train_counts.values)])

X_test_tfidf = tfidf.transform(X_test)
X_test_counts = cat_count_transform(X_test)
X_test_full = hstack([X_test_tfidf, csr_matrix(X_test_counts.values)])

In [20]:
## Train the model

In [23]:
# === Train → Save → Fine-tune → Save-again (SGD Logistic Regression) ===

# --- Config ---
base_epochs      = 10      # initial training epochs
finetune_epochs  = 10      # extra epochs after saving (fine-tuning)
batch_size       = 256
models_dir       = "../models"
base_ckpt_path   = os.path.join(models_dir, "logreg_sgd_base.joblib")
ft_ckpt_path     = os.path.join(models_dir, "logreg_sgd_finetuned.joblib")

# --- Initialize model (logistic regression via SGD) ---
clf = SGDClassifier(loss="log_loss", max_iter=1, learning_rate="optimal", tol=None, random_state=42)
classes = np.array([0, 1])

n_samples = X_train_full.shape[0]
indices = np.arange(n_samples)

def train_for_epochs(model, n_epochs):
    """Train model for n_epochs using mini-batch partial_fit."""
    for epoch in range(n_epochs):
        np.random.shuffle(indices)
        for start in range(0, n_samples, batch_size):
            batch_idx = indices[start:start+batch_size]
            if epoch == 0 and start == 0 and not hasattr(model, "classes_"):
                model.partial_fit(X_train_full[batch_idx], np.array(y_train)[batch_idx], classes=classes)
            else:
                model.partial_fit(X_train_full[batch_idx], np.array(y_train)[batch_idx])
    return model

def eval_model(model, tag=""):
    y_pred  = model.predict(X_test_full)
    y_proba = model.predict_proba(X_test_full)[:, 1]
    acc = accuracy_score(y_test, y_pred)
    auc = roc_auc_score(y_test, y_proba)
    loss = log_loss(y_test, y_proba, labels=[0,1])
    print(f"[{tag}] Test Acc: {acc:.3f} | AUC: {auc:.3f} | LogLoss: {loss:.4f}")
    return acc, auc, loss

# --- 1) Train initial epochs ---
clf = train_for_epochs(clf, base_epochs)

# --- 2) Evaluate & Save base model ---
eval_model(clf, tag=f"After {base_epochs} epochs")
joblib.dump(clf, base_ckpt_path)
print(f"Saved base model to: {base_ckpt_path}")

# --- 3) Fine-tune (continue training) ---
clf = train_for_epochs(clf, finetune_epochs)

# --- 4) Evaluate & Save fine-tuned model ---
eval_model(clf, tag=f"After {base_epochs + finetune_epochs} epochs (fine-tuned)")
joblib.dump(clf, ft_ckpt_path)
print(f"Saved fine-tuned model to: {ft_ckpt_path}")

[After 10 epochs] Test Acc: 0.995 | AUC: 0.999 | LogLoss: 0.0270
Saved base model to: ../models\logreg_sgd_base.joblib
[After 20 epochs (fine-tuned)] Test Acc: 0.970 | AUC: 0.998 | LogLoss: 0.0505
Saved fine-tuned model to: ../models\logreg_sgd_finetuned.joblib
