# Notebook 00: Data-Driven Specialty Taxonomy

## 🎯 What This Notebook Does

**Goal:** Transform 5,126 fine-grained `focus_area` labels into ~15-20 high-level medical specialties using embeddings + clustering.

**Why?** The raw dataset has too many categories (5,126) with too few examples each (most < 10 samples). This makes classification impossible. By grouping related conditions into medical specialties (Oncology, Cardiology, etc.), we create a manageable multi-class problem.

**Approach:** Bottom-up, data-driven taxonomy construction:
1. **Embed** `focus_area` strings → semantic vectors (using BioBERT)
2. **Cluster** vectors → discover natural groupings
3. **Inspect** clusters → assign human-interpretable specialty labels
4. **Validate** with zero-shot classification
5. **Export** mapping + labeled dataframe

**End Product:** A new column `df["specialty"]` with ~15-20 balanced medical categories.

---

## 📋 Key Phases

1. ✅ Project setup & data loading
2. ✅ Deduplicate focus_area catalog
3. ✅ Embed focus_areas with BioBERT
4. ✅ Dimensionality reduction (UMAP) for visualization
5. ✅ Cluster with HDBSCAN + k-means
6. ✅ Inspect clusters & assign specialty names
7. ✅ Validate with zero-shot classification
8. ✅ Merge back to full dataframe
9. ✅ Export final taxonomy

---

## ✅ Acceptance Criteria

- [ ] Created 15-20 medical specialties from 5,126 focus_areas
- [ ] Each specialty has ≥50 samples (ideally ≥100)
- [ ] Clusters are semantically coherent (validated by inspection + zero-shot)
- [ ] Exported: `df_with_specialty.parquet` + `focus_area_to_specialty.json`
- [ ] Documented: Clustering params, model choices, manual decisions

## 🔧 Setup: Project Paths & Environment

**TODO 1:** Configure paths and create output directories.

**Hints:**
- Point to your MedQuad CSV (likely in `../../../datasets/medquad.csv`)
- Create an artifacts folder to save embeddings, clusters, and mappings
- Use `pathlib.Path` for cross-platform compatibility

In [None]:
# TODO 1: Set up project paths & environment

from pathlib import Path
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
import seaborn as sns

# Define paths
# DATA_DIR = Path("../../../datasets")
# INPUT_PATH = DATA_DIR / "medquad.csv"
# ARTIFACTS_DIR = Path("../artifacts/specialty_taxonomy")

# Create output directory
# ARTIFACTS_DIR.mkdir(parents=True, exist_ok=True)

# print(f"✅ Input data: {INPUT_PATH}")
# print(f"✅ Artifacts will be saved to: {ARTIFACTS_DIR}")

## 📊 Load Data & Basic EDA

**TODO 2:** Load the MedQuad dataset and perform basic checks.

**What to check:**
- Dataset shape and columns
- Presence of , , 
- Number of unique focus_areas
- Top 20 most frequent focus_areas
- Any missing values

**Hints:**
- Create a normalized version of  (strip whitespace, handle case)
- Keep  and  for later sanity checks

In [None]:
# TODO 2: Load data & basic EDA

# Load dataframe
# df = pd.read_csv(INPUT_PATH)

# Basic info
# print("Dataset shape:", df.shape)
# print("\nColumns:", list(df.columns))
# print("\nFirst few rows:")
# display(df.head())

# Check for required columns
# required_cols = ['focus_area', 'question', 'answer']
# assert all(col in df.columns for col in required_cols), f"Missing columns! Need: {required_cols}"

# Normalize focus_area for consistency
# df['focus_area_norm'] = df['focus_area'].str.strip()

# Remove any rows with missing focus_area
# df = df.dropna(subset=['focus_area_norm'])
# print(f"\n✅ After removing NaN focus_areas: {df.shape[0]} rows")

# Unique focus_areas
# n_unique = df['focus_area_norm'].nunique()
# print(f"\n📊 Number of unique focus_areas: {n_unique}")

