# 🧠 Embedding-Based Evaluation Deep Dive

## Overview

The embedding-based evaluation approach converts medical claims into vector representations and trains machine learning classifiers to predict binary outcomes. This method leverages the semantic understanding captured in your MGPT model's embeddings.

## 🔄 Process Flow

```
📊 Medical Claims Text
     ↓
🔗 MGPT Model API (/embeddings_batch)
     ↓
📊 Vector Embeddings (e.g., 768-dimensional)
     ↓
🤖 Train ML Classifiers
     ├── Logistic Regression
     ├── Support Vector Machine  
     └── Random Forest
     ↓
📈 Cross-Validation & Hyperparameter Tuning
     ↓
🎯 Best Model Selection
     ↓
📋 Test Set Evaluation
```

## 🚀 Step-by-Step Walkthrough

### Step 1: Embedding Generation

**What Happens:**
- Medical claims are sent to your MGPT model's `/embeddings_batch` endpoint
- Each claim is converted to a fixed-size vector (e.g., 768 dimensions)
- Embeddings capture semantic meaning of medical code sequences

**Example API Call:**
```json
POST /embeddings_batch
{
  "texts": [
    "N6320 G0378 |eoc| Z91048 M1710",
    "E119 76642 |eoc| K9289 O0903"
  ]
}
```

**Response:**
```json
{
  "embeddings": [
    [0.1, -0.3, 0.7, ...],  # 768 dimensions
    [0.2, 0.1, -0.4, ...]   # 768 dimensions
  ]
}
```

### Step 2: Data Preparation

**Input Format:**
```python
# After embedding generation, data looks like:
{
  "mcids": ["CLAIM_123", "CLAIM_456", ...],
  "labels": [1, 0, 1, 0, ...],
  "embeddings": [
    [0.1, -0.3, 0.7, ...],  # Embedding for CLAIM_123
    [0.2, 0.1, -0.4, ...],  # Embedding for CLAIM_456
    ...
  ]
}
```

**Train/Test Split:**
- Data is split into training (80%) and test (20%) sets
- Stratified splitting maintains label distribution
- MCIDs are preserved for result tracking

### Step 3: Classifier Training

**Three Classifier Types:**

#### 1. Logistic Regression
- **Best for**: Interpretable results, fast training
- **Hyperparameters tuned**: C (regularization), penalty (L1/L2), solver
- **Advantages**: Provides probability scores, feature importance

#### 2. Support Vector Machine (SVM)
- **Best for**: High-dimensional data, robust to overfitting
- **Hyperparameters tuned**: C, kernel (RBF/linear), gamma
- **Advantages**: Effective with limited data, memory efficient

#### 3. Random Forest
- **Best for**: Robust predictions, handling noisy data
- **Hyperparameters tuned**: n_estimators, max_depth, min_samples_split
- **Advantages**: Feature importance, resistant to overfitting

**Cross-Validation Process:**
```python
# 5-fold cross-validation with grid search
for classifier in [LogisticRegression, SVM, RandomForest]:
    for hyperparameter_combination in grid:
        cv_scores = cross_validate(classifier, X_train, y_train, cv=5)
        # Select best combination based on ROC-AUC
```

### Step 4: Model Evaluation

**Metrics Calculated:**

| Metric | Formula | Interpretation |
|--------|---------|----------------|
| **Accuracy** | (TP + TN) / (TP + TN + FP + FN) | Overall correctness |
| **Precision** | TP / (TP + FP) | Of predicted positives, how many are correct |
| **Recall** | TP / (TP + FN) | Of actual positives, how many were found |
| **F1-Score** | 2 × (Precision × Recall) / (Precision + Recall) | Harmonic mean of precision/recall |
| **ROC-AUC** | Area under ROC curve | Ability to distinguish classes |

**Confusion Matrix:**
```
                Predicted
              0      1
Actual   0   TN     FP
         1   FN     TP
```

## 📊 Configuration Options

### Embedding Generation Settings

```yaml
embedding_generation:
  batch_size: 16                    # Claims per API request
  save_interval: 100                # Checkpoint frequency
  max_sequence_length: 512          # Token limit per claim
  output_format: "json"             # json or csv
  resume_from_checkpoint: true      # Resume if interrupted
```

### Classification Settings

