In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Segger Model Training\n",
    "\n",
    "This notebook demonstrates how to train the Segger model on spatial transcriptomics data."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 1: Setup and Environment\n",
    "\n",
    "First, we set up the environment by importing necessary libraries and ensuring that required directories exist."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "from pathlib import Path\n",
    "import torch\n",
    "import lightning as L\n",
    "from torch_geometric.loader import DataLoader\n",
    "from segger.data.utils import XeniumDataset\n",
    "from segger.models.segger_model import Segger\n",
    "from segger.training.train import LitSegger\n",
    "\n",
    "# Ensure PyGEOS is not used\n",
    "os.environ[\"USE_PYGEOS\"] = \"0\"\n",
    "os.environ[\"PYTORCH_USE_CUDA_DSA\"] = \"1\"\n",
    "os.environ[\"CUDA_LAUNCH_BLOCKING\"] = \"1\"\n",
    "\n",
    "# Add the src directory to the Python path\n",
    "sys.path.insert(0, os.path.abspath('../../src'))\n",
    "\n",
    "# Define the data directory paths\n",
    "TRAIN_DIR = Path('data_tidy/pyg_datasets/MNG_N173116IA/train_tiles/processed')\n",
    "VAL_DIR = Path('data_tidy/pyg_datasets/MNG_N173116IA/val_tiles/processed')\n",
    "\n",
    "# Data params\n",
    "DATA_CHUNK_SIZE = 20\n",
    "BATCH_SIZE_TRAIN = 4\n",
    "BATCH_SIZE_VAL = 4\n",
    "\n",
    "# Trainer params\n",
    "EPOCHS = 100\n",
    "ACCELERATOR = \"cuda\"\n",
    "STRATEGY = 'auto'\n",
    "PRECISION = \"16-mixed\"\n",
    "DEVICES = 4\n",
    "DEFAULT_ROOT_DIR = \"./models/MNG_big\"\n",
    "\n",
    "# Create model directory\n",
    "os.makedirs(DEFAULT_ROOT_DIR, exist_ok=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Load and Process Data\n",
    "\n",
    "Load the datasets for training and validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load datasets\n",
    "xe_train_ds = XeniumDataset(root=TRAIN_DIR)\n",
    "xe_val_ds = XeniumDataset(root=VAL_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 3: Initialize and Train the Model\n",
    "\n",
    "Initialize the Segger model and the Lightning trainer, then train the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialize model and trainer\n",
    "model = Segger(init_emb=8, hidden_channels=64, out_channels=16, heads=4)\n",
    "model = to_hetero(model, (['tx', 'nc'], [('tx', 'belongs', 'nc'), ('tx', 'neighbors', 'tx')]), aggr='sum')\n",
    "\n",
    "litsegger = LitSegger(model)\n",
    "trainer = L.Trainer(\n",
    "    accelerator=ACCELERATOR,\n",
    "    strategy=STRATEGY,\n",
    "    precision=PRECISION,\n",
    "    devices=DEVICES,\n",
    "    max_epochs=EPOCHS,\n",
    "    default_root_dir=DEFAULT_ROOT_DIR,\n",
    "    # callbacks=[EarlyStopping(monitor=\"train_loss\", mode=\"min\")]\n",
    ")\n",
    "\n",
    "# Train model\n",
    "train_loader = DataLoader(xe_train_ds, batch_size=BATCH_SIZE_TRAIN, num_workers=0, pin_memory=True, shuffle=True)\n",
    "val_loader = DataLoader(xe_val_ds, batch_size=BATCH_SIZE_VAL, num_workers=0, pin_memory=True, shuffle=True)\n",
    "trainer.fit(litsegger, train_loader, val_loader)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
