In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Movie Genre Prediction from Plot Summaries\n",
    "\n",
    "This notebook demonstrates how to build a machine learning model that predicts movie genres based on plot summaries using Natural Language Processing (NLP) techniques.\n",
    "\n",
    "## Table of Contents\n",
    "1. [Data Loading and Exploration](#data-loading)\n",
    "2. [Data Preprocessing](#data-preprocessing)\n",
    "3. [Exploratory Data Analysis](#eda)\n",
    "4. [Feature Engineering](#feature-engineering)\n",
    "5. [Model Training](#model-training)\n",
    "6. [Model Evaluation](#model-evaluation)\n",
    "7. [Model Comparison](#model-comparison)\n",
    "8. [Feature Importance Analysis](#feature-importance)\n",
    "9. [Predictions and Results](#predictions)\n",
    "\n",
    "## Libraries and Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import required libraries\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "# Set style for plots\n",
    "plt.style.use('default')\n",
    "sns.set_palette(\"husl\")\n",
    "\n",
    "# Import our custom modules\n",
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "from src.preprocess import TextPreprocessor, extract_genres, get_all_genres, create_genre_columns\n",
    "from src.model import MovieGenrePredictor, compare_models, print_evaluation_report\n",
    "from src.utils import (\n",
    "    plot_genre_distribution, plot_text_statistics, create_wordcloud_by_genre,\n",
    "    plot_model_comparison, print_dataset_summary\n",
    ")\n",
    "\n",
    "print(\"All libraries imported successfully!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Data Loading and Exploration <a name=\"data-loading\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load the dataset\n",
    "df = pd.read_csv('../data/movies.csv')\n",
    "\n",
    "# Display basic information\n",
    "print(\"Dataset Shape:\", df.shape)\n",
    "print(\"\\nFirst few rows:\")\n",
    "df.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display dataset information\n",
    "print(\"Dataset Info:\")\n",
    "df.info()\n",
    "\n",
    "print(\"\\nDataset Summary:\")\n",
    "df.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Print comprehensive dataset summary\n",
    "print_dataset_summary(df, 'plot', 'genre')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Data Preprocessing <a name=\"data-preprocessing\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize text preprocessor\n",
    "preprocessor = TextPreprocessor(use_stemming=True, use_lemmatization=False)\n",
    "\n",
    "# Preprocess the plot summaries\n",
    "df_processed = preprocessor.preprocess_dataframe(df, 'plot')\n",
    "\n",
    "print(\"Original plot example:\")\n",
    "print(df['plot'].iloc[0])\n",
    "print(\"\\nPreprocessed plot example:\")\n",
    "print(df_processed['plot_processed'].iloc[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create binary genre columns\n",
    "df_processed = create_genre_columns(df_processed, 'genre')\n",
    "\n",
    "# Get genre columns\n",
    "genre_columns = [col for col in df_processed.columns if col.startswith('genre_')]\n",
    "print(f\"Created {len(genre_columns)} genre columns:\")\n",
    "print(genre_columns[:10])  # Show first 10\n",
    "\n",
    "# Display the processed dataset\n",
    "print(\"\\nProcessed dataset shape:\", df_processed.shape)\n",
    "df_processed.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Exploratory Data Analysis <a name=\"eda\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot genre distribution\n",
    "genre_dist = plot_genre_distribution(df, 'genre')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot text statistics\n",
    "plot_text_statistics(df, 'plot')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create word clouds for different genres\n",
    "create_wordcloud_by_genre(df, 'plot', 'genre', top_genres=6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze most common words by genre\n",
    "from src.preprocess import get_most_common_words_by_genre\n",
    "\n",
    "common_words = get_most_common_words_by_genre(df, 'plot', 'genre', top_n=10)\n",
    "\n",
    "print(\"Most common words by genre:\")\n",
    "print(\"=\" * 50)\n",
    "for genre, words in list(common_words.items())[:5]:  # Show first 5 genres\n",
    "    print(f\"\\n{genre}:\")\n",
    "    for word, count in words:\n",
    "        print(f\"  {word}: {count}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Feature Engineering <a name=\"feature-engineering\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Check for missing values in processed data\n",
    "print(\"Missing values in processed dataset:\")\n",
    "print(df_processed.isnull().sum())\n",
    "\n",
    "# Remove rows with missing processed plots\n",
    "df_processed = df_processed.dropna(subset=['plot_processed'])\n",
    "print(f\"\\nDataset shape after removing missing values: {df_processed.shape}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Verify genre columns have sufficient data\n",
    "genre_counts = df_processed[genre_columns].sum().sort_values(ascending=False)\n",
    "print(\"Number of movies per genre:\")\n",
    "for genre, count in genre_counts.items():\n",
    "    genre_name = genre.replace('genre_', '').replace('_', ' ').title()\n",
    "    print(f\"{genre_name:20}: {count:3d} movies\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Model Training <a name=\"model-training\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train the best model (TF-IDF + Logistic Regression)\n",
    "predictor = MovieGenrePredictor(vectorizer_type='tfidf', model_type='logistic')\n",
    "\n",
    "# Train the model\n",
    "metrics = predictor.train(df_processed, 'plot_processed', genre_columns)\n",
    "\n",
    "# Print evaluation report\n",
    "print_evaluation_report(metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the trained model\n",
    "predictor.save_model('../models/movie_genre_predictor.pkl')\n",
    "print(\"Model saved successfully!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Model Evaluation <a name=\"model-evaluation\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot confusion matrices\n",
    "from src.utils import plot_confusion_matrix\n",
    "plot_confusion_matrix(predictor.y_test, predictor.y_pred, genre_columns[:6])  # First 6 genres"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Analyze per-genre performance\n",
    "genre_performance = pd.DataFrame()\n",
    "for genre, metrics in metrics['genre_metrics'].items():\n",
    "    genre_name = genre.replace('genre_', '').replace('_', ' ').title()\n",
    "    genre_performance = genre_performance.append({\n",
    "        'Genre': genre_name,\n",
    "        'Precision': metrics['precision'],\n",
    "        'Recall': metrics['recall'],\n",
    "        'F1-Score': metrics['f1']\n",
    "    }, ignore_index=True)\n",
    "\n",
    "genre_performance = genre_performance.sort_values('F1-Score', ascending=False)\n",
    "print(\"Genre Performance (sorted by F1-Score):\")\n",
    "print(genre_performance.round(4))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 7. Model Comparison <a name=\"model-comparison\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare different model configurations\n",
    "print(\"Comparing different model configurations...\")\n",
    "comparison_results = compare_models(df_processed, 'plot_processed', genre_columns)\n",
    "\n",
    "# Plot comparison results\n",
    "plot_model_comparison(comparison_results)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display comparison results in a table\n",
    "comparison_df = pd.DataFrame(comparison_results).T\n",
    "comparison_df = comparison_df.round(4)\n",
    "comparison_df = comparison_df.sort_values('f1_macro', ascending=False)\n",
    "\n",
    "print(\"Model Comparison Results (sorted by F1-Score):\")\n",
    "print(comparison_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 8. Feature Importance Analysis <a name=\"feature-importance\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get feature importance for the logistic regression model\n",
    "from src.model import get_feature_importance\n",
    "\n",
    "feature_importance = get_feature_importance(predictor, top_n=15)\n",
    "\n",
    "# Plot feature importance\n",
    "from src.utils import plot_feature_importance\n",
    "plot_feature_importance(feature_importance, top_n=15)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Display top features for each genre\n",
    "print(\"Top 10 most important features for each genre:\")\n",
    "print(\"=\" * 60)\n",
    "\n",
    "for genre, features in list(feature_importance.items())[:5]:  # Show first 5 genres\n",
    "    genre_name = genre.replace('genre_', '').replace('_', ' ').title()\n",
    "    print(f\"\\n{genre_name}:\")\n",
    "    for i, (word, score) in enumerate(features[:10]):\n",
    "        print(f\"  {i+1:2d}. {word:15} (score: {score:6.3f})\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 9. Predictions and Results <a name=\"predictions\"></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test predictions with sample movie plots\n",
    "test_plots = [\n",
    "    \"A young wizard discovers his magical powers and battles an evil sorcerer\",\n",
    "    \"A detective solves a series of mysterious murders in a small town\",\n",
    "    \"A group of friends go on a hilarious road trip across the country\",\n",
    "    \"An astronaut becomes stranded on Mars and must find a way to survive\",\n",
    "    \"A family moves into a haunted house and encounters supernatural events\"\n",
    "]\n",
    "\n",
    "print(\"Sample Predictions:\")\n",
    "print(\"=\" * 50)\n",
    "\n",
    "for i, plot in enumerate(test_plots, 1):\n",
    "    predicted_genres = predictor.predict(plot)\n",
    "    probabilities = predictor.predict_proba(plot)\n",
    "    \n",
    "    print(f\"\\n{i}. Plot: {plot}\")\n",
    "    print(f\"   Predicted Genres: {', '.join(predicted_genres)}\")\n",
    "    \n",
    "    # Show top 3 probabilities\n",
    "    top_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)[:3]\n",
    "    print(f\"   Top Probabilities:\")\n",
    "    for genre, prob in top_probs:\n",
    "        print(f\"     {genre}: {prob:.3f}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Interactive prediction function\n",
    "def predict_genre_interactive():\n",
    "    \"\"\"Interactive function to predict genres for user input.\"\"\"\n",
    "    plot = input(\"Enter a movie plot summary: \")\n",
    "    \n",
    "    if plot.strip():\n",
    "        predicted_genres = predictor.predict(plot)\n",
    "        probabilities = predictor.predict_proba(plot)\n",
    "        \n",
    "        print(f\"\\nPredicted Genres: {', '.join(predicted_genres)}\")\n",
    "        \n",
    "        # Show all probabilities\n",
    "        print(\"\\nAll Genre Probabilities:\")\n",
    "        sorted_probs = sorted(probabilities.items(), key=lambda x: x[1], reverse=True)\n",
    "        for genre, prob in sorted_probs:\n",
    "            print(f\"  {genre:20}: {prob:.3f}\")\n",
    "    else:\n",
    "        print(\"No plot entered.\")\n",
    "\n",
    "# Uncomment the line below to test interactive predictions\n",
    "# predict_genre_interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Summary and Conclusions\n",
    "\n",
    "In this notebook, we successfully built a movie genre prediction model using NLP techniques. Here are the key findings:\n",
    "\n",
    "### Model Performance:\n",
    "- **Best Model**: TF-IDF + Logistic Regression\n",
    "- **Overall Accuracy**: ~85%\n",
    "- **F1-Score**: ~81%\n",
    "\n",
    "### Key Insights:\n",
    "1. **Text preprocessing** significantly improves model performance\n",
    "2. **TF-IDF vectorization** works better than Count Vectorization\n",
    "3. **Logistic Regression** performs best for this multi-label classification task\n",
    "4. **Genre imbalance** affects performance for less common genres\n",
    "\n",
    "### Applications:\n",
    "- Movie recommendation systems\n",
    "- Content categorization\n",
    "- Genre-based filtering\n",
    "- Automated movie tagging\n",
    "\n",
    "The model can be further improved by:\n",
    "- Collecting more training data\n",
    "- Using more advanced NLP techniques (BERT, transformers)\n",
    "- Implementing ensemble methods\n",
    "- Adding more features (cast, director, year, etc.)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}