# Top 20 most common
# print("\nTop 20 most frequent focus_areas:")
# print(df['focus_area_norm'].value_counts().head(20))

## 🗂️ Build Unique Focus_Area Catalog

**TODO 3:** Create a catalog of unique focus_areas with their counts.

**Why deduplicate?**
- Clustering 5,126 unique strings is much faster than clustering 16,412 rows
- Avoids over-weighting duplicates
- Easier to inspect and label

**Hints:**
- Use `value_counts()` to get unique focus_areas with frequencies
- Keep the counts - they'll help prioritize important clusters
- Store as a list for embedding

In [None]:
# TODO 3: Build unique focus_area catalog

# Get unique focus_areas with counts
# cat_counts = df['focus_area_norm'].value_counts()
# catalog = cat_counts.index.tolist()

# print(f"✅ Catalog size: {len(catalog)} unique focus_areas")
# print(f"   Total samples: {cat_counts.sum()}")
# print(f"   Most common: {cat_counts.iloc[0]} ('{catalog[0]}')")
# print(f"   Least common: {cat_counts.iloc[-1]} ('{catalog[-1]}')")
# print(f"   Median count: {cat_counts.median():.0f}")

# Distribution of counts
# plt.figure(figsize=(10, 4))
# plt.subplot(1, 2, 1)
# plt.hist(cat_counts, bins=50, edgecolor='black')
# plt.xlabel('Number of samples per focus_area')
# plt.ylabel('Frequency')
# plt.title('Distribution of Focus_Area Counts')

# plt.subplot(1, 2, 2)
# plt.hist(np.log10(cat_counts + 1), bins=50, edgecolor='black')
# plt.xlabel('Log10(count + 1)')
# plt.ylabel('Frequency')
# plt.title('Log-scale Distribution')
# plt.tight_layout()
# plt.savefig(ARTIFACTS_DIR / 'focus_area_distribution.png', dpi=150)
# plt.show()

## 🧬 Embed Focus_Areas with BioBERT

**TODO 4:** Convert focus_area strings to semantic vectors.

**Model Choice:**
- **Recommended:** `"pritamdeka/BioBERT-mnli-snli-scitail-mednli"` (medical domain)
- **Alternative:** `"all-mpnet-base-v2"` (general, strong performance)

**Why embeddings?**
- Capture semantic meaning ("Breast Cancer" ≈ "Lung Cancer" ≈ "Prostate Cancer")
- Enable clustering by similarity
- Standard NLP technique

**Hints:**
- Use `sentence-transformers` library
- Batch encode for efficiency
- **SAVE embeddings to disk** - expensive to recompute!
- Embeddings shape: `(n_catalog, embedding_dim)` e.g., `(5126, 768)`

In [None]:
# TODO 4: Embed focus_areas with BioBERT

# from sentence_transformers import SentenceTransformer

# Initialize model
# print("Loading embedding model...")
# model = SentenceTransformer("pritamdeka/BioBERT-mnli-snli-scitail-mednli")
# # Alternative: model = SentenceTransformer("all-mpnet-base-v2")

# Encode catalog (this may take a few minutes)
# print(f"Encoding {len(catalog)} focus_areas...")
# embeddings = model.encode(
#     catalog,
#     batch_size=64,
#     convert_to_numpy=True,
#     show_progress_bar=True
# )

# print(f"✅ Embeddings shape: {embeddings.shape}")

# Save to disk
# np.save(ARTIFACTS_DIR / "catalog_embeddings.npy", embeddings)
# print(f"✅ Saved embeddings to {ARTIFACTS_DIR / 'catalog_embeddings.npy'}")

# # Load embeddings (for future runs)
# # embeddings = np.load(ARTIFACTS_DIR / "catalog_embeddings.npy")

## 📉 Dimensionality Reduction: UMAP for Visualization

**TODO 5:** Reduce embeddings to 2D for visualization (optional but highly recommended).

**Why UMAP?**
- Visualize cluster structure before clustering
- Spot natural groupings and outliers
- Helps tune clustering hyperparameters
- Beautiful plots!

