In [1]:
# retrain_improved_model.py - Better training with hyperparameter tuning

import os
import joblib
import numpy as np
import pandas as pd
from utils import preprocess_text, MEDICAL_PHRASES
from sentence_transformers import SentenceTransformer
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.calibration import CalibratedClassifierCV
from sklearn.metrics import accuracy_score, f1_score, classification_report
import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = "archive/Final_Augmented_dataset_Diseases_and_Symptoms.csv"
MODELS_DIR = "models"
os.makedirs(MODELS_DIR, exist_ok=True)

MIN_SAMPLES = 3
RANDOM_STATE = 42

print("=" * 70)
print("IMPROVED MODEL TRAINING - Better Accuracy Focus")
print("=" * 70)

IMPROVED MODEL TRAINING - Better Accuracy Focus


In [3]:
# Load data
print("\n[1/6] Loading dataset...")
df = pd.read_csv(DATA_PATH)
df.columns = [c.strip().lower().replace(" ", "_") for c in df.columns]
df = df.dropna(subset=["diseases"])

print(f"   Total samples: {len(df)}")
print(f"   Unique diseases: {df['diseases'].nunique()}")



[1/6] Loading dataset...
   Total samples: 246945
   Unique diseases: 773


In [4]:
# Build symptom text
if "symptom_text" not in df.columns or df["symptom_text"].isnull().all():
    symptom_cols = [c for c in df.columns if c != "diseases"]
    def combine_symptoms(row):
        return " ".join([col for col in symptom_cols if row.get(col, 0) == 1])
    df["symptom_text"] = df.apply(combine_symptoms, axis=1)

# Enhanced preprocessing
print("\n[2/6] Preprocessing with medical context preservation...")
df["clean_text"] = df["symptom_text"].fillna("").astype(str).apply(preprocess_text)

# Filter rare diseases
counts = df["diseases"].value_counts()
valid = counts[counts >= MIN_SAMPLES].index
df = df[df["diseases"].isin(valid)].reset_index(drop=True)

print(f"   After filtering: {len(df)} samples, {df['diseases'].nunique()} diseases")



[2/6] Preprocessing with medical context preservation...
   After filtering: 246914 samples, 748 diseases


In [5]:
# Encode labels
le = LabelEncoder()
df["label"] = le.fit_transform(df["diseases"])

# Split data
X = df["clean_text"].tolist()
y = df["label"].values

X_train_text, X_test_text, y_train, y_test = train_test_split(
    X, y, test_size=0.2, stratify=y, random_state=RANDOM_STATE
)

print(f"   Training samples: {len(X_train_text)}")
print(f"   Test samples: {len(X_test_text)}")


   Training samples: 197531
   Test samples: 49383


