In [None]:
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# KiTS23 肾脏肿瘤分割 - Google Colab训练脚本\n",
    "\n",
    "本笔记本用于在Google Colab上训练AMSFF模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 检查GPU是否可用\n",
    "import torch\n",
    "print(f\"GPU是否可用: {torch.cuda.is_available()}\")\n",
    "if torch.cuda.is_available():\n",
    "    print(f\"当前GPU: {torch.cuda.get_device_name(0)}\")\n",
    "    print(f\"GPU数量: {torch.cuda.device_count()}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 安装必要的依赖\n",
    "!pip install nibabel numpy pandas scikit-learn tensorboard"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 挂载Google Drive\n",
    "from google.colab import drive\n",
    "drive.mount('/content/drive')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 设置工作目录\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "# 设置数据目录和输出目录\n",
    "DRIVE_PATH = '/content/drive/MyDrive/kits23'  # 请根据实际情况修改\n",
    "DATA_DIR = os.path.join(DRIVE_PATH, 'data')\n",
    "OUTPUT_DIR = os.path.join(DRIVE_PATH, 'output')\n",
    "\n",
    "# 创建必要的目录\n",
    "os.makedirs(DATA_DIR, exist_ok=True)\n",
    "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
    "\n",
    "print(f\"数据目录: {DATA_DIR}\")\n",
    "print(f\"输出目录: {OUTPUT_DIR}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 导入必要的模块\n",
    "import sys\n",
    "import json\n",
    "import logging\n",
    "from datetime import datetime\n",
    "from collections import defaultdict\n",
    "\n",
    "# 添加项目根目录到Python路径\n",
    "PROJECT_ROOT = '/content/kits23_segmentation'\n",
    "sys.path.append(PROJECT_ROOT)\n",
    "\n",
    "# 导入项目模块\n",
    "from kits23_segmentation.models.amsff_net import AdaptiveMultiScaleFeatureFusionNet\n",
    "from kits23_segmentation.data.data_loader import get_data_loader\n",
    "from kits23_segmentation.training.trainer import AMSFFTrainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 设置训练配置\n",
    "config = {\n",
    "    'data_dir': DATA_DIR,\n",
    "    'output_dir': OUTPUT_DIR,\n",
    "    'batch_size': 4,\n",
    "    'num_workers': 2,\n",
    "    'mixed_precision': True,\n",
    "    'max_epochs': 300,\n",
    "    'learning_rate': 1e-4,\n",
    "    'weight_decay': 1e-5,\n",
    "    'patch_size': [64, 64, 32],\n",
    "    'initial_channels': 32,\n",
    "    'depth': 4,\n",
    "    'experiment_name': f'amsff_{datetime.now().strftime(\"%Y%m%d_%H%M%S\")}'\n",
    "}\n",
    "\n",
    "# 保存配置\n",
    "config_path = os.path.join(OUTPUT_DIR, 'config.json')\n",
    "with open(config_path, 'w') as f:\n",
    "    json.dump(config, f, indent=2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 设置设备\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f\"使用设备: {device}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 创建数据加载器\n",
    "train_loader = get_data_loader(\n",
    "    data_dir=config['data_dir'],\n",
    "    batch_size=config['batch_size'],\n",
    "    num_workers=config['num_workers'],\n",
    "    is_training=True\n",
    ")\n",
    "\n",
    "val_loader = get_data_loader(\n",
    "    data_dir=config['data_dir'],\n",
    "    batch_size=config['batch_size'],\n",
    "    num_workers=config['num_workers'],\n",
    "    is_training=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 创建模型\n",
    "model = AdaptiveMultiScaleFeatureFusionNet(\n",
    "    in_channels=1,\n",
    "    num_classes=3,\n",
    "    initial_channels=config['initial_channels'],\n",
    "    depth=config['depth']\n",
    ")\n",
    "\n",
    "# 将模型移动到设备\n",
    "model = model.to(device)\n",
    "\n",
    "# 打印模型信息\n",
    "total_params = sum(p.numel() for p in model.parameters())\n",
    "trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f\"模型总参数量: {total_params:,}\")\n",
    "print(f\"可训练参数量: {trainable_params:,}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 创建训练器\n",
    "trainer = AMSFFTrainer(\n",
    "    model=model,\n",
    "    train_loader=train_loader,\n",
    "    val_loader=val_loader,\n",
    "    config=config,\n",
    "    device=device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 开始训练\n",
    "history, efficiency_metrics = trainer.train_with_efficiency_monitoring()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 绘制训练曲线\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "plt.figure(figsize=(12, 4))\n",
    "\n",
    "# 损失曲线\n",
    "plt.subplot(1, 2, 1)\n",
    "plt.plot(history['train_loss'], label='训练损失')\n",
    "plt.plot(history['val_loss'], label='验证损失')\n",
    "plt.title('损失曲线')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('损失')\n",
    "plt.legend()\n",
    "\n",
    "# Dice分数曲线\n",
    "plt.subplot(1, 2, 2)\n",
    "plt.plot(history['train_dice_tumor'], label='训练Dice')\n",
    "plt.plot(history['val_dice_tumor'], label='验证Dice')\n",
    "plt.title('Dice分数曲线')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Dice分数')\n",
    "plt.legend()\n",
    "\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 保存训练结果\n",
    "results = {\n",
    "    'history': history,\n",
    "    'efficiency_metrics': efficiency_metrics,\n",
    "    'config': config\n",
    "}\n",
    "\n",
    "results_path = os.path.join(OUTPUT_DIR, 'training_results.json')\n",
    "with open(results_path, 'w') as f:\n",
    "    json.dump(results, f, indent=2)\n",
    "\n",
    "print(f\"训练结果已保存到: {results_path}\")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "KiTS23训练脚本",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "name": "python3"
  },
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}