In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Language Model Training for Truyện Kiều\n",
    "\n",
    "This notebook demonstrates how to train language models for generating new verses in the style of Truyện Kiều.\n",
    "\n",
    "We'll implement and train:\n",
    "1. A character-level n-gram model (statistical approach)\n",
    "2. A LSTM-based neural language model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "from tqdm.notebook import tqdm\n",
    "\n",
    "from src.preprocessor import KieuPreprocessor\n",
    "from src.language_model import KieuLanguageModelTrainer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Load and Prepare the Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the Truyện Kiều text\n",
    "preprocessor = KieuPreprocessor()\n",
    "verses = preprocessor.load_poem('../data/truyen_kieu.txt')\n",
    "\n",
    "print(f\"Loaded {len(verses)} verses from Truyện Kiều\")\n",
    "print(\"\\nFirst 5 verses:\")\n",
    "for i, verse in enumerate(verses[:5]):\n",
    "    print(f\"{i+1}. {verse}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. N-Gram Language Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize N-Gram model trainer\n",
    "ngram_trainer = KieuLanguageModelTrainer(model_type='ngram', n=3, smoothing=0.01)\n",
    "\n",
    "# Train the model\n",
    "print(\"Training N-Gram model...\")\n",
    "ngram_trainer.train(verses)\n",
    "\n",
    "# Save the model\n",
    "ngram_trainer.save('../models/kieu_ngram.pkl')\n",
    "print(\"N-Gram model saved to '../models/kieu_ngram.pkl'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Test the N-Gram Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the N-Gram model with different prompts\n",
    "prompts = [\n",
    "    \"Trăm năm\",  # Famous opening\n",
    "    \"Tình duyên\",  # Love and fate\n",
    "    \"Hồng nhan\",  # Beauty\n",
    "    \"Sông núi\"   # Nature\n",
    "]\n",
    "\n",
    "print(\"Generating verses with N-Gram model:\\n\")\n",
    "for prompt in prompts:\n",
    "    generated = ngram_trainer.generate(prompt, max_length=50)\n",
    "    print(f\"Prompt: '{prompt}'\")\n",
    "    print(f\"Generated: '{generated}'\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. LSTM Language Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check if GPU is available\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"Using device: {device}\")\n",
    "\n",
    "# Initialize LSTM model trainer\n",
    "lstm_trainer = KieuLanguageModelTrainer(\n",
    "    model_type='lstm',\n",
    "    embedding_dim=128,\n",
    "    hidden_dim=256,\n",
    "    num_layers=2,\n",
    "    dropout=0.2\n",
    ")\n",
    "\n",
    "# Train the model (this will take some time)\n",
    "print(\"Training LSTM model (this may take a while)...\")\n",
    "lstm_trainer.train(verses, epochs=5, batch_size=32, learning_rate=0.001)\n",
    "\n",
    "# Save the model\n",
    "lstm_trainer.save('../models/kieu_lstm.pkl')\n",
    "print(\"LSTM model saved to '../models/kieu_lstm.pkl'\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Test the LSTM Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test the LSTM model with the same prompts\n",
    "print(\"Generating verses with LSTM model:\\n\")\n",
    "for prompt in prompts:\n",
    "    generated = lstm_trainer.generate(prompt, max_length=50)\n",
    "    print(f\"Prompt: '{prompt}'\")\n",
    "    print(f\"Generated: '{generated}'\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Compare N-Gram and LSTM Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate multiple verses with both models\n",
    "comparison_prompts = [\n",
    "    \"Trăm năm\",\n",
    "    \"Sông Tiền\",\n",
    "    \"Trông ra\"\n",
    "]\n",
    "\n",
    "for prompt in comparison_prompts:\n",
    "    print(f\"\\nPrompt: '{prompt}'\")\n",
    "    \n",
    "    print(\"\\nN-Gram generated verses:\")\n",
    "    for i in range(3):\n",
    "        generated = ngram_trainer.generate(prompt, max_length=50)\n",
    "        print(f\"{i+1}. {generated}\")\n",
    "    \n",
    "    print(\"\\nLSTM generated verses:\")\n",
    "    for i in range(3):\n",
    "        generated = lstm_trainer.generate(prompt, max_length=50)\n",
    "        print(f\"{i+1}. {generated}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Evaluating Generated Verses\n",
    "\n",
    "Let's define a simple evaluation based on:\n",
    "1. Presence of Vietnamese-style verse endings (usually commas or periods)\n",
    "2. Length of the verse (ideally 6-8 words)\n",
    "3. Simple perplexity-based measure for fluency"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from src.verse_generator import KieuVerseGenerator\n",
    "\n",
    "# Load trained models into the verse generator\n",
    "ngram_generator = KieuVerseGenerator('../models/kieu_ngram.pkl', 'ngram')\n",
    "lstm_generator = KieuVerseGenerator('../models/kieu_lstm.pkl', 'lstm')\n",
    "\n",
    "# Generate verses with quality evaluation\n",
    "test_prompts = [\n",
    "    \"Trăm năm\",\n",
    "    \"Tình duyên\",\n",
    "    \"Mặt trăng\",\n",
    "    \"Hồng nhan\"\n",
    "]\n",
    "\n",
    "print(\"Evaluating generated verses:\\n\")\n",
    "\n",
    "for prompt in test_prompts:\n",
    "    print(f\"Prompt: '{prompt}'\")\n",
    "    \n",
    "    # Generate verses with N-Gram model\n",
    "    ngram_verses = ngram_generator.generate_verse(prompt, num_samples=1)\n",
    "    if ngram_verses:\n",
    "        ngram_verse = ngram_verses[0]\n",
    "        ngram_quality = ngram_generator.evaluate_verse_quality(ngram_verse)\n",
    "        \n",
    "        print(f\"\\nN-Gram: '{ngram_verse}'\")\n",
    "        print(f\"Quality scores: \")\n",
    "        for metric, score in ngram_quality.items():\n",
    "            print(f\"  - {metric}: {score:.2f}\")\n",
    "    \n",
    "    # Generate verses with LSTM model\n",
    "    lstm_verses = lstm_generator.generate_verse(prompt, num_samples=1)\n",
    "    if lstm_verses:\n",
    "        lstm_verse = lstm_verses[0]\n",
    "        lstm_quality = lstm_generator.evaluate_verse_quality(lstm_verse)\n",
    "        \n",
    "        print(f\"\\nLSTM: '{lstm_verse}'\")\n",
    "        print(f\"Quality scores: \")\n",
    "        for metric, score in lstm_quality.items():\n",
    "            print(f\"  - {metric}: {score:.2f}\")\n",
    "    \n",
    "    print(\"\\n\" + \"-\"*50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Generating Verse Pairs\n",
    "\n",
    "In Truyện Kiều, verses often come in pairs that work together poetically. Let's generate some pairs."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate verse pairs\n",
    "print(\"Generating verse pairs:\\n\")\n",
    "\n",
    "# Use the better performing model based on previous evaluations\n",
    "generator = lstm_generator if torch.cuda.is_available() else ngram_generator\n",
    "\n",
    "for _ in range(5):\n",
    "    verse_pair = generator.generate_verse_pair()\n",
    "    print(f\"First verse: {verse_pair[0]}\")\n",
    "    print(f\"Second verse: {verse_pair[1]}\")\n",
    "    print()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Generating Verses by Theme\n",
    "\n",
    "Let's generate verses related to specific themes from Truyện Kiều."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Generate verses by theme\n",
    "themes = [\"love\", \"nature\", \"fate\", \"beauty\", \"sadness\"]\n",
    "\n",
    "print(\"Generating verses by theme:\\n\")\n",
    "\n",
    "for theme in themes:\n",
    "    print(f\"Theme: {theme}\")\n",
    "    verses = generator.generate_with_theme(theme, num_samples=3)\n",
    "    for i, verse in enumerate(verses, 1):\n",
    "        print(f\"{i}. {verse}\")\n",
    "    print()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 10. Conclusion\n",
    "\n",
    "We've successfully trained and compared two types of language models for generating Truyện Kiều style verses:\n",
    "\n",
    "1. **N-Gram Model**:\n",
    "   - Simpler, faster to train\n",
    "   - Generally produces shorter sequences\n",
    "   - May have less coherence across the entire verse\n",
    "\n",
    "2. **LSTM Model**:\n",
    "   - More computationally intensive\n",
    "   - Captures longer-range dependencies\n",
    "   - Generally produces more coherent verses\n",
    "\n",
    "Both models are available for generating new verses starting from a given phrase."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}