In [6]:
# Embedding
print("\n[3/6] Generating embeddings (this takes time)...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
X_train_emb = embedder.encode(X_train_text, convert_to_numpy=True, show_progress_bar=True)
X_test_emb = embedder.encode(X_test_text, convert_to_numpy=True, show_progress_bar=True)

# Compute class centroids
print("\n[4/6] Computing class centroids...")
class_centroids = {}
for lbl in np.unique(y_train):
    idx = np.where(y_train == lbl)[0]
    class_centroids[int(lbl)] = X_train_emb[idx].mean(axis=0)


INFO:sentence_transformers.SentenceTransformer:Use pytorch device_name: cpu
INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: all-MiniLM-L6-v2



[3/6] Generating embeddings (this takes time)...


Batches: 100%|█████████████████████████████████████████████████████████████████████| 6173/6173 [05:30<00:00, 18.67it/s]
Batches: 100%|█████████████████████████████████████████████████████████████████████| 1544/1544 [01:22<00:00, 18.64it/s]



[4/6] Computing class centroids...


In [7]:
# Train models with better hyperparameters
print("\n[5/6] Training models with optimized hyperparameters...")

# 1. Naive Bayes (baseline)
print("   Training Naive Bayes...")
nb = MultinomialNB(alpha=0.5)  # Reduced alpha for less smoothing
nb.fit(np.abs(X_train_emb), y_train)
y_pred_nb = nb.predict(np.abs(X_test_emb))
acc_nb = accuracy_score(y_test, y_pred_nb)
print(f"      NB Accuracy: {acc_nb:.4f}")

# 2. Logistic Regression with tuning
print("   Training Logistic Regression...")
lr = LogisticRegression(
    max_iter=3000,
    C=1.0,  # Regularization
    solver='lbfgs',
    class_weight='balanced',  # Handle class imbalance
    random_state=RANDOM_STATE
)
lr.fit(X_train_emb, y_train)
y_pred_lr = lr.predict(X_test_emb)
acc_lr = accuracy_score(y_test, y_pred_lr)
print(f"      LR Accuracy: {acc_lr:.4f}")

# 3. SVC with RBF kernel (better for complex patterns)
print("   Training SVC with RBF kernel...")
svc_base = SVC(
    kernel='rbf',
    C=10.0,  # Increased C for better fit
    gamma='scale',
    probability=True,
    class_weight='balanced',
    random_state=RANDOM_STATE,
    max_iter=5000
)
svc_base.fit(X_train_emb, y_train)



[5/6] Training models with optimized hyperparameters...
   Training Naive Bayes...
      NB Accuracy: 0.6772
   Training Logistic Regression...
      LR Accuracy: 0.8418
   Training SVC with RBF kernel...


0,1,2
,C,10.0
,kernel,'rbf'
,degree,3
,gamma,'scale'
,coef0,0.0
,shrinking,True
,probability,True
,tol,0.001
,cache_size,200
,class_weight,'balanced'


In [8]:
print("   Using SVC predictions...")
y_pred_svc = svc_base.predict(X_test_emb)
acc_svc = accuracy_score(y_test, y_pred_svc)
print(f"      SVC Accuracy: {acc_svc:.4f}")

# Evaluation
print("\n[6/6] Evaluation Results:")
print("=" * 70)
print(f"Naive Bayes:              {acc_nb:.4f}")
print(f"Logistic Regression:      {acc_lr:.4f}")
print(f"SVC (RBF kernel):         {acc_svc:.4f}")
print("=" * 70)


   Using SVC predictions...
      SVC Accuracy: 0.8582

[6/6] Evaluation Results:
Naive Bayes:              0.6772
Logistic Regression:      0.8418
SVC (RBF kernel):         0.8582


In [10]:
# Best model report
best_model_name = max([("NB", acc_nb), ("LR", acc_lr), ("SVC", acc_svc)], key=lambda x: x[1])[0]
best_pred = {"NB": y_pred_nb, "LR": y_pred_lr, "SVC": y_pred_svc}[best_model_name]

print(f"\nBest Model: {best_model_name}")
print("\nDetailed Classification Report (Top 20 classes):")
print("-" * 70)

# ✅ Fix for mismatch between y_test and encoder classes
present_labels = np.unique(np.concatenate([y_test, best_pred]))

report = classification_report(
    y_test,
    best_pred,
    labels=present_labels,
    target_names=le.inverse_transform(present_labels),
    zero_division=0,
    output_dict=True
)

# Sort by support
sorted_classes = sorted(
    report.items(),
    key=lambda x: x[1].get('support', 0) if isinstance(x[1], dict) else 0,
    reverse=True
)

print(f"{'Disease':<30} {'Precision':>10} {'Recall':>10} {'F1-Score':>10} {'Support':>10}")
print("-" * 70)

for disease, metrics in sorted_classes[:20]:
    if isinstance(metrics, dict) and 'support' in metrics:
        print(f"{disease:<30} {metrics['precision']:>10.3f} {metrics['recall']:>10.3f} {metrics['f1-score']:>10.3f} {int(metrics['support']):>10}")




Best Model: SVC

Detailed Classification Report (Top 20 classes):
----------------------------------------------------------------------
Disease                         Precision     Recall   F1-Score    Support
----------------------------------------------------------------------
macro avg                           0.814      0.878      0.831      49383
weighted avg                        0.877      0.858      0.861      49383
cystitis                            0.983      0.697      0.815        244
nose disorder                       1.000      0.799      0.888        244
vulvodynia                          0.977      0.881      0.927        244
acute bronchitis                    0.861      0.687      0.764        243
complex regional pain syndrome      0.934      0.881      0.907        243
conjunctivitis due to allergy       0.986      0.881      0.930        243
diverticulitis                      0.877      0.819      0.847        243
esophagitis                         0.996

In [11]:
# Sort by support
sorted_classes = sorted(report.items(), key=lambda x: x[1].get('support', 0) if isinstance(x[1], dict) else 0, reverse=True)

print(f"{'Disease':<30} {'Precision':>10} {'Recall':>10} {'F1-Score':>10} {'Support':>10}")
print("-" * 70)

for disease, metrics in sorted_classes[:20]:
    if isinstance(metrics, dict) and 'support' in metrics:
        print(f"{disease:<30} {metrics['precision']:>10.3f} {metrics['recall']:>10.3f} {metrics['f1-score']:>10.3f} {int(metrics['support']):>10}")


Disease                         Precision     Recall   F1-Score    Support
----------------------------------------------------------------------
macro avg                           0.814      0.878      0.831      49383
weighted avg                        0.877      0.858      0.861      49383
cystitis                            0.983      0.697      0.815        244
nose disorder                       1.000      0.799      0.888        244
vulvodynia                          0.977      0.881      0.927        244
acute bronchitis                    0.861      0.687      0.764        243
complex regional pain syndrome      0.934      0.881      0.907        243
conjunctivitis due to allergy       0.986      0.881      0.930        243
diverticulitis                      0.877      0.819      0.847        243
esophagitis                         0.996      0.914      0.953        243
gastrointestinal hemorrhage         0.990      0.848      0.914        243
hypoglycemia                 

In [13]:
# Save models
print("\n[SAVING] Saving trained models...")
joblib.dump(nb, os.path.join(MODELS_DIR, "diagnobot_nb.pkl"))
joblib.dump(lr, os.path.join(MODELS_DIR, "diagnobot_lr.pkl"))
# joblib.dump(svc_cal, os.path.join(MODELS_DIR, "diagnobot_svc_calibrated.pkl"))
joblib.dump(le, os.path.join(MODELS_DIR, "diagnobot_label_encoder.pkl"))
joblib.dump(class_centroids, os.path.join(MODELS_DIR, "diagnobot_class_centroids.pkl"))

print(f"\nâœ… Models saved to '{MODELS_DIR}' directory")

# Save training metadata
metadata = {
    'training_date': pd.Timestamp.now().isoformat(),
    'total_diseases': len(le.classes_),
    'training_samples': len(X_train_text),
    'test_samples': len(X_test_text),
    'nb_accuracy': acc_nb,
    'lr_accuracy': acc_lr,
    'svc_accuracy': acc_svc,
    'best_model': best_model_name
}

joblib.dump(metadata, os.path.join(MODELS_DIR, "training_metadata.pkl"))

print("\nâœ… Training complete!")
print("=" * 70)



[SAVING] Saving trained models...

âœ… Models saved to 'models' directory

âœ… Training complete!