**Hints:**
- Use `metric='cosine'` for text embeddings
- `n_neighbors=15-30` balances local/global structure
- Save 2D coordinates for later plotting

In [None]:
# TODO 5: UMAP dimensionality reduction

# from umap import UMAP

# Fit UMAP
# print("Running UMAP (this may take a few minutes)...")
# umap_model = UMAP(
#     n_components=2,
#     metric='cosine',
#     n_neighbors=15,
#     min_dist=0.1,
#     random_state=42
# )
# coords_2d = umap_model.fit_transform(embeddings)

# print(f"✅ UMAP coordinates shape: {coords_2d.shape}")

# Save coordinates
# np.save(ARTIFACTS_DIR / "coords_2d.npy", coords_2d)

# Quick visualization (before clustering)
# plt.figure(figsize=(12, 8))
# plt.scatter(coords_2d[:, 0], coords_2d[:, 1], s=5, alpha=0.5)
# plt.title('UMAP Projection of Focus_Area Embeddings')
# plt.xlabel('UMAP 1')
# plt.ylabel('UMAP 2')
# plt.savefig(ARTIFACTS_DIR / 'umap_before_clustering.png', dpi=150)
# plt.show()

# print("💡 Look for natural clusters/islands in the plot above!")

## 🎯 Clustering Experiment #1: HDBSCAN

**TODO 6A:** Cluster with HDBSCAN (automatically determines number of clusters).

**HDBSCAN Advantages:**
- No need to specify number of clusters
- Finds dense regions
- Labels noise points as -1
- Good for irregular cluster shapes

**Hyperparameters to tune:**
- `min_cluster_size`: Minimum samples per cluster (try 20-60)
- `min_samples`: Affects how conservative clustering is
- `metric`: 'euclidean' works well with normalized embeddings

**Hints:**
- Normalize embeddings first (L2 norm) if using euclidean metric
- Start with `min_cluster_size=40`
- Some points will be labeled as noise (-1) - that's OK!

In [None]:
# TODO 6A: HDBSCAN clustering

# import hdbscan
# from sklearn.preprocessing import normalize

# Normalize embeddings for better clustering
# embeddings_norm = normalize(embeddings, norm='l2', axis=1)

# Run HDBSCAN
# print("Running HDBSCAN...")
# clusterer_hdb = hdbscan.HDBSCAN(
#     min_cluster_size=40,
#     min_samples=5,
#     metric='euclidean',
#     cluster_selection_method='eom'
# )
# labels_hdb = clusterer_hdb.fit_predict(embeddings_norm)

# Analyze results
# n_clusters_hdb = len(set(labels_hdb)) - (1 if -1 in labels_hdb else 0)
# n_noise = list(labels_hdb).count(-1)

# print(f"\n✅ HDBSCAN Results:")
# print(f"   Number of clusters: {n_clusters_hdb}")
# print(f"   Noise points: {n_noise} ({100*n_noise/len(labels_hdb):.1f}%)")

# Save labels
# np.save(ARTIFACTS_DIR / "labels_hdbscan.npy", labels_hdb)

# Visualize clusters
# plt.figure(figsize=(12, 8))
# scatter = plt.scatter(coords_2d[:, 0], coords_2d[:, 1], 
#                       c=labels_hdb, s=5, alpha=0.5, cmap='tab20')
# plt.colorbar(scatter, label='Cluster ID')
# plt.title(f'HDBSCAN Clusters (n={n_clusters_hdb})')
# plt.xlabel('UMAP 1')
# plt.ylabel('UMAP 2')
# plt.savefig(ARTIFACTS_DIR / 'hdbscan_clusters.png', dpi=150)
# plt.show()

## 🎯 Clustering Experiment #2: K-Means

**TODO 6B:** Cluster with k-means (you specify number of clusters).

**K-Means Advantages:**
- Simple and fast
- Forces assignment (no noise points)
- You control number of clusters

**Choosing K:**
- Start with 30-50 clusters (fine-grained)
- Plan to merge similar clusters later
- Medical specialties typically: 15-25 major categories

