In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# FactCheck-MM Advanced Visualizations\n",
    "\n",
    "## Overview\n",
    "This notebook provides advanced visualization capabilities for FactCheck-MM models including:\n",
    "- Cross-modal attention heatmaps\n",
    "- Grad-CAM visualizations for vision encoders\n",
    "- Audio attention maps for speech analysis\n",
    "- Interactive embeddings visualization\n",
    "- Publication-ready plots and figures\n",
    "\n",
    "## Method\n",
    "We implement specialized visualization techniques:\n",
    "- **Attention Visualization**: Cross-modal attention patterns\n",
    "- **Saliency Maps**: Grad-CAM for visual interpretability\n",
    "- **Embedding Projections**: t-SNE/UMAP for high-dimensional data\n",
    "- **Interactive Plots**: Plotly-based interactive visualizations\n",
    "\n",
    "All visualizations are designed for both analysis and publication quality."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Setup and imports\n",
    "import sys\n",
    "import os\n",
    "from pathlib import Path\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import plotly.express as px\n",
    "import plotly.graph_objects as go\n",
    "from plotly.subplots import make_subplots\n",
    "import plotly.figure_factory as ff\n",
    "import json\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "from sklearn.manifold import TSNE\n",
    "from sklearn.decomposition import PCA\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Add project root to path\n",
    "project_root = Path().cwd().parent if Path().cwd().name == 'notebooks' else Path().cwd()\n",
    "sys.path.insert(0, str(project_root))\n",
    "\n",
    "# Import project utilities\n",
    "from shared.utils.visualization import create_attention_heatmap, create_embedding_plot\n",
    "from shared.utils.interpretability import GradCAMVisualizer, AttentionVisualizer\n",
    "\n",
    "# Optional imports for advanced visualizations\n",
    "try:\n",
    "    import umap\n",
    "    UMAP_AVAILABLE = True\n",
    "except ImportError:\n",
    "    UMAP_AVAILABLE = False\n",
    "    print(\"UMAP not available. Install with: pip install umap-learn\")\n",
    "\n",
    "try:\n",
    "    from captum.attr import GradientShap, IntegratedGradients\n",
    "    CAPTUM_AVAILABLE = True\n",
    "except ImportError:\n",
    "    CAPTUM_AVAILABLE = False\n",
    "    print(\"Captum not available. Install with: pip install captum\")\n",
    "\n",
    "# Set style\n",
    "plt.style.use('seaborn-v0_8')\n",
    "sns.set_palette(\"husl\")\n",
    "\n",
    "# Create output directories\n",
    "output_dir = project_root / 'outputs' / 'notebooks'\n",
    "docs_dir = project_root / 'docs' / 'figures'\n",
    "output_dir.mkdir(parents=True, exist_ok=True)\n",
    "docs_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "print(f\"Project root: {project_root}\")\n",
    "print(f\"Output directory: {output_dir}\")\n",
    "print(f\"Docs directory: {docs_dir}\")\n",
    "print(f\"UMAP available: {UMAP_AVAILABLE}\")\n",
    "print(f\"Captum available: {CAPTUM_AVAILABLE}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cross-Modal Attention Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create cross-modal attention heatmaps\n",
    "def create_cross_modal_attention_viz():\n",
    "    \"\"\"\n",
    "    Create comprehensive cross-modal attention visualizations.\n",
    "    \"\"\"\n",
    "    # Mock attention data for demonstration\n",
    "    np.random.seed(42)\n",
    "    \n",
    "    # Text tokens\n",
    "    text_tokens = [\"Oh\", \"great\", \",\", \"another\", \"Monday\", \"morning\", \"meeting\", \"[CLS]\", \"[SEP]\"]\n",
    "    \n",
    "    # Audio frame indices\n",
    "    audio_frames = [f\"Audio_{i}\" for i in range(8)]\n",
    "    \n",
    "    # Image patch indices  \n",
    "    image_patches = [f\"Patch_{i}\" for i in range(12)]\n",
    "    \n",
    "    # Generate mock attention matrices\n",
    "    text_audio_attention = np.random.rand(len(text_tokens), len(audio_frames))\n",
    "    text_image_attention = np.random.rand(len(text_tokens), len(image_patches))\n",
    "    audio_image_attention = np.random.rand(len(audio_frames), len(image_patches))\n",
    "    \n",
    "    # Normalize attention weights\n",
    "    text_audio_attention = F.softmax(torch.tensor(text_audio_attention), dim=1).numpy()\n",
    "    text_image_attention = F.softmax(torch.tensor(text_image_attention), dim=1).numpy()\n",
    "    audio_image_attention = F.softmax(torch.tensor(audio_image_attention), dim=1).numpy()\n",
    "    \n",
    "    # Create attention heatmaps\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
    "    \n",
    "    # Plot 1: Text-Audio Attention\n",
    "    ax1 = axes[0, 0]\n",
    "    im1 = ax1.imshow(text_audio_attention, cmap='Blues', aspect='auto')\n",
    "    ax1.set_title('Text-Audio Cross-Modal Attention', fontsize=14, fontweight='bold')\n",
    "    ax1.set_xlabel('Audio Frames')\n",
    "    ax1.set_ylabel('Text Tokens')\n",
    "    ax1.set_xticks(range(len(audio_frames)))\n",
    "    ax1.set_yticks(range(len(text_tokens)))\n",
    "    ax1.set_xticklabels(audio_frames, rotation=45, ha='right')\n",
    "    ax1.set_yticklabels(text_tokens)\n",
    "    \n",
    "    # Add attention values\n",
    "    for i in range(len(text_tokens)):\n",
    "        for j in range(len(audio_frames)):\n",
    "            if text_audio_attention[i, j] > 0.1:  # Only show high attention\n",
    "                text = ax1.text(j, i, f'{text_audio_attention[i, j]:.2f}',\n",
    "                               ha=\"center\", va=\"center\", color=\"white\", fontsize=8)\n",
    "    \n",
    "    plt.colorbar(im1, ax=ax1, label='Attention Weight')\n",
    "    \n",
    "    # Plot 2: Text-Image Attention\n",
    "    ax2 = axes[0, 1]\n",
    "    im2 = ax2.imshow(text_image_attention, cmap='Reds', aspect='auto')\n",
    "    ax2.set_title('Text-Image Cross-Modal Attention', fontsize=14, fontweight='bold')\n",
    "    ax2.set_xlabel('Image Patches')\n",
    "    ax2.set_ylabel('Text Tokens')\n",
    "    ax2.set_xticks(range(len(image_patches)))\n",
    "    ax2.set_yticks(range(len(text_tokens)))\n",
    "    ax2.set_xticklabels(image_patches, rotation=45, ha='right')\n",
    "    ax2.set_yticklabels(text_tokens)\n",
    "    \n",
    "    # Add attention values\n",
    "    for i in range(len(text_tokens)):\n",
    "        for j in range(len(image_patches)):\n",
    "            if text_image_attention[i, j] > 0.08:\n",
    "                text = ax2.text(j, i, f'{text_image_attention[i, j]:.2f}',\n",
    "                               ha=\"center\", va=\"center\", color=\"white\", fontsize=7)\n",
    "    \n",
    "    plt.colorbar(im2, ax=ax2, label='Attention Weight')\n",
    "    \n",
    "    # Plot 3: Audio-Image Attention\n",
    "    ax3 = axes[1, 0]\n",
    "    im3 = ax3.imshow(audio_image_attention, cmap='Greens', aspect='auto')\n",
    "    ax3.set_title('Audio-Image Cross-Modal Attention', fontsize=14, fontweight='bold')\n",
    "    ax3.set_xlabel('Image Patches')\n",
    "    ax3.set_ylabel('Audio Frames')\n",
    "    ax3.set_xticks(range(len(image_patches)))\n",
    "    ax3.set_yticks(range(len(audio_frames)))\n",
    "    ax3.set_xticklabels(image_patches, rotation=45, ha='right')\n",
    "    ax3.set_yticklabels(audio_frames)\n",
    "    \n",
    "    # Add attention values\n",
    "    for i in range(len(audio_frames)):\n",
    "        for j in range(len(image_patches)):\n",
    "            if audio_image_attention[i, j] > 0.1:\n",
    "                text = ax3.text(j, i, f'{audio_image_attention[i, j]:.2f}',\n",
    "                               ha=\"center\", va=\"center\", color=\"white\", fontsize=8)\n",
    "    \n",
    "    plt.colorbar(im3, ax=ax3, label='Attention Weight')\n",
    "    \n",
    "    # Plot 4: Attention Flow Diagram\n",
    "    ax4 = axes[1, 1]\n",
    "    \n",
    "    # Create a simplified attention flow visualization\n",
    "    modalities = ['Text', 'Audio', 'Image']\n",
    "    attention_strengths = np.array([\n",
    "        [0, np.mean(text_audio_attention), np.mean(text_image_attention)],\n",
    "        [np.mean(text_audio_attention), 0, np.mean(audio_image_attention)],\n",
    "        [np.mean(text_image_attention), np.mean(audio_image_attention), 0]\n",
    "    ])\n",
    "    \n",
    "    im4 = ax4.imshow(attention_strengths, cmap='viridis', aspect='auto')\n",
    "    ax4.set_title('Cross-Modal Attention Summary', fontsize=14, fontweight='bold')\n",
    "    ax4.set_xticks(range(len(modalities)))\n",
    "    ax4.set_yticks(range(len(modalities)))\n",
    "    ax4.set_xticklabels(modalities)\n",
    "    ax4.set_yticklabels(modalities)\n",
    "    \n",
    "    # Add attention strength values\n",
    "    for i in range(len(modalities)):\n",
    "        for j in range(len(modalities)):\n",
    "            if i != j:\n",
    "                text = ax4.text(j, i, f'{attention_strengths[i, j]:.3f}',\n",
    "                               ha=\"center\", va=\"center\", color=\"white\", \n",
    "                               fontsize=12, fontweight='bold')\n",
    "    \n",
    "    plt.colorbar(im4, ax=ax4, label='Average Attention')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    return fig, {\n",
    "        'text_audio': text_audio_attention,\n",
    "        'text_image': text_image_attention, \n",
    "        'audio_image': audio_image_attention\n",
    "    }\n",
    "\n",
    "# Create attention visualizations\n",
    "fig, attention_data = create_cross_modal_attention_viz()\n",
    "plt.savefig(output_dir / 'cross_modal_attention_heatmaps.png', dpi=300, bbox_inches='tight')\n",
    "plt.savefig(docs_dir / 'cross_modal_attention_heatmaps.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Cross-modal attention heatmaps saved to:\")\n",
    "print(f\"  Analysis: {output_dir / 'cross_modal_attention_heatmaps.png'}\")\n",
    "print(f\"  Documentation: {docs_dir / 'cross_modal_attention_heatmaps.png'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Grad-CAM Visualization for Vision Encoders"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create Grad-CAM visualizations for image understanding\n",
    "def create_gradcam_visualization():\n",
    "    \"\"\"\n",
    "    Create Grad-CAM visualizations for vision encoder interpretability.\n",
    "    \"\"\"\n",
    "    # Create mock Grad-CAM heatmaps\n",
    "    np.random.seed(42)\n",
    "    \n",
    "    # Simulate different image scenarios\n",
    "    scenarios = [\n",
    "        \"Facial Expression (Sarcastic)\",\n",
    "        \"Gesture Analysis\", \n",
    "        \"Context Objects\",\n",
    "        \"Scene Understanding\"\n",
    "    ]\n",
    "    \n",
    "    fig, axes = plt.subplots(2, 4, figsize=(20, 10))\n",
    "    \n",
    "    for idx, scenario in enumerate(scenarios):\n",
    "        # Create mock original image (grayscale for simplicity)\n",
    "        original_image = np.random.rand(224, 224)\n",
    "        \n",
    "        # Create mock Grad-CAM heatmap\n",
    "        # Focus attention on different regions based on scenario\n",
    "        if idx == 0:  # Facial expression\n",
    "            center_y, center_x = 80, 112  # Upper center (face region)\n",
    "        elif idx == 1:  # Gesture\n",
    "            center_y, center_x = 150, 80   # Lower left (hand region)\n",
    "        elif idx == 2:  # Objects\n",
    "            center_y, center_x = 112, 180  # Center right (object region)\n",
    "        else:  # Scene\n",
    "            center_y, center_x = 180, 112  # Lower center (scene context)\n",
    "        \n",
    "        # Create Gaussian-like attention map\n",
    "        y, x = np.ogrid[:224, :224]\n",
    "        mask = ((x - center_x) ** 2 + (y - center_y) ** 2) < 40**2\n",
    "        gradcam_heatmap = np.zeros((224, 224))\n",
    "        gradcam_heatmap[mask] = 1.0\n",
    "        \n",
    "        # Add some noise and smooth the heatmap\n",
    "        from scipy import ndimage\n",
    "        gradcam_heatmap = ndimage.gaussian_filter(gradcam_heatmap, sigma=15)\n",
    "        gradcam_heatmap = gradcam_heatmap / gradcam_heatmap.max()\n",
    "        \n",
    "        # Plot original image\n",
    "        ax_orig = axes[0, idx]\n",
    "        ax_orig.imshow(original_image, cmap='gray', alpha=0.7)\n",
    "        ax_orig.set_title(f'{scenario}\\nOriginal Image', fontsize=12, fontweight='bold')\n",
    "        ax_orig.axis('off')\n",
    "        \n",
    "        # Add mock image content indicators\n",
    "        if idx == 0:  # Face\n",
    "            circle = plt.Circle((112, 80), 30, fill=False, color='yellow', linewidth=2)\n",
    "            ax_orig.add_patch(circle)\n",
    "            ax_orig.text(112, 50, 'Face', ha='center', color='yellow', fontweight='bold')\n",
    "        elif idx == 1:  # Gesture\n",
    "            rect = plt.Rectangle((65, 130), 30, 40, fill=False, color='cyan', linewidth=2)\n",
    "            ax_orig.add_patch(rect)\n",
    "            ax_orig.text(80, 190, 'Hand', ha='center', color='cyan', fontweight='bold')\n",
    "        elif idx == 2:  # Objects\n",
    "            rect = plt.Rectangle((160, 90), 40, 44, fill=False, color='lime', linewidth=2)\n",
    "            ax_orig.add_patch(rect)\n",
    "            ax_orig.text(180, 150, 'Object', ha='center', color='lime', fontweight='bold')\n",
    "        else:  # Scene\n",
    "            rect = plt.Rectangle((80, 160), 64, 40, fill=False, color='orange', linewidth=2)\n",
    "            ax_orig.add_patch(rect)\n",
    "            ax_orig.text(112, 210, 'Context', ha='center', color='orange', fontweight='bold')\n",
    "        \n",
    "        # Plot Grad-CAM overlay\n",
    "        ax_gradcam = axes[1, idx]\n",
    "        ax_gradcam.imshow(original_image, cmap='gray', alpha=0.4)\n",
    "        im = ax_gradcam.imshow(gradcam_heatmap, cmap='hot', alpha=0.6)\n",
    "        ax_gradcam.set_title(f'Grad-CAM Heatmap\\nSaliency: {gradcam_heatmap.max():.3f}', \n",
    "                            fontsize=12, fontweight='bold')\n",
    "        ax_gradcam.axis('off')\n",
    "        \n",
    "        # Add colorbar for the last subplot\n",
    "        if idx == len(scenarios) - 1:\n",
    "            cbar = plt.colorbar(im, ax=ax_gradcam, fraction=0.046, pad=0.04)\n",
    "            cbar.set_label('Attention Intensity', fontsize=10)\n",
    "    \n",
    "    plt.suptitle('Grad-CAM Visualization for Vision Encoder Analysis', \n",
    "                 fontsize=16, fontweight='bold', y=0.98)\n",
    "    plt.tight_layout()\n",
    "    return fig\n",
    "\n",
    "# Create Grad-CAM visualizations\n",
    "try:\n",
    "    from scipy import ndimage\n",
    "    fig = create_gradcam_visualization()\n",
    "    plt.savefig(output_dir / 'gradcam_visualizations.png', dpi=300, bbox_inches='tight')\n",
    "    plt.savefig(docs_dir / 'gradcam_visualizations.png', dpi=300, bbox_inches='tight')\n",
    "    plt.show()\n",
    "    \n",
    "    print(f\"Grad-CAM visualizations saved to:\")\n",
    "    print(f\"  Analysis: {output_dir / 'gradcam_visualizations.png'}\")\n",
    "    print(f\"  Documentation: {docs_dir / 'gradcam_visualizations.png'}\")\n",
    "    \n",
    "except ImportError:\n",
    "    print(\"SciPy not available for Grad-CAM visualization. Install with: pip install scipy\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Audio Attention Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create audio attention visualizations\n",
    "def create_audio_attention_visualization():\n",
    "    \"\"\"\n",
    "    Create audio attention maps for speech analysis.\n",
    "    \"\"\"\n",
    "    # Mock audio data and attention\n",
    "    np.random.seed(42)\n",
    "    \n",
    "    # Simulate audio spectrogram data\n",
    "    time_steps = 100\n",
    "    freq_bins = 80\n",
    "    \n",
    "    # Create mock spectrogram\n",
    "    spectrogram = np.random.rand(freq_bins, time_steps)\n",
    "    \n",
    "    # Add some structure to make it look more realistic\n",
    "    for t in range(time_steps):\n",
    "        # Add formant-like structure\n",
    "        if t % 20 < 10:  # Speech segments\n",
    "            spectrogram[10:20, t] *= 2  # First formant\n",
    "            spectrogram[30:40, t] *= 1.5  # Second formant\n",
    "            spectrogram[50:60, t] *= 1.2  # Third formant\n",
    "    \n",
    "    # Create attention patterns for different aspects\n",
    "    pitch_attention = np.zeros((freq_bins, time_steps))\n",
    "    rhythm_attention = np.zeros((freq_bins, time_steps))\n",
    "    prosody_attention = np.zeros((freq_bins, time_steps))\n",
    "    \n",
    "    # Pitch attention (focus on fundamental frequency)\n",
    "    pitch_attention[5:25, :] = np.random.rand(20, time_steps) * 0.8\n",
    "    \n",
    "    # Rhythm attention (temporal patterns)\n",
    "    for t in range(0, time_steps, 15):\n",
    "        rhythm_attention[:, t:t+5] = np.random.rand(freq_bins, min(5, time_steps-t)) * 0.6\n",
    "    \n",
    "    # Prosody attention (higher frequencies)\n",
    "    prosody_attention[40:70, :] = np.random.rand(30, time_steps) * 0.7\n",
    "    \n",
    "    # Create visualization\n",
    "    fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
    "    \n",
    "    # Plot 1: Original Spectrogram\n",
    "    ax1 = axes[0, 0]\n",
    "    im1 = ax1.imshow(spectrogram, cmap='viridis', aspect='auto', origin='lower')\n",
    "    ax1.set_title('Audio Spectrogram\\n\"Oh great, another meeting\"', fontsize=14, fontweight='bold')\n",
    "    ax1.set_xlabel('Time Frames')\n",
    "    ax1.set_ylabel('Frequency Bins')\n",
    "    plt.colorbar(im1, ax=ax1, label='Magnitude')\n",
    "    \n",
    "    # Add speech segment annotations\n",
    "    segments = [(0, 25, 'Oh'), (25, 45, 'great'), (45, 65, 'another'), (65, 100, 'meeting')]\n",
    "    for start, end, word in segments:\n",
    "        ax1.axvline(start, color='white', linestyle='--', alpha=0.7)\n",
    "        ax1.text((start + end) / 2, freq_bins - 5, word, ha='center', \n",
    "                color='white', fontweight='bold', fontsize=10)\n",
    "    \n",
    "    # Plot 2: Pitch Attention\n",
    "    ax2 = axes[0, 1] \n",
    "    im2 = ax2.imshow(pitch_attention, cmap='Reds', aspect='auto', origin='lower', alpha=0.8)\n",
    "    ax2.imshow(spectrogram, cmap='gray', aspect='auto', origin='lower', alpha=0.3)\n",
    "    ax2.set_title('Pitch Attention\\n(Fundamental Frequency Focus)', fontsize=14, fontweight='bold')\n",
    "    ax2.set_xlabel('Time Frames')\n",
    "    ax2.set_ylabel('Frequency Bins')\n",
    "    plt.colorbar(im2, ax=ax2, label='Attention Weight')\n",
    "    \n",
    "    # Plot 3: Rhythm Attention\n",
    "    ax3 = axes[1, 0]\n",
    "    im3 = ax3.imshow(rhythm_attention, cmap='Blues', aspect='auto', origin='lower', alpha=0.8)\n",
    "    ax3.imshow(spectrogram, cmap='gray', aspect='auto', origin='lower', alpha=0.3)\n",
    "    ax3.set_title('Rhythm Attention\\n(Temporal Pattern Focus)', fontsize=14, fontweight='bold')\n",
    "    ax3.set_xlabel('Time Frames')\n",
    "    ax3.set_ylabel('Frequency Bins')\n",
    "    plt.colorbar(im3, ax=ax3, label='Attention Weight')\n",
    "    \n",
    "    # Plot 4: Prosody Attention\n",
    "    ax4 = axes[1, 1]\n",
    "    im4 = ax4.imshow(prosody_attention, cmap='Greens', aspect='auto', origin='lower', alpha=0.8)\n",
    "    ax4.imshow(spectrogram, cmap='gray', aspect='auto', origin='lower', alpha=0.3)\n",
    "    ax4.set_title('Prosody Attention\\n(Emotional Tone Focus)', fontsize=14, fontweight='bold')\n",
    "    ax4.set_xlabel('Time Frames')\n",
    "    ax4.set_ylabel('Frequency Bins')\n",
    "    plt.colorbar(im4, ax=ax4, label='Attention Weight')\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    return fig\n",
    "\n",
    "# Create audio attention visualization\n",
    "fig = create_audio_attention_visualization()\n",
    "plt.savefig(output_dir / 'audio_attention_maps.png', dpi=300, bbox_inches='tight')\n",
    "plt.savefig(docs_dir / 'audio_attention_maps.png', dpi=300, bbox_inches='tight')\n",
    "plt.show()\n",
    "\n",
    "print(f\"Audio attention maps saved to:\")\n",
    "print(f\"  Analysis: {output_dir / 'audio_attention_maps.png'}\")\n",
    "print(f\"  Documentation: {docs_dir / 'audio_attention_maps.png'}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Embeddings Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create interactive embeddings visualization\n",
    "def create_interactive_embeddings():\n",
    "    \"\"\"\n",
    "    Create interactive embeddings visualization using t-SNE/UMAP.\n",
    "    \"\"\"\n",
    "    # Generate mock high-dimensional embeddings\n",
    "    np.random.seed(42)\n",
    "    \n",
    "    n_samples = 1000\n",
    "    embedding_dim = 768\n",
    "    \n",
    "    # Create mock embeddings for different tasks and datasets\n",
    "    embeddings = np.random.randn(n_samples, embedding_dim)\n",
    "    \n",
    "    # Create labels and metadata\n",
    "    tasks = np.random.choice(['Sarcasm', 'Fact_Check', 'Paraphrase'], n_samples, p=[0.5, 0.3, 0.2])\n",
    "    datasets = np.random.choice(['SARC', 'MMSD2', 'FEVER', 'LIAR', 'ParaNMT'], n_samples)\n",
    "    predictions = np.random.choice(['Correct', 'Incorrect'], n_samples, p=[0.75, 0.25])\n",
    "    confidence = np.random.uniform(0.4
