# Fine-Tune Sentence Transformer for Medical Symptom Extraction

**Purpose:** Train a sentence transformer model to better understand how patients describe symptoms in natural language.

**Dataset:** 
- 200k doctor-patient conversations (ai-medical-chatbot.csv)
- 131+ canonical symptoms (symptoms.csv)

**Output:** Fine-tuned model saved to `models/medical_symptom_matcher/`

---

## Table of Contents
1. [Setup & Installation](#setup)
2. [Load Data](#load-data)
3. [Prepare Training Data](#prepare-training)
4. [Fine-Tune Model](#fine-tune)
5. [Evaluate Performance](#evaluate)
6. [Save Model](#save)
7. [Quick Test](#test)

---
## 1. Setup & Installation <a id='setup'></a>

In [None]:
# Install required packages
!pip install sentence-transformers torch pandas numpy scikit-learn tqdm

In [None]:
# Google Colab Setup
import os

# Create necessary directories
!mkdir -p data models

print("📁 Directories created!")
print("\n⚠️ IMPORTANT: Upload your datasets now!")
print("   1. Click the folder icon on the left sidebar")
print("   2. Navigate to the 'data' folder")
print("   3. Upload these files:")
print("      - ai-medical-chatbot.csv")
print("      - symptoms.csv")
print("\n💡 Or use the code below to upload via dialog:")
print("\nfrom google.colab import files")
print("uploaded = files.upload()")
print("# Then move files: !mv *.csv data/")

In [None]:
# Check GPU availability
import torch

if torch.cuda.is_available():
    print(f"✅ GPU Detected: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("⚠️ No GPU detected. Training will be MUCH slower.")
    print("   Go to Runtime > Change runtime type > Select GPU")

In [1]:
import pandas as pd
import numpy as np
import torch
import re
from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import json
import os
from collections import defaultdict

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")
else:
    print("⚠️ Running on CPU - training will be slower but still works!")

  from .autonotebook import tqdm as notebook_tqdm


KeyboardInterrupt: 

---
## 2. Load Data <a id='load-data'></a>

In [None]:
# Load symptoms list
print("Loading symptoms...")
symptoms_df = pd.read_csv('symptoms.csv')
SYMPTOMS = symptoms_df['symptoms'].str.lower().str.strip().tolist()
print(f"✓ Loaded {len(SYMPTOMS)} symptoms")
print(f"\nFirst 10 symptoms:")
for i, symptom in enumerate(SYMPTOMS[:10], 1):
    print(f"  {i}. {symptom}")

In [None]:
# Load medical conversations
print("\nLoading medical conversations...")
print("⚠️ This may take a minute for 200k rows...")

conversations_df = pd.read_csv('ai-medical-chatbot.csv')
print(f"✓ Loaded {len(conversations_df):,} conversations")
print(f"\nDataset columns: {list(conversations_df.columns)}")
print(f"\nFirst conversation example:")
print(f"Description: {conversations_df.iloc[0]['Description'][:100]}...")
print(f"Patient: {conversations_df.iloc[0]['Patient'][:150]}...")
print(f"Doctor: {conversations_df.iloc[0]['Doctor'][:150]}...")

---
## 3. Prepare Training Data <a id='prepare-training'></a>

We'll create positive training pairs where:
- **Text 1:** Patient's message
- **Text 2:** Symptom name
- **Label:** 1.0 (if symptom mentioned), 0.0 (if not mentioned)

This teaches the model to recognize when patient language matches a symptom.

In [None]:
def clean_text(text):
    """Clean and normalize text"""
    if pd.isna(text):
        return ""
    text = str(text).lower()
    # Remove extra whitespace
    text = ' '.join(text.split())
    return text

def extract_symptom_mentions(patient_text, doctor_text, symptoms_list):
    """
    Find which symptoms are mentioned in the conversation
    Returns list of mentioned symptoms
    """
    patient_text = clean_text(patient_text)
    doctor_text = clean_text(doctor_text)
    combined_text = patient_text + " " + doctor_text
    
    mentioned_symptoms = []
    
    for symptom in symptoms_list:
        symptom_clean = symptom.lower().strip()
        
        # Check exact match
        if symptom_clean in combined_text:
            mentioned_symptoms.append(symptom)
            continue
        
        # Check word-by-word match (for multi-word symptoms)
        symptom_words = symptom_clean.split()
        if len(symptom_words) > 1:
            # Check if all words appear in text (allows for different word order)
            if all(word in combined_text for word in symptom_words):
                mentioned_symptoms.append(symptom)
    
    return mentioned_symptoms

# Test the function
test_patient = "I have a terrible headache and feel nauseous"
test_doctor = "You may have migraine. The nausea is common with headaches."
test_mentions = extract_symptom_mentions(test_patient, test_doctor, SYMPTOMS)
print(f"Test extraction: {test_mentions}")

In [None]:
def create_training_examples(conversations_df, symptoms_list, max_examples=50000, sample_negatives=True):
    """
    Create training examples from conversations
    
    Args:
        conversations_df: DataFrame with Patient, Doctor columns
        symptoms_list: List of canonical symptoms
        max_examples: Maximum number of positive examples to create
        sample_negatives: Whether to include negative examples (no symptom match)
    
    Returns:
        List of InputExample objects
    """
    training_examples = []
    positive_count = 0
    negative_count = 0
    
    print(f"\nCreating training examples from {len(conversations_df):,} conversations...")
    print(f"This will take 5-10 minutes...\n")
    
    for idx, row in tqdm(conversations_df.iterrows(), total=len(conversations_df)):
        if positive_count >= max_examples:
            break
        
        patient_text = clean_text(row['Patient'])
        doctor_text = clean_text(row['Doctor'])
        
        # Skip very short messages
        if len(patient_text.split()) < 5:
            continue
        
        # Find mentioned symptoms
        mentioned_symptoms = extract_symptom_mentions(patient_text, doctor_text, symptoms_list)
        
        if mentioned_symptoms:
            # Create positive examples
            for symptom in mentioned_symptoms:
                training_examples.append(
                    InputExample(texts=[patient_text, symptom], label=1.0)
                )
                positive_count += 1
            
            # Create some negative examples (symptoms NOT mentioned)
            if sample_negatives and len(mentioned_symptoms) < len(symptoms_list):
                # Sample 1-2 negative symptoms per positive
                unmentioned = [s for s in symptoms_list if s not in mentioned_symptoms]
                num_negatives = min(2, len(unmentioned))
                negative_samples = np.random.choice(unmentioned, size=num_negatives, replace=False)
                
                for neg_symptom in negative_samples:
                    training_examples.append(
                        InputExample(texts=[patient_text, neg_symptom], label=0.0)
                    )
                    negative_count += 1
    
    print(f"\n✓ Created {len(training_examples):,} training examples")
    print(f"  - Positive examples: {positive_count:,}")
    print(f"  - Negative examples: {negative_count:,}")
    print(f"  - Ratio: {positive_count/negative_count:.2f}:1 (positive:negative)")
    
    return training_examples

In [None]:
# Create training examples
# Adjust max_examples based on your compute resources:
# - CPU: 10,000-20,000 examples
# - GPU: 50,000+ examples

MAX_EXAMPLES = 30000  # Adjust this based on your hardware

train_examples = create_training_examples(
    conversations_df, 
    SYMPTOMS, 
    max_examples=MAX_EXAMPLES,
    sample_negatives=True
)

In [None]:
# Split into train and validation
train_data, val_data = train_test_split(train_examples, test_size=0.1, random_state=42)

print(f"\nDataset split:")
print(f"  Training: {len(train_data):,} examples")
print(f"  Validation: {len(val_data):,} examples")

# Show some examples
print(f"\n📝 Sample training examples:")
for i in range(min(3, len(train_data))):
    example = train_data[i]
    print(f"\nExample {i+1}:")
    print(f"  Patient text: {example.texts[0][:100]}...")
    print(f"  Symptom: {example.texts[1]}")
    print(f"  Label: {'MATCH ✓' if example.label == 1.0 else 'NO MATCH ✗'}")

---
## 4. Fine-Tune Model <a id='fine-tune'></a>

We'll fine-tune the `all-MiniLM-L6-v2` model using cosine similarity loss.

In [None]:
# Load base model
print("Loading base model...")
base_model_name = 'sentence-transformers/all-MiniLM-L6-v2'
model = SentenceTransformer(base_model_name)
print(f"✓ Loaded {base_model_name}")
print(f"\nModel embedding dimension: {model.get_sentence_embedding_dimension()}")

In [None]:
# Create dataloaders
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=16)
val_dataloader = DataLoader(val_data, shuffle=False, batch_size=16)

# Define loss function
train_loss = losses.CosineSimilarityLoss(model)

print(f"Training setup:")
print(f"  Batch size: 16")
print(f"  Training batches: {len(train_dataloader)}")
print(f"  Validation batches: {len(val_dataloader)}")
print(f"  Loss function: CosineSimilarityLoss")

In [None]:
# Training configuration
NUM_EPOCHS = 3  # Usually 2-4 epochs is enough
WARMUP_STEPS = int(len(train_dataloader) * 0.1)  # 10% of training data for warm-up

print(f"\nTraining configuration:")
print(f"  Epochs: {NUM_EPOCHS}")
print(f"  Warmup steps: {WARMUP_STEPS}")
print(f"  Total training steps: {len(train_dataloader) * NUM_EPOCHS}")

# Estimate training time
if torch.cuda.is_available():
    print(f"\n⏱️ Estimated time: {NUM_EPOCHS * 15}-{NUM_EPOCHS * 30} minutes on GPU")
else:
    print(f"\n⏱️ Estimated time: {NUM_EPOCHS * 60}-{NUM_EPOCHS * 120} minutes on CPU")
    print(f"   (Consider using Google Colab with GPU if this is too slow)")

In [None]:
# Create output directory
os.makedirs('models', exist_ok=True)
output_path = 'models/medical_symptom_matcher'

print(f"\n🚀 Starting training...\n")
print("="*60)

# Train the model
model.fit(
    train_objectives=[(train_dataloader, train_loss)],
    epochs=NUM_EPOCHS,
    warmup_steps=WARMUP_STEPS,
    output_path=output_path,
    show_progress_bar=True,
    save_best_model=True
)

print("\n" + "="*60)
print(f"✓ Training complete!")
print(f"✓ Model saved to: {output_path}")

---
## 5. Evaluate Performance <a id='evaluate'></a>

Let's test the fine-tuned model against the base model.

In [None]:
# Load both models for comparison
print("Loading models for comparison...")
base_model = SentenceTransformer(base_model_name)
finetuned_model = SentenceTransformer(output_path)
print("✓ Models loaded")

# Pre-compute symptom embeddings
print("\nComputing symptom embeddings...")
base_symptom_embeddings = base_model.encode(SYMPTOMS, convert_to_tensor=True, show_progress_bar=True)
finetuned_symptom_embeddings = finetuned_model.encode(SYMPTOMS, convert_to_tensor=True, show_progress_bar=True)
print("✓ Embeddings computed")

In [None]:
from sentence_transformers import util

def extract_symptoms_comparison(text, model, symptom_embeddings, symptoms_list, threshold=0.5, top_k=5):
    """
    Extract symptoms using a given model
    """
    text_embedding = model.encode(text, convert_to_tensor=True)
    similarities = util.cos_sim(text_embedding, symptom_embeddings)[0]
    
    matches = []
    for idx, score in enumerate(similarities):
        if score > threshold:
            matches.append({
                'symptom': symptoms_list[idx],
                'score': float(score)
            })
    
    matches = sorted(matches, key=lambda x: x['score'], reverse=True)[:top_k]
    return matches

# Test cases that previously failed
test_cases = [
    "I've been sneezing all day",
    "My head hurts really bad",
    "I am coughing a lot",
    "My stomach area has been itching like crazy",
    "I can't stop sneezing and my nose is blocked",
    "Having terrible pounding in my temples and feel nauseous",
    "My back is really painful",
    "I feel dizzy and want to throw up"
]

print("\n" + "="*80)
print("COMPARISON: BASE MODEL vs FINE-TUNED MODEL")
print("="*80)

for test_text in test_cases:
    print(f"\n📝 Input: \"{test_text}\"")
    print("-" * 80)
    
    # Base model results
    base_results = extract_symptoms_comparison(
        test_text, base_model, base_symptom_embeddings, SYMPTOMS, threshold=0.45
    )
    print("❌ BASE MODEL:")
    if base_results:
        for r in base_results:
            print(f"   • {r['symptom']} (confidence: {r['score']:.3f})")
    else:
        print("   • No symptoms detected")
    
    # Fine-tuned model results
    finetuned_results = extract_symptoms_comparison(
        test_text, finetuned_model, finetuned_symptom_embeddings, SYMPTOMS, threshold=0.45
    )
    print("\n✅ FINE-TUNED MODEL:")
    if finetuned_results:
        for r in finetuned_results:
            print(f"   • {r['symptom']} (confidence: {r['score']:.3f})")
    else:
        print("   • No symptoms detected")
    
    # Calculate improvement
    improvement = len(finetuned_results) - len(base_results)
    if improvement > 0:
        print(f"\n💡 Improvement: +{improvement} more symptoms detected")
    elif improvement < 0:
        print(f"\n⚠️ Note: {abs(improvement)} fewer symptoms (may be more precise)")
    
    print()

In [None]:
# Quantitative evaluation on validation set
print("\n" + "="*80)
print("QUANTITATIVE EVALUATION ON VALIDATION SET")
print("="*80)

def evaluate_model(model, symptom_embeddings, val_examples, threshold=0.5):
    """
    Evaluate model accuracy on validation set
    """
    correct = 0
    total = len(val_examples)
    
    for example in tqdm(val_examples[:1000], desc="Evaluating"):  # Sample 1000 for speed
        text = example.texts[0]
        target_symptom = example.texts[1]
        true_label = example.label
        
        # Get prediction
        text_embedding = model.encode(text, convert_to_tensor=True)
        symptom_idx = SYMPTOMS.index(target_symptom)
        similarity = float(util.cos_sim(text_embedding, symptom_embeddings[symptom_idx]))
        
        predicted_label = 1.0 if similarity > threshold else 0.0
        
        if predicted_label == true_label:
            correct += 1
    
    accuracy = correct / min(1000, total)
    return accuracy

print("\nEvaluating base model...")
base_accuracy = evaluate_model(base_model, base_symptom_embeddings, val_data)
print(f"Base model accuracy: {base_accuracy:.2%}")

print("\nEvaluating fine-tuned model...")
finetuned_accuracy = evaluate_model(finetuned_model, finetuned_symptom_embeddings, val_data)
print(f"Fine-tuned model accuracy: {finetuned_accuracy:.2%}")

improvement_pct = ((finetuned_accuracy - base_accuracy) / base_accuracy) * 100
print(f"\n🎯 Improvement: {improvement_pct:+.1f}%")

---
## 6. Save Model & Metadata <a id='save'></a>

In [None]:
# Save metadata
metadata = {
    'base_model': base_model_name,
    'training_examples': len(train_data),
    'validation_examples': len(val_data),
    'num_symptoms': len(SYMPTOMS),
    'epochs': NUM_EPOCHS,
    'base_accuracy': float(base_accuracy),
    'finetuned_accuracy': float(finetuned_accuracy),
    'improvement_pct': float(improvement_pct)
}

with open(f'{output_path}/training_metadata.json', 'w') as f:
    json.dump(metadata, f, indent=2)

print(f"✓ Metadata saved to {output_path}/training_metadata.json")
print(f"\n📦 Model package contents:")
print(f"   {output_path}/")
for item in os.listdir(output_path):
    print(f"   ├── {item}")

---
## 7. Quick Test Interface <a id='test'></a>

Interactive testing of your fine-tuned model.

In [None]:
def test_symptom_extraction(text, threshold=0.45, top_k=5):
    """
    Test the fine-tuned model on custom input
    """
    print(f"\n{'='*60}")
    print(f"Input: \"{text}\"")
    print(f"{'='*60}")
    
    results = extract_symptoms_comparison(
        text, finetuned_model, finetuned_symptom_embeddings, 
        SYMPTOMS, threshold=threshold, top_k=top_k
    )
    
    if results:
        print(f"\n✅ Detected {len(results)} symptom(s):\n")
        for i, r in enumerate(results, 1):
            confidence_bar = '█' * int(r['score'] * 20)
            print(f"  {i}. {r['symptom']}")
            print(f"     Confidence: {r['score']:.3f} {confidence_bar}")
    else:
        print("\n❌ No symptoms detected above threshold")
        print(f"   Try lowering threshold (currently {threshold})")
    
    return results

# Test it!
test_symptom_extraction("I've been feeling dizzy and have a terrible headache")
test_symptom_extraction("My throat is sore and I can't stop coughing")
test_symptom_extraction("I have chest pain and shortness of breath")

In [None]:
# Interactive testing (optional - uncomment to use)
# while True:
#     user_input = input("\nDescribe your symptoms (or 'quit' to exit): ")
#     if user_input.lower() in ['quit', 'exit', 'q']:
#         break
#     test_symptom_extraction(user_input)

---
## ✅ Training Complete!

### Next Steps:

1. **Your fine-tuned model is saved at:** `models/medical_symptom_matcher/`

2. **To use it in your chatbot notebook:**
```python
from sentence_transformers import SentenceTransformer
model = SentenceTransformer('models/medical_symptom_matcher')
```

3. **Integration code is ready** - just copy the `extract_symptoms_comparison()` function to your main chatbot!

### Performance Summary:
- ✅ Base model accuracy: {base_accuracy:.1%}
- ✅ Fine-tuned accuracy: {finetuned_accuracy:.1%}
- ✅ Improvement: {improvement_pct:+.1f}%

### Tips:
- Adjust `threshold` parameter to control sensitivity (lower = more symptoms detected)
- If you need better performance, train with more examples or more epochs
- The model works best with complete sentences (not just single words)