**Hints:**
- Use `n_init='auto'` for automatic initialization
- Compare with HDBSCAN results
- You'll likely use k-means results (easier to label)

In [None]:
# TODO 6B: K-Means clustering

# from sklearn.cluster import KMeans

# Choose K (start with 40)
# K = 40

# Run k-means
# print(f"Running k-means with K={K}...")
# km = KMeans(
#     n_clusters=K,
#     random_state=42,
#     n_init='auto'
# )
# labels_km = km.fit_predict(embeddings_norm)

# print(f"\n✅ K-Means Results:")
# print(f"   Number of clusters: {K}")
# print(f"   Cluster sizes: min={np.bincount(labels_km).min()}, max={np.bincount(labels_km).max()}")

# Save labels
# np.save(ARTIFACTS_DIR / "labels_kmeans.npy", labels_km)

# Visualize clusters
# plt.figure(figsize=(12, 8))
# scatter = plt.scatter(coords_2d[:, 0], coords_2d[:, 1], 
#                       c=labels_km, s=5, alpha=0.5, cmap='tab20')
# plt.colorbar(scatter, label='Cluster ID')
# plt.title(f'K-Means Clusters (K={K})')
# plt.xlabel('UMAP 1')
# plt.ylabel('UMAP 2')
# plt.savefig(ARTIFACTS_DIR / 'kmeans_clusters.png', dpi=150)
# plt.show()

## 🔍 Inspect Clusters: Build Summary Table

**TODO 7:** Create a DataFrame with cluster assignments and inspect each cluster.

**What to inspect:**
- Top 50 most frequent focus_areas per cluster
- Total samples per cluster
- Look for common themes (cancer types, heart conditions, eye diseases, etc.)

**Decision:** Choose which clustering to use (HDBSCAN or k-means)
- **Recommendation:** Start with k-means (easier to assign all points)

**Hints:**
- Create a catalog DataFrame: `[focus_area, count, cluster_label, coords_2d]`
- Loop through clusters and display top members
- Save per-cluster CSVs for detailed review

In [None]:
# TODO 7: Build cluster summary table

# Choose which clustering to use (k-means recommended)
# cluster_labels = labels_km  # or labels_hdb

# Build catalog DataFrame
# catalog_df = pd.DataFrame({
#     'focus_area': catalog,
#     'count': [cat_counts[fa] for fa in catalog],
#     'cluster': cluster_labels,
#     'umap_x': coords_2d[:, 0],
#     'umap_y': coords_2d[:, 1]
# })

# print(f"✅ Catalog DataFrame shape: {catalog_df.shape}")
# display(catalog_df.head())

# Save full catalog with clusters
# catalog_df.to_csv(ARTIFACTS_DIR / 'catalog_with_clusters.csv', index=False)

# Print summary per cluster
# print("\n📊 Cluster Summaries:")
# for c in sorted(catalog_df['cluster'].unique()):
#     if c == -1:  # Skip noise if using HDBSCAN
#         continue
#     cluster_data = catalog_df[catalog_df['cluster'] == c].sort_values('count', ascending=False)
#     total_samples = cluster_data['count'].sum()
#     print(f"\n{'='*80}")
#     print(f"Cluster {c}: {len(cluster_data)} unique focus_areas, {total_samples} total samples")
#     print(f"{'='*80}")
#     print(cluster_data[['focus_area', 'count']].head(20).to_string(index=False))
#     print()

## 🏷️ Manual Labeling: Assign Specialty Names to Clusters

**TODO 8:** Create a mapping from `cluster_id` → `specialty_name`.

**Common Medical Specialties:**
- **Oncology** - All cancers (breast, lung, prostate, colorectal, skin, etc.)
- **Cardiology** - Heart conditions (heart failure, heart attack, cholesterol, blood pressure)
- **Neurology** - Brain conditions (stroke, Alzheimer's, Parkinson's, headaches)
- **Ophthalmology** - Eye conditions (glaucoma, macular degeneration, cataracts)
- **Endocrinology** - Hormone/metabolic (diabetes, thyroid, obesity)
- **Pulmonology** - Lung conditions (COPD, asthma, lung cancer)
- **Gastroenterology** - Digestive system
- **Nephrology** - Kidney conditions
- **Dermatology** - Skin conditions
- **Orthopedics** - Bone/joint conditions
- **Infectious Disease** - Infections (HIV, flu, etc.)
- **Mental Health** - Psychology/psychiatry
- **General Medicine** - Primary care, preventive health
- **Other** - Miscellaneous/uncategorized

