In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 🤖 LSTM Flood Forecasting Model\n",
    "### 7-day ahead water level prediction\n",
    "\n",
    "This notebook demonstrates:\n",
    "- LSTM architecture for time series forecasting\n",
    "- Training with discharge and water level data\n",
    "- Model evaluation and validation\n",
    "- 7-day forecast generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('..')\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from torch.utils.data import DataLoader\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from src.config import Config\n",
    "from src.data_loader import ERA5DataLoader\n",
    "from src.models import FloodDataset, LSTMModel, FloodForecaster\n",
    "\n",
    "plt.style.use('seaborn-v0_8-darkgrid')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1. Prepare Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load data\n",
    "station_key = 'koshi_chatara'\n",
    "station_config = Config.TARGET_STATIONS[station_key]\n",
    "\n",
    "loader = ERA5DataLoader(station_config, years_back=5)\n",
    "data = loader.fetch_era5_data()\n",
    "\n",
    "# Use last year for training\n",
    "train_data = data[-8760:]  # Last 365 days\n",
    "\n",
    "print(f\"Training data: {len(train_data):,} hours\")\n",
    "print(f\"Station: {station_config['name']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2. Create Sequences"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create dataset with lookback=168h (7 days), forecast=168h (7 days)\n",
    "dataset = FloodDataset(train_data, lookback=168, forecast=168, stride=6)\n",
    "\n",
    "print(f\"\\nDataset created:\")\n",
    "print(f\"Total sequences: {len(dataset)}\")\n",
    "print(f\"Input shape: {dataset.sequences.shape}\")\n",
    "print(f\"Target shape: {dataset.targets.shape}\")\n",
    "print(f\"\\nNormalization stats:\")\n",
    "print(f\"Discharge - Mean: {dataset.discharge_mean:.2f}, Std: {dataset.discharge_std:.2f}\")\n",
    "print(f\"Water Level - Mean: {dataset.level_mean:.2f}, Std: {dataset.level_std:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 3. Train/Validation Split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 80/20 split\n",
    "train_size = int(0.8 * len(dataset))\n",
    "val_size = len(dataset) - train_size\n",
    "\n",
    "train_dataset, val_dataset = torch.utils.data.random_split(\n",
    "    dataset, [train_size, val_size]\n",
    ")\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)\n",
    "val_loader = DataLoader(val_dataset, batch_size=32)\n",
    "\n",
    "print(f\"Train set: {train_size} sequences\")\n",
    "print(f\"Validation set: {val_size} sequences\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4. Train LSTM Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize and train\n",
    "forecaster = FloodForecaster()\n",
    "\n",
    "print(\"Training LSTM model...\")\n",
    "print(f\"Device: {forecaster.device}\")\n",
    "print(f\"Model parameters: {sum(p.numel() for p in forecaster.model.parameters()):,}\\n\")\n",
    "\n",
    "best_loss = forecaster.train_silent(train_loader, val_loader, epochs=30)\n",
    "\n",
    "print(f\"\\nTraining complete!\")\n",
    "print(f\"Best validation loss: {best_loss:.6f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 5. Generate Forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use recent data for forecast\n",
    "lookback = 168 // 6\n",
    "input_discharge = data['discharge_cumecs'].iloc[-lookback:].values\n",
    "input_level = data['water_level_m'].iloc[-lookback:].values\n",
    "\n",
    "# Normalize\n",
    "input_norm = np.stack([\n",
    "    (input_discharge - dataset.discharge_mean) / (dataset.discharge_std + 1e-8),\n",
    "    (input_level - dataset.level_mean) / (dataset.level_std + 1e-8)\n",
    "], axis=-1)\n",
    "\n",
    "# Generate prediction\n",
    "prediction = forecaster.predict(input_norm, dataset)\n",
    "\n",
    "print(\"7-Day Forecast Generated!\")\n",
    "print(f\"Predicted max discharge: {prediction[:, 0].max():.2f} m³/s\")\n",
    "print(f\"Predicted max water level: {prediction[:, 1].max():.2f} m\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 6. Visualize Forecast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Plot results\n",
    "fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(14, 8))\n",
    "\n",
    "forecast_steps = np.arange(len(prediction))\n",
    "\n",
    "# Discharge forecast\n",
    "ax1.plot(forecast_steps, prediction[:, 0], 'o-', color='#dc143c', linewidth=2, markersize=5)\n",
    "ax1.set_title('7-Day Discharge Forecast', fontsize=12, weight='bold')\n",
    "ax1.set_ylabel('Discharge (m³/s)')\n",
    "ax1.grid(True, alpha=0.3)\n",
    "\n",
    "# Water level forecast\n",
    "ax2.plot(forecast_steps, prediction[:, 1], 'o-', color='#003893', linewidth=2, markersize=5)\n",
    "ax2.axhline(station_config['flood_stage'], color='orange', linestyle='--', label='Flood Stage')\n",
    "ax2.axhline(station_config['moderate_flood'], color='red', linestyle='--', label='Moderate Flood')\n",
    "ax2.set_title('7-Day Water Level Forecast', fontsize=12, weight='bold')\n",
    "ax2.set_xlabel('Forecast Step (6-hour intervals)')\n",
    "ax2.set_ylabel('Water Level (m)')\n",
    "ax2.legend()\n",
    "ax2.grid(True, alpha=0.3)\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ✅ Model Summary\n",
    "\n",
    "- **Architecture**: 2-layer LSTM with 128 hidden units\n",
    "- **Input**: 7-day history (discharge + water level)\n",
    "- **Output**: 7-day forecast (28 steps at 6-hour intervals)\n",
    "- **Training**: Early stopping with validation monitoring\n",
    "- **Performance**: Model ready for real-time forecasting"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}