In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# HMM Training and Analysis\n",
    "This notebook trains the Hidden Markov Model emissions for Hangman letter prediction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "import pickle\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from pathlib import Path\n",
    "\n",
    "# Add project root to path\n",
    "sys.path.append('..')\n",
    "\n",
    "from src.utils.data_loader import CorpusLoader\n",
    "from src.hmm.emissions import EmissionBuilder\n",
    "from src.hmm.oracle import HMMOracle"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load and Explore Corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load corpus\n",
    "loader = CorpusLoader('../data/raw/corpus.txt')\n",
    "words_by_length, letter_freq = loader.load_and_preprocess()\n",
    "\n",
    "print(f\"Total unique words: {sum(len(v) for v in words_by_length.values())}\")\n",
    "print(f\"Word lengths: {sorted(words_by_length.keys())}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot word length distribution\n",
    "plt.figure(figsize=(12, 6))\n",
    "lengths = sorted(words_by_length.keys())\n",
    "counts = [len(words_by_length[l]) for l in lengths]\n",
    "\n",
    "plt.bar(lengths, counts, edgecolor='black', alpha=0.7)\n",
    "plt.xlabel('Word Length')\n",
    "plt.ylabel('Number of Words')\n",
    "plt.title('Distribution of Word Lengths in Corpus')\n",
    "plt.grid(axis='y', alpha=0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot letter frequency\n",
    "plt.figure(figsize=(12, 5))\n",
    "letters = sorted(letter_freq.keys())\n",
    "freqs = [letter_freq[l] for l in letters]\n",
    "\n",
    "plt.bar(letters, freqs, edgecolor='black', alpha=0.7, color='steelblue')\n",
    "plt.xlabel('Letter')\n",
    "plt.ylabel('Frequency')\n",
    "plt.title('Overall Letter Frequency in Corpus')\n",
    "plt.grid(axis='y', alpha=0.3)\n",
    "plt.show()\n",
    "\n",
    "# Top 10 letters\n",
    "sorted_letters = sorted(letter_freq.items(), key=lambda x: x[1], reverse=True)\n",
    "print(\"\\nTop 10 most frequent letters:\")\n",
    "for letter, freq in sorted_letters[:10]:\n",
    "    print(f\"{letter}: {freq}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Build HMM Emissions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Build emissions with smoothing\n",
    "builder = EmissionBuilder(smoothing_alpha=1.0)\n",
    "emissions = builder.build_from_corpus(words_by_length)\n",
    "\n",
    "print(f\"\\nEmissions trained for {len(emissions)} different word lengths\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Analyze Emissions for Sample Length"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze emissions for 5-letter words\n",
    "sample_length = 5\n",
    "if sample_length in emissions:\n",
    "    position_emissions = emissions[sample_length]\n",
    "    \n",
    "    # Create heatmap\n",
    "    letters = list('abcdefghijklmnopqrstuvwxyz')\n",
    "    emission_matrix = np.zeros((sample_length, 26))\n",
    "    \n",
    "    for pos in range(sample_length):\n",
    "        for i, letter in enumerate(letters):\n",
    "            emission_matrix[pos, i] = position_emissions[pos][letter]\n",
    "    \n",
    "    plt.figure(figsize=(14, 6))\n",
    "    sns.heatmap(emission_matrix, \n",
    "                xticklabels=letters,\n",
    "                yticklabels=[f'Pos {i}' for i in range(sample_length)],\n",
    "                cmap='YlOrRd',\n",
    "                cbar_kws={'label': 'Probability'})\n",
    "    plt.title(f'Letter Emission Probabilities for {sample_length}-Letter Words')\n",
    "    plt.xlabel('Letter')\n",
    "    plt.ylabel('Position')\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "    # Show top letters per position\n",
    "    print(f\"\\nTop 5 letters per position for {sample_length}-letter words:\")\n",
    "    for pos in range(sample_length):\n",
    "        top_letters = sorted(position_emissions[pos].items(), \n",
    "                           key=lambda x: x[1], reverse=True)[:5]\n",
    "        print(f\"Position {pos}: {', '.join([f'{l}({p:.3f})' for l, p in top_letters])}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Test HMM Oracle Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create oracle\n",
    "oracle = HMMOracle(emissions, words_by_length)\n",
    "\n",
    "# Test with sample mask\n",
    "test_mask = \"_pp__\"\n",
    "guessed = set(['a', 'p'])\n",
    "\n",
    "probs = oracle.get_letter_probs(test_mask, guessed)\n",
    "\n",
    "# Sort and display top predictions\n",
    "sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:10]\n",
    "\n",
    "print(f\"\\nTop 10 predictions for mask '{test_mask}' (guessed: {guessed}):\")\n",
    "for letter, prob in sorted_probs:\n",
    "    print(f\"{letter}: {prob:.4f}\")\n",
    "\n",
    "# Visualize\n",
    "letters = [l for l, _ in sorted_probs]\n",
    "probs_vals = [p for _, p in sorted_probs]\n",
    "\n",
    "plt.figure(figsize=(10, 5))\n",
    "plt.bar(letters, probs_vals, edgecolor='black', alpha=0.7)\n",
    "plt.xlabel('Letter')\n",
    "plt.ylabel('Probability')\n",
    "plt.title(f'HMM Predictions for Mask: {test_mask}')\n",
    "plt.grid(axis='y', alpha=0.3)\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Save Processed Data and Emissions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save preprocessed data\n",
    "loader.save_processed('../data/processed')\n",
    "\n",
    "# Save emissions\n",
    "Path('../models/hmm').mkdir(parents=True, exist_ok=True)\n",
    "builder.save('../models/hmm/emissions.pkl')\n",
    "\n",
    "print(\"\\nData and emissions saved successfully!\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