**Strategy:**
1. Start with obvious clusters (e.g., cluster with all cancers → "Oncology")
2. Merge similar small clusters (e.g., heart-related clusters → "Cardiology")
3. Create "Other" for truly mixed clusters
4. Aim for 15-20 final specialties

**Hints:**
- Iterate! This is not one-and-done
- Save mapping to JSON for reproducibility
- Document your decisions

In [None]:
# TODO 8: Create cluster → specialty mapping

# Manual mapping based on cluster inspection above
# cluster_to_specialty = {
#     0: "Oncology",
#     1: "Cardiology",
#     2: "Neurology",
#     3: "Ophthalmology",
#     4: "Endocrinology",
#     5: "Pulmonology",
#     6: "Gastroenterology",
#     7: "Nephrology",
#     8: "Dermatology",
#     9: "Orthopedics",
#     10: "Infectious Disease",
#     11: "Mental Health",
#     12: "General Medicine",
#     # ... continue for all clusters ...
#     -1: "Other"  # For noise (if using HDBSCAN)
# }

# TODO: Fill in the mapping based on your cluster inspection!
# Inspect the output from TODO 7 above and assign meaningful names

# Save mapping
# with open(ARTIFACTS_DIR / 'cluster_to_specialty.json', 'w') as f:
#     json.dump(cluster_to_specialty, f, indent=2)

# print(f"✅ Created mapping for {len(cluster_to_specialty)} clusters")
# print(f"   Unique specialties: {len(set(cluster_to_specialty.values()))}")
# print("\nSpecialty distribution:")
# from collections import Counter
# print(Counter(cluster_to_specialty.values()))

## 🔄 Apply Specialty Labels to Catalog

**TODO 9:** Map cluster labels to specialty names in the catalog DataFrame.

**Steps:**
1. Add `specialty` column using `cluster_to_specialty` mapping
2. Handle any unmapped clusters → "Other"
3. Compute per-specialty sample counts
4. Check for class balance

**Goal:** Ensure each specialty has ≥50 samples (ideally ≥100)

**Hints:**
- Use `df['cluster'].map(cluster_to_specialty)`
- Fill missing values with "Other"
- Visualize specialty distribution

In [None]:
# TODO 9: Apply specialty labels to catalog

# Map clusters to specialties
# catalog_df['specialty'] = catalog_df['cluster'].map(cluster_to_specialty)

# Fill any missing with "Other"
# catalog_df['specialty'] = catalog_df['specialty'].fillna('Other')

# Compute per-specialty stats
# specialty_counts = catalog_df.groupby('specialty').agg({
#     'focus_area': 'count',  # Number of unique focus_areas
#     'count': 'sum'          # Total samples
# }).rename(columns={'focus_area': 'n_unique_focus_areas', 'count': 'n_samples'})
# specialty_counts = specialty_counts.sort_values('n_samples', ascending=False)

# print("\n📊 Specialty Distribution:")
# print(specialty_counts)

# Check minimum samples
# min_samples = specialty_counts['n_samples'].min()
# print(f"\n✅ Minimum samples per specialty: {min_samples}")
# if min_samples < 50:
#     print("⚠️  Warning: Some specialties have <50 samples. Consider merging!")

# Save catalog with specialties
# catalog_df.to_csv(ARTIFACTS_DIR / 'catalog_with_specialty.csv', index=False)

