# Colab 环境准备与训练
依次运行下面的单元，在 Google Colab 中准备运行环境并启动 `train.py`。

In [None]:
# 环境准备：检测 GPU、升级 pip、克隆仓库并安装依赖
import pathlib
import subprocess
import sys

try:
    subprocess.run(["nvidia-smi"], check=True)
except FileNotFoundError:
    print("nvidia-smi 未找到，当前环境可能不支持 GPU")
except subprocess.CalledProcessError:
    print("nvidia-smi 调用失败，GPU 可能不可用")

REPO_URL = "https://github.com/Daethalous/AI27_summarization.git"
REPO_NAME = "AI27_summarization"
repo_path = pathlib.Path(REPO_NAME)

if repo_path.exists():
    print(f"仓库已存在于 {repo_path.resolve()}，尝试拉取最新代码...")
    subprocess.run(["git", "-C", str(repo_path), "pull"], check=True)
else:
    print(f"正在克隆仓库: {REPO_URL}")
    subprocess.run(["git", "clone", REPO_URL], check=True)

subprocess.run([sys.executable, "-m", "pip", "install", "--upgrade", "pip"], check=True)

requirements_file = repo_path / "requirements.txt"
if not requirements_file.exists():
    raise FileNotFoundError(f"未找到依赖文件: {requirements_file}")
print(f"安装依赖: {requirements_file}")
subprocess.run([sys.executable, "-m", "pip", "install", "-r", str(requirements_file)], check=True)

import torch
print(f"PyTorch 版本: {torch.__version__}")
print(f"CUDA 是否可用: {torch.cuda.is_available()}")

CalledProcessError: Command '['git', 'clone', 'https://github.com/<your-account>/summarization.git', '/content/summarization']' returned non-zero exit status 128.

## 启动训练
在完成环境准备后，直接导入 `train.py` 的入口函数并启动训练。可根据需要调整配置或启用 Tiny 模式进行快速验证。

In [None]:
# 运行训练脚本
import argparse
import os
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd() / "AI27_summarization"
if not PROJECT_ROOT.exists():
    raise FileNotFoundError("未找到 AI27_summarization 仓库，请先运行环境准备单元。")

os.chdir(PROJECT_ROOT)
sys.path.insert(0, str(PROJECT_ROOT / "src"))

from train import main

args = argparse.Namespace(
    config="configs/seq2seq_attn.yaml",
    data_dir=None,
    vocab_path=None,
    save_dir="./checkpoints",
    batch_size=None,
    max_src_len=None,
    max_tgt_len=None,
    epochs=None,
    lr=None,
    num_workers=2,
    dataset_version=None,
    auto_download=True,
    no_auto_download=False,
    )

main(args)