In [4]:
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.svm import LinearSVC
from sklearn.metrics import accuracy_score, classification_report
import joblib

# ============= 1. Load Dataset =============
df = pd.read_parquet("hf://datasets/ruslanmv/ai-medical-chatbot/dialogues.parquet")

# Combine patient + description into a single input text
df['combined_text'] = df['Patient'] + " " + df['Description']
df['label'] = df['Doctor']

# Drop rare doctor labels (less than 2 examples)
df = df.groupby('label').filter(lambda x: len(x) > 1)

# Define features and labels
X = df['combined_text']
y = df['label']

# ============= 2. Train-test split =============
X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

# ============= 3. Convert text to TF-IDF vectors =============
vectorizer = TfidfVectorizer(max_features=5000, stop_words="english")
X_train_tfidf = vectorizer.fit_transform(X_train)
X_test_tfidf = vectorizer.transform(X_test)

# ============= 4. Train Multiple Models =============
models = {
    "RandomForest": RandomForestClassifier(n_estimators=200, random_state=42, n_jobs=-1),
    "LogisticRegression": LogisticRegression(max_iter=2000, random_state=42, n_jobs=-1),
    "SVM": LinearSVC(random_state=42)
}

results = {}

for name, model in models.items():
    print(f"\n🔹 Training {name}...")
    model.fit(X_train_tfidf, y_train)
    y_pred = model.predict(X_test_tfidf)
    
    acc = accuracy_score(y_test, y_pred)
    results[name] = acc
    
    print(f"\n✅ {name} Accuracy: {acc:.4f}")
    print("Classification Report:")
    print(classification_report(y_test, y_pred))

# ============= 5. Pick Best Model =============
best_model_name = max(results, key=results.get)
best_model = models[best_model_name]

print(f"\n🚀 Best Model: {best_model_name} with Accuracy {results[best_model_name]:.4f}")

# Save best model + vectorizer
joblib.dump(best_model, "best_model.pkl")
joblib.dump(vectorizer, "vectorizer.pkl")

print("\n💾 Model and vectorizer saved successfully!")

# ============= 6. Sample Predictions =============
y_pred_best = best_model.predict(X_test_tfidf)
print("\n🔎 Example Predictions:\n")
for i in range(5):
    print("Patient+Question:", X_test.iloc[i][:200], "...")
    print("Actual Doctor Response:", y_test.iloc[i])
    print("Predicted Doctor Response:", y_pred_best[i])
    print("="*70)


# ============= 7. FastAPI Service =============
# Save this part in a separate file "api.py"
"""
from fastapi import FastAPI, Request, HTTPException
import joblib

# Load saved model + vectorizer
model = joblib.load("best_model.pkl")
vectorizer = joblib.load("vectorizer.pkl")

app = FastAPI()

API_KEY = "mysecretapikey123"  # 🔑 Your API Key

@app.post("/predict")
async def predict(request: Request):
    headers = request.headers
    if headers.get("x-api-key") != API_KEY:
        raise HTTPException(status_code=403, detail="Unauthorized")

    data = await request.json()
    text = data["text"]

    features = vectorizer.transform([text])
    prediction = model.predict(features).tolist()

    return {"prediction": prediction}
"""



🔹 Training RandomForest...

✅ RandomForest Accuracy: 0.7064
Classification Report:


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                        

'\nfrom fastapi import FastAPI, Request, HTTPException\nimport joblib\n\n# Load saved model + vectorizer\nmodel = joblib.load("best_model.pkl")\nvectorizer = joblib.load("vectorizer.pkl")\n\napp = FastAPI()\n\nAPI_KEY = "mysecretapikey123"  # 🔑 Your API Key\n\n@app.post("/predict")\nasync def predict(request: Request):\n    headers = request.headers\n    if headers.get("x-api-key") != API_KEY:\n        raise HTTPException(status_code=403, detail="Unauthorized")\n\n    data = await request.json()\n    text = data["text"]\n\n    features = vectorizer.transform([text])\n    prediction = model.predict(features).tolist()\n\n    return {"prediction": prediction}\n'