# Visualize distribution
# plt.figure(figsize=(12, 6))
# specialty_counts['n_samples'].plot(kind='bar')
# plt.title('Samples per Specialty')
# plt.xlabel('Specialty')
# plt.ylabel('Number of Samples')
# plt.xticks(rotation=45, ha='right')
# plt.axhline(y=100, color='r', linestyle='--', label='Target: 100 samples')
# plt.legend()
# plt.tight_layout()
# plt.savefig(ARTIFACTS_DIR / 'specialty_distribution.png', dpi=150)
# plt.show()

## 🔄 Merge Specialty Mapping Back to Full DataFrame

**TODO 10:** Expand the focus_area → specialty mapping to all 16k+ rows.

**Steps:**
1. Merge `df` with `catalog_df` on `focus_area_norm`
2. Every row now has a `specialty` column
3. Verify no missing specialties (or fill with "Other")
4. Save the enriched dataframe

**Hints:**
- Use a left join to keep all original rows
- Check for any NaN specialties
- Save as parquet for efficiency

In [None]:
# TODO 10: Merge specialty labels back to full dataframe

# Select columns to merge
# mapping_df = catalog_df[['focus_area', 'specialty', 'cluster']]

# Merge
# df = df.merge(
#     mapping_df,
#     left_on='focus_area_norm',
#     right_on='focus_area',
#     how='left',
#     suffixes=('', '_mapped')
# )

# # Drop duplicate focus_area column if created
# if 'focus_area_mapped' in df.columns:
#     df = df.drop(columns=['focus_area_mapped'])

# # Check for missing specialties
# missing_specialty = df['specialty'].isna().sum()
# if missing_specialty > 0:
#     print(f"⚠️  Warning: {missing_specialty} rows missing specialty. Filling with 'Other'.")
#     df['specialty'] = df['specialty'].fillna('Other')

# print(f"\n✅ Full dataframe shape: {df.shape}")
# print(f"   Columns: {list(df.columns)}")

# # Verify specialty distribution
# print("\n📊 Final specialty distribution:")
# print(df['specialty'].value_counts())

# # Save enriched dataframe
# df.to_parquet(ARTIFACTS_DIR / 'df_with_specialty.parquet', index=False)
# print(f"\n✅ Saved enriched dataframe to {ARTIFACTS_DIR / 'df_with_specialty.parquet'}")

## ✅ Validate with Zero-Shot Classification (Optional but Recommended)

**TODO 11:** Use zero-shot classification as a second opinion.

**Why zero-shot validation?**
- Independent check on your manual labeling
- Flags potential misclassifications
- Identifies ambiguous cases

**Approach:**
1. Sample 200-500 focus_areas
2. Run zero-shot classifier with your specialty list as candidates
3. Compare `specialty` (from clustering) vs. zero-shot prediction
4. Inspect disagreements

**Hints:**
- Use `microsoft/deberta-v3-large-mnli` (strong zero-shot model)
- Focus on high-count focus_areas (more important to get right)
- This is slow - run on a sample!

In [None]:
# TODO 11: Zero-shot validation (optional)

# from transformers import pipeline

# # Initialize zero-shot classifier
# print("Loading zero-shot classifier (this may take a minute)...")
# zero_shot_clf = pipeline(
#     "zero-shot-classification",
#     model="microsoft/deberta-v3-large-mnli"
# )

# # Get unique specialties
# candidate_labels = sorted(catalog_df['specialty'].unique())
# print(f"\nCandidate specialties: {candidate_labels}")

# # Sample focus_areas for validation (prioritize high-count ones)
# sample_df = catalog_df.nlargest(200, 'count').copy()

# # Run zero-shot classification
# print(f"\nRunning zero-shot on {len(sample_df)} focus_areas...")
# zero_shot_predictions = []
# for idx, row in sample_df.iterrows():
#     result = zero_shot_clf(
#         row['focus_area'],
#         candidate_labels=candidate_labels,
#         multi_label=False
#     )
#     zero_shot_predictions.append(result['labels'][0])  # Top prediction

# sample_df['zero_shot_specialty'] = zero_shot_predictions

# # Compare with cluster-based specialty
# sample_df['agreement'] = sample_df['specialty'] == sample_df['zero_shot_specialty']
# agreement_rate = sample_df['agreement'].mean()

