In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Embedding Analysis\n",
    "## Analyze embeddings generated from WHO dataset\n",
    "\n",
    "This notebook analyzes:\n",
    "- Embedding generation process\n",
    "- Vector space visualization\n",
    "- Clustering analysis\n",
    "- Similarity patterns"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import json\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "from sklearn.decomposition import PCA\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.cluster import KMeans\n",
    "from sklearn.metrics.pairwise import cosine_similarity\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "from src.embedding.model_loader import ModelLoader\n",
    "from src.embedding.embedder import EmbeddingGenerator\n",
    "from src.config import settings\n",
    "\n",
    "sns.set_style('whitegrid')\n",
    "plt.rcParams['figure.figsize'] = (12, 8)\n",
    "\n",
    "print(\"✓ Imports successful\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load Chunks and Embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load chunks\n",
    "chunks_path = Path('../data/processed/chunks.json')\n",
    "if chunks_path.exists():\n",
    "    with open(chunks_path, 'r') as f:\n",
    "        chunks = json.load(f)\n",
    "    print(f\"✓ Loaded {len(chunks)} chunks\")\n",
    "else:\n",
    "    print(\"❌ Chunks not found! Run: python scripts/build_index.py\")\n",
    "    chunks = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "embeddings_path = Path('../data/processed/embeddings.npy')\n",
    "if embeddings_path.exists():\n",
    "    embeddings = np.load(embeddings_path)\n",
    "    print(f\"✓ Loaded embeddings with shape: {embeddings.shape}\")\n",
    "    print(f\"  - Dimension: {embeddings.shape[1]}\")\n",
    "    print(f\"  - Number of vectors: {embeddings.shape[0]}\")\n",
    "else:\n",
    "    print(\"❌ Embeddings not found! Run: python scripts/build_index.py\")\n",
    "    embeddings = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Embedding Statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None:\n",
    "    print(\"Embedding Statistics:\")\n",
    "    print(f\"Mean: {embeddings.mean():.6f}\")\n",
    "    print(f\"Std:  {embeddings.std():.6f}\")\n",
    "    print(f\"Min:  {embeddings.min():.6f}\")\n",
    "    print(f\"Max:  {embeddings.max():.6f}\")\n",
    "    \n",
    "    # Check if normalized\n",
    "    norms = np.linalg.norm(embeddings, axis=1)\n",
    "    print(f\"\\nVector norms (should be ~1 if normalized):\")\n",
    "    print(f\"Mean norm: {norms.mean():.6f}\")\n",
    "    print(f\"Std norm:  {norms.std():.6f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Distribution of embedding values\n",
    "if embeddings is not None:\n",
    "    plt.figure(figsize=(14, 5))\n",
    "    \n",
    "    # Sample some dimensions for visualization\n",
    "    sample_dims = np.random.choice(embeddings.shape[1], 5, replace=False)\n",
    "    \n",
    "    for i, dim in enumerate(sample_dims):\n",
    "        plt.subplot(1, 5, i+1)\n",
    "        plt.hist(embeddings[:, dim], bins=50, alpha=0.7, color=f'C{i}')\n",
    "        plt.title(f'Dim {dim}')\n",
    "        plt.xlabel('Value')\n",
    "        if i == 0:\n",
    "            plt.ylabel('Frequency')\n",
    "    \n",
    "    plt.suptitle('Distribution of Sample Embedding Dimensions', fontsize=14, fontweight='bold', y=1.02)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Dimensionality Reduction - PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None:\n",
    "    print(\"Running PCA...\")\n",
    "    pca = PCA(n_components=50)\n",
    "    embeddings_pca = pca.fit_transform(embeddings)\n",
    "    \n",
    "    # Explained variance\n",
    "    cumsum_variance = np.cumsum(pca.explained_variance_ratio_)\n",
    "    \n",
    "    plt.figure(figsize=(12, 5))\n",
    "    \n",
    "    # Individual variance\n",
    "    plt.subplot(1, 2, 1)\n",
    "    plt.bar(range(1, 51), pca.explained_variance_ratio_[:50], alpha=0.7, color='skyblue')\n",
    "    plt.xlabel('Principal Component')\n",
    "    plt.ylabel('Explained Variance Ratio')\n",
    "    plt.title('PCA - Individual Variance Explained')\n",
    "    plt.grid(axis='y', alpha=0.3)\n",
    "    \n",
    "    # Cumulative variance\n",
    "    plt.subplot(1, 2, 2)\n",
    "    plt.plot(range(1, 51), cumsum_variance[:50], marker='o', linewidth=2, markersize=4)\n",
    "    plt.axhline(y=0.95, color='r', linestyle='--', label='95% variance')\n",
    "    plt.xlabel('Number of Components')\n",
    "    plt.ylabel('Cumulative Explained Variance')\n",
    "    plt.title('PCA - Cumulative Variance Explained')\n",
    "    plt.legend()\n",
    "    plt.grid(alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    # Find components needed for 95% variance\n",
    "    n_components_95 = np.argmax(cumsum_variance >= 0.95) + 1\n",
    "    print(f\"\\nComponents needed for 95% variance: {n_components_95}\")\n",
    "    print(f\"Variance explained by first 2 components: {cumsum_variance[1]:.2%}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. 2D Visualization with PCA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None and chunks:\n",
    "    # PCA to 2D\n",
    "    pca_2d = PCA(n_components=2)\n",
    "    embeddings_2d = pca_2d.fit_transform(embeddings)\n",
    "    \n",
    "    # Extract disease names for coloring\n",
    "    disease_names = [chunk['disease_name'] for chunk in chunks]\n",
    "    unique_diseases = list(set(disease_names))\n",
    "    disease_to_idx = {d: i for i, d in enumerate(unique_diseases)}\n",
    "    colors = [disease_to_idx[d] for d in disease_names]\n",
    "    \n",
    "    plt.figure(figsize=(14, 10))\n",
    "    scatter = plt.scatter(\n",
    "        embeddings_2d[:, 0], \n",
    "        embeddings_2d[:, 1],\n",
    "        c=colors,\n",
    "        cmap='tab20',\n",
    "        alpha=0.6,\n",
    "        s=30\n",
    "    )\n",
    "    \n",
    "    plt.xlabel(f'PC1 ({pca_2d.explained_variance_ratio_[0]:.2%} variance)', fontsize=12)\n",
    "    plt.ylabel(f'PC2 ({pca_2d.explained_variance_ratio_[1]:.2%} variance)', fontsize=12)\n",
    "    plt.title('2D PCA Visualization of Embeddings', fontsize=14, fontweight='bold')\n",
    "    plt.grid(alpha=0.3)\n",
    "    \n",
    "    # Show a few disease names\n",
    "    if len(unique_diseases) <= 20:\n",
    "        for disease in unique_diseases[:10]:\n",
    "            idx = disease_names.index(disease)\n",
    "            plt.annotate(\n",
    "                disease,\n",
    "                (embeddings_2d[idx, 0], embeddings_2d[idx, 1]),\n",
    "                fontsize=8,\n",
    "                alpha=0.7\n",
    "            )\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. t-SNE Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None and len(embeddings) < 5000:\n",
    "    print(\"Running t-SNE (this may take a few minutes)...\")\n",
    "    \n",
    "    # Use PCA first for speed\n",
    "    pca_50 = PCA(n_components=50)\n",
    "    embeddings_pca_50 = pca_50.fit_transform(embeddings)\n",
    "    \n",
    "    tsne = TSNE(n_components=2, random_state=42, perplexity=30, n_iter=1000)\n",
    "    embeddings_tsne = tsne.fit_transform(embeddings_pca_50)\n",
    "    \n",
    "    plt.figure(figsize=(14, 10))\n",
    "    scatter = plt.scatter(\n",
    "        embeddings_tsne[:, 0],\n",
    "        embeddings_tsne[:, 1],\n",
    "        c=colors,\n",
    "        cmap='tab20',\n",
    "        alpha=0.6,\n",
    "        s=30\n",
    "    )\n",
    "    \n",
    "    plt.xlabel('t-SNE Dimension 1', fontsize=12)\n",
    "    plt.ylabel('t-SNE Dimension 2', fontsize=12)\n",
    "    plt.title('t-SNE Visualization of Embeddings', fontsize=14, fontweight='bold')\n",
    "    plt.grid(alpha=0.3)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    print(\"✓ t-SNE visualization complete\")\n",
    "elif embeddings is not None:\n",
    "    print(f\"Skipping t-SNE: too many samples ({len(embeddings)}). Use a subset for t-SNE.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Clustering Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None:\n",
    "    # Determine optimal number of clusters using elbow method\n",
    "    max_clusters = min(20, len(embeddings) // 10)\n",
    "    inertias = []\n",
    "    K_range = range(2, max_clusters + 1)\n",
    "    \n",
    "    print(f\"Testing K-means with K from 2 to {max_clusters}...\")\n",
    "    for k in K_range:\n",
    "        kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)\n",
    "        kmeans.fit(embeddings)\n",
    "        inertias.append(kmeans.inertia_)\n",
    "    \n",
    "    plt.figure(figsize=(10, 6))\n",
    "    plt.plot(K_range, inertias, marker='o', linewidth=2, markersize=8)\n",
    "    plt.xlabel('Number of Clusters (K)', fontsize=12)\n",
    "    plt.ylabel('Inertia', fontsize=12)\n",
    "    plt.title('K-Means Elbow Method', fontsize=14, fontweight='bold')\n",
    "    plt.grid(alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Apply K-means with chosen K\n",
    "if embeddings is not None:\n",
    "    n_clusters = 8\n",
    "    print(f\"Applying K-means with {n_clusters} clusters...\")\n",
    "    \n",
    "    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)\n",
    "    cluster_labels = kmeans.fit_predict(embeddings)\n",
    "    \n",
    "    # Add cluster labels to chunks\n",
    "    for i, chunk in enumerate(chunks):\n",
    "        chunk['cluster'] = int(cluster_labels[i])\n",
    "    \n",
    "    # Analyze clusters\n",
    "    cluster_df = pd.DataFrame(chunks)\n",
    "    \n",
    "    print(\"\\nCluster Distribution:\")\n",
    "    cluster_counts = cluster_df['cluster'].value_counts().sort_index()\n",
    "    print(cluster_counts)\n",
    "    \n",
    "    # Visualize cluster sizes\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    colors_bar = plt.cm.Set3(range(n_clusters))\n",
    "    plt.bar(cluster_counts.index, cluster_counts.values, color=colors_bar, edgecolor='black')\n",
    "    plt.xlabel('Cluster ID', fontsize=12)\n",
    "    plt.ylabel('Number of Chunks', fontsize=12)\n",
    "    plt.title('Cluster Size Distribution', fontsize=14, fontweight='bold')\n",
    "    plt.xticks(cluster_counts.index)\n",
    "    plt.grid(axis='y', alpha=0.3)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Show sample diseases from each cluster\n",
    "if embeddings is not None:\n",
    "    print(\"\\nSample Diseases from Each Cluster:\")\n",
    "    print(\"=\"*80)\n",
    "    \n",
    "    for cluster_id in range(n_clusters):\n",
    "        cluster_chunks = cluster_df[cluster_df['cluster'] == cluster_id]\n",
    "        unique_diseases = cluster_chunks['disease_name'].unique()[:5]\n",
    "        \n",
    "        print(f\"\\nCluster {cluster_id} ({len(cluster_chunks)} chunks):\")\n",
    "        for disease in unique_diseases:\n",
    "            print(f\"  - {disease}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Similarity Analysis"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None and len(embeddings) < 1000:\n",
    "    print(\"Computing cosine similarity matrix...\")\n",
    "    \n",
    "    # Compute similarity for subset\n",
    "    sample_size = min(100, len(embeddings))\n",
    "    sample_indices = np.random.choice(len(embeddings), sample_size, replace=False)\n",
    "    sample_embeddings = embeddings[sample_indices]\n",
    "    \n",
    "    similarity_matrix = cosine_similarity(sample_embeddings)\n",
    "    \n",
    "    plt.figure(figsize=(12, 10))\n",
    "    sns.heatmap(\n",
    "        similarity_matrix,\n",
    "        cmap='coolwarm',\n",
    "        center=0,\n",
    "        square=True,\n",
    "        linewidths=0,\n",
    "        cbar_kws={\"shrink\": 0.8}\n",
    "    )\n",
    "    plt.title(f'Cosine Similarity Matrix (Sample of {sample_size} chunks)', fontsize=14, fontweight='bold')\n",
    "    plt.xlabel('Chunk Index')\n",
    "    plt.ylabel('Chunk Index')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    # Statistics\n",
    "    # Exclude diagonal\n",
    "    mask = np.ones_like(similarity_matrix, dtype=bool)\n",
    "    np.fill_diagonal(mask, False)\n",
    "    off_diagonal_sims = similarity_matrix[mask]\n",
    "    \n",
    "    print(f\"\\nSimilarity Statistics (excluding self-similarity):\")\n",
    "    print(f\"Mean similarity: {off_diagonal_sims.mean():.4f}\")\n",
    "    print(f\"Std similarity:  {off_diagonal_sims.std():.4f}\")\n",
    "    print(f\"Min similarity:  {off_diagonal_sims.min():.4f}\")\n",
    "    print(f\"Max similarity:  {off_diagonal_sims.max():.4f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Query Similarity Test"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load embedding model for query testing\n",
    "print(\"Loading embedding model for query testing...\")\n",
    "\n",
    "model_loader = ModelLoader(settings.EMBEDDING_MODEL, settings.DEVICE)\n",
    "tokenizer, model = model_loader.load()\n",
    "embedder = EmbeddingGenerator(tokenizer, model, settings.DEVICE)\n",
    "\n",
    "print(\"✓ Model loaded\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test query\n",
    "test_query = \"What are the symptoms of malaria?\"\n",
    "print(f\"Test Query: '{test_query}'\\n\")\n",
    "\n",
    "# Generate query embedding\n",
    "query_embedding = embedder.embed_query(test_query)\n",
    "print(f\"Query embedding shape: {query_embedding.shape}\")\n",
    "\n",
    "# Compute similarities\n",
    "similarities = cosine_similarity([query_embedding], embeddings)[0]\n",
    "\n",
    "# Get top 10 matches\n",
    "top_indices = np.argsort(similarities)[::-1][:10]\n",
    "\n",
    "print(\"\\nTop 10 Most Similar Chunks:\")\n",
    "print(\"=\"*80)\n",
    "\n",
    "for rank, idx in enumerate(top_indices, 1):\n",
    "    chunk = chunks[idx]\n",
    "    sim_score = similarities[idx]\n",
    "    \n",
    "    print(f\"\\n{rank}. Disease: {chunk['disease_name']}\")\n",
    "    print(f\"   Field: {chunk['field']}\")\n",
    "    print(f\"   Similarity: {sim_score:.4f}\")\n",
    "    print(f\"   Text: {chunk['text'][:150]}...\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Summary"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if embeddings is not None:\n",
    "    print(\"\\n\" + \"=\"*80)\n",
    "    print(\"EMBEDDING ANALYSIS SUMMARY\")\n",
    "    print(\"=\"*80)\n",
    "    print(f\"Total Embeddings: {len(embeddings)}\")\n",
    "    print(f\"Embedding Dimension: {embeddings.shape[1]}\")\n",
    "    print(f\"\\nPCA Analysis:\")\n",
    "    print(f\"  - Components for 95% variance: {n_components_95}\")\n",
    "    print(f\"  - First 2 components variance: {cumsum_variance[1]:.2%}\")\n",
    "    print(f\"\\nClustering:\")\n",
    "    print(f\"  - Number of clusters: {n_clusters}\")\n",
    "    print(f\"  - Largest cluster size: {cluster_counts.max()}\")\n",
    "    print(f\"  - Smallest cluster size: {cluster_counts.min()}\")\n",
    "    print(\"\\n\" + \"=\"*80)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}