```yaml
classification:
  models: ["logistic_regression", "svm", "random_forest"]
  
  cross_validation:
    n_folds: 5                      # CV folds
    scoring: "roc_auc"              # Optimization metric
    n_jobs: -1                      # Parallel jobs
  
  hyperparameter_search:
    logistic_regression:
      C: [0.001, 0.01, 0.1, 1, 10, 100]
      penalty: ["l1", "l2"]
      solver: ["liblinear", "saga"]
```

## 🎯 Practical Example

### Scenario: Diabetes Code Prediction

**Goal**: Predict if a medical claim contains diabetes-related codes

**Sample Data:**
```csv
mcid,claims,label
C001,"E119 Z9981 |eoc| N189 M549",1  # Contains E119 (diabetes)
C002,"K592 G9340 |eoc| R50 M255",0   # No diabetes codes
C003,"E1022 Z794 |eoc| N183",1       # Contains E1022 (diabetes)
```

**Expected Workflow:**
1. Generate embeddings for all claims
2. Train classifiers to predict label (diabetes present=1, absent=0)
3. Evaluate which classifier performs best
4. Use best model for future diabetes prediction

**Configuration:**
```yaml
# Use 02_from_embeddings.yaml template
input:
  dataset_path: "data/diabetes_claims.csv"
  split_ratio: 0.8

pipeline_stages:
  embeddings: true
  classification: true
  evaluation: true
  target_word_eval: false  # Focus on embedding approach
```

## 📈 Interpreting Results

### Model Comparison Output

```json
{
  "logistic_regression": {
    "accuracy": 0.85,
    "precision": 0.82,
    "recall": 0.88,
    "f1_score": 0.85,
    "roc_auc": 0.91
  },
  "svm": {
    "accuracy": 0.87,
    "precision": 0.84,
    "recall": 0.90,
    "f1_score": 0.87,
    "roc_auc": 0.93
  },
  "random_forest": {
    "accuracy": 0.83,
    "precision": 0.80,
    "recall": 0.86,
    "f1_score": 0.83,
    "roc_auc": 0.89
  },
  "best_model": "svm"  # Highest ROC-AUC
}
```

### What This Tells You:
- **SVM performs best** with 93% ROC-AUC
- **High recall (90%)** means few diabetes cases are missed
- **Good precision (84%)** means predictions are reliable
- **Your MGPT embeddings** effectively capture diabetes-related patterns

## ⚙️ Advanced Configuration

### Performance Optimization

```yaml
# For large datasets
embedding_generation:
  batch_size: 32                    # Increase if server can handle
  save_interval: 50                 # More frequent checkpoints

model_api:
  batch_size: 64                    # Larger API batches
  timeout: 600                      # Longer timeout

classification:
  cross_validation:
    n_jobs: -1                      # Use all CPU cores
```

### Memory Management

```yaml
# For memory-constrained environments
data_processing:
  output_format: "csv"              # More memory efficient

embedding_generation:
  batch_size: 8                     # Smaller batches
  save_interval: 25                 # Frequent saves
```

## 🚨 Common Issues & Solutions

### Issue 1: Low Performance Scores
**Symptoms**: All classifiers get <70% accuracy
**Possible Causes**:
- Insufficient training data
- Poor label quality
- Model embeddings don't capture relevant patterns

**Solutions**:
- Increase dataset size
- Review label definitions
- Check if MGPT model is properly trained on medical data

### Issue 2: Overfitting
**Symptoms**: High training accuracy, low test accuracy
**Solutions**:
- Increase regularization (lower C values)
- Use more cross-validation folds
- Reduce model complexity

### Issue 3: Imbalanced Classes
**Symptoms**: High accuracy but poor recall for minority class
**Solutions**:
- The pipeline automatically handles class imbalance
- Focus on F1-score and ROC-AUC rather than accuracy
- Consider collecting more minority class examples

## 🔗 Next Steps

- **[03_Target_Word_Evaluation.ipynb](03_Target_Word_Evaluation.ipynb)** - Learn the alternative evaluation approach
- **[05_Results_Analysis.ipynb](05_Results_Analysis.ipynb)** - Deep dive into result interpretation
- **[07_Advanced_Usage.ipynb](07_Advanced_Usage.ipynb)** - Production deployment tips

### Quick Commands:

**Embeddings Only:**
```bash
python main.py run-all --config configs/templates/01_embeddings_only.yaml
```

**Classification from Existing Embeddings:**
```bash
python main.py run-all --config configs/templates/02_from_embeddings.yaml
```