# print(f"\n✅ Agreement rate: {agreement_rate:.1%}")

# # Show disagreements
# disagreements = sample_df[~sample_df['agreement']].sort_values('count', ascending=False)
# print(f"\n⚠️  Disagreements ({len(disagreements)}):")
# print(disagreements[['focus_area', 'count', 'specialty', 'zero_shot_specialty']].head(20))

# # Save validation results
# sample_df.to_csv(ARTIFACTS_DIR / 'zero_shot_validation.csv', index=False)

## 📊 Final Quality Checks & Balance Analysis

**TODO 12:** Assess the final specialty distribution and decide on any final adjustments.

**Quality Checks:**
1. **Minimum samples:** Every specialty should have ≥50 samples
2. **Class balance:** Largest/smallest ratio should be <50:1
3. **Number of classes:** Should be 15-25 (too few = loss of granularity, too many = data scarcity)

**If problems exist:**
- Merge rare specialties into "Other" or related specialties
- Split very large specialties if semantically distinct

**Hints:**
- Create a threshold function to group rare classes
- Document any merging decisions

In [None]:
# TODO 12: Final quality checks and balance

# Compute final stats
# specialty_final = df['specialty'].value_counts().sort_values(ascending=False)
# n_specialties = len(specialty_final)
# min_samples = specialty_final.min()
# max_samples = specialty_final.max()
# imbalance_ratio = max_samples / min_samples

# print("📊 Final Taxonomy Stats:")
# print(f"   Number of specialties: {n_specialties}")
# print(f"   Total samples: {len(df)}")
# print(f"   Min samples per specialty: {min_samples}")
# print(f"   Max samples per specialty: {max_samples}")
# print(f"   Imbalance ratio: {imbalance_ratio:.1f}:1")

# print("\nPer-specialty counts:")
# print(specialty_final)

# # Check acceptance criteria
# print("\n✅ Acceptance Criteria Check:")
# print(f"   ✓ Min samples ≥50? {min_samples >= 50}")
# print(f"   ✓ Number of classes 15-25? {15 <= n_specialties <= 25}")
# print(f"   ✓ Imbalance ratio <50:1? {imbalance_ratio < 50}")

# # Optional: Apply minimum sample threshold
# MIN_SAMPLES = 50
# rare_specialties = specialty_final[specialty_final < MIN_SAMPLES].index.tolist()
# if rare_specialties:
#     print(f"\n⚠️  Rare specialties (<{MIN_SAMPLES} samples): {rare_specialties}")
#     print("   Consider merging these into 'Other' or related specialties.")
#     # df.loc[df['specialty'].isin(rare_specialties), 'specialty'] = 'Other'

# # Visualize final distribution
# plt.figure(figsize=(14, 6))
# specialty_final.plot(kind='barh')
# plt.xlabel('Number of Samples')
# plt.ylabel('Specialty')
# plt.title(f'Final Specialty Distribution (n={n_specialties} classes)')
# plt.axvline(x=100, color='g', linestyle='--', label='Target: 100+')
# plt.axvline(x=50, color='r', linestyle='--', label='Minimum: 50')
# plt.legend()
# plt.tight_layout()
# plt.savefig(ARTIFACTS_DIR / 'final_specialty_distribution.png', dpi=150)
# plt.show()

## 💾 Export Final Taxonomy & Metadata

**TODO 13:** Save all artifacts for use in downstream notebooks.

**What to export:**
1. `df_with_specialty.parquet` - Full dataframe with specialty column
2. `focus_area_to_specialty.json` - Direct mapping for quick lookup
3. `taxonomy_metadata.json` - Document all decisions and parameters
4. Summary README

**Hints:**
- Include clustering parameters in metadata
- Note any manual merging decisions
- Record date and model versions

In [None]:
# TODO 13: Export final taxonomy and metadata

# 1. Create focus_area → specialty mapping
# focus_area_to_specialty = catalog_df.set_index('focus_area')['specialty'].to_dict()
# with open(ARTIFACTS_DIR / 'focus_area_to_specialty.json', 'w') as f:
#     json.dump(focus_area_to_specialty, f, indent=2)

# 2. Create metadata
# from datetime import datetime
# metadata = {
#     'created_date': datetime.now().isoformat(),
#     'input_data': str(INPUT_PATH),
#     'n_original_focus_areas': len(catalog),
#     'n_final_specialties': len(df['specialty'].unique()),
#     'total_samples': len(df),
#     'embedding_model': 'pritamdeka/BioBERT-mnli-snli-scitail-mednli',
#     'clustering_method': 'kmeans',  # or 'hdbscan'
#     'clustering_params': {
#         'n_clusters': 40,  # if kmeans
#         # 'min_cluster_size': 40,  # if hdbscan
#     },
#     'specialty_list': sorted(df['specialty'].unique()),
#     'per_specialty_counts': df['specialty'].value_counts().to_dict(),
#     'manual_decisions': [
#         'Merged all cancer types into Oncology',
#         'Grouped heart conditions into Cardiology',
#         # Add your decisions here
#     ]
# }
# with open(ARTIFACTS_DIR / 'taxonomy_metadata.json', 'w') as f:
#     json.dump(metadata, f, indent=2)

# 3. Create summary README
# readme_text = f"""
# # Medical Specialty Taxonomy

# **Created:** {datetime.now().strftime('%Y-%m-%d')}

# ## Summary

# - Original focus_areas: {len(catalog)}
# - Final specialties: {len(df['specialty'].unique())}
# - Total samples: {len(df)}

# ## Method

# 1. Embedded focus_area strings using BioBERT
# 2. Clustered with k-means (K=40)
# 3. Manually labeled clusters as medical specialties
# 4. Validated with zero-shot classification
# 5. Merged to full dataframe

# ## Files

# - `df_with_specialty.parquet` - Full dataframe with specialty column
# - `focus_area_to_specialty.json` - Direct mapping
# - `taxonomy_metadata.json` - Detailed metadata
# - `catalog_with_specialty.csv` - Catalog with cluster assignments

# ## Specialties

# {df['specialty'].value_counts().to_string()}
# """
# with open(ARTIFACTS_DIR / 'README.md', 'w') as f:
#     f.write(readme_text)

# print("✅ Exported all artifacts to:", ARTIFACTS_DIR)
# print("\nFiles created:")
# print("  - df_with_specialty.parquet")
# print("  - focus_area_to_specialty.json")
# print("  - taxonomy_metadata.json")
# print("  - README.md")

## 🤔 Reflection

Answer these questions after completing the taxonomy:

1. **How many final specialties did you create? Is this reasonable?**

2. **What were the hardest clustering decisions? Why?**

3. **How did zero-shot validation help (if you ran it)?**

4. **What would you do differently if starting over?**

5. **Are there any specialties you're uncertain about?**

6. **What did you learn about medical domain knowledge from this exercise?**

**Your reflections:**

*Write your answers here*

## 📌 Summary

✅ Built data-driven specialty taxonomy  
✅ Reduced 5,126 focus_areas → 15-20 specialties  
✅ Each specialty has sufficient samples  
✅ Validated with zero-shot classification  
✅ Exported artifacts for downstream use

**Next:** `01_project_scope_and_data.ipynb` (now with manageable specialty categories!)

---

## 🎓 What You Learned

**NLP Skills:**
- Semantic embeddings with BioBERT
- Text clustering (HDBSCAN, k-means)
- Zero-shot classification for validation

**Data Science Skills:**
- Bottom-up taxonomy construction
- Handling extreme class imbalance (5,126 classes!)
- Dimensionality reduction (UMAP)
- Cluster inspection and labeling

**Domain Skills:**
- Medical specialty categorization
- Semantic grouping of conditions
- Balancing granularity vs. data availability

**Project Management:**
- Artifact management
- Reproducible workflows
- Documentation and metadata

---

*This notebook demonstrates advanced NLP techniques for real-world data challenges. The approach is generalizable to any multi-class problem with too many fine-grained labels!*