这是一个基于 PyTorch 的 CIFAR-10 简单示例项目,包含训练、测试、模型结构与可视化工具。此仓库用于教学或实验用途。
Foundation_FinalProject/
├─ checkpoint.pth # 训练后保存的模型权重(示例)
├─ main.py # 项目入口(使用 --train / --test 参数)
├─ train.py # 训练代码(包含 train_model 函数)
├─ test.py # 测试代码(包含 test_model 函数)
├─ models/simple_cnn.py # 简单 CNN 模型定义
├─ utils/plot_utils.py # 可视化工具(plot_loss)
├─ losses.csv # 训练后每个 epoch 的 loss(CSV 文件)
└─ data/ # CIFAR-10 数据(由 torchvision 下载或放置)
- 推荐 Python 3.10 或更高(项目中使用 Python 3.14)。
- 使用虚拟环境(venv)隔离依赖。
- 依赖项已经导出到
requirements.txt(基于当前 venv 的pip freeze)。
- 进入项目目录:
cd D:\Code\PythonProjects\Foundation_FinalProject- 创建并激活虚拟环境(如果尚未创建):
python -m venv .venv
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
.\.venv\Scripts\Activate.ps1- 安装依赖:
pip install -r requirements.txt- 训练模型:
python main.py --train运行后会生成 checkpoint.pth(模型权重)和 losses.csv(包含每个 epoch 的 loss),并显示训练过程中的 epoch loss。
- 测试模型:
python main.py --test如果 checkpoint.pth 存在,test_model 会加载并评估模型并输出测试集准确率。
- 脚本
train.py中的train_model(save_csv=True, verbose=False, save_per_batch=False)支持参数:save_csv=True:保存losses.csv(每个 epoch loss)verbose=True:打印每个 batch 的 losssave_per_batch=True:生成losses_per_batch.csv(每 batch loss)
- 若要直接使用 Pyton 调用训练函数以传入参数:
python - << 'PY'
from train import train_model
train_model(save_csv=True, verbose=True, save_per_batch=True)
PY- 使用
utils.plot_utils.plot_loss(losses)显示训练损失曲线;losses.csv可以用 Excel / Pandas / 任何工具绘图。
- 请避免使用与其他包冲突的通用文件名(例如
test.py),以免导入时被错误解析为其它模块。若发生ImportError/ModuleNotFoundError,请检查当前 Python 解释器与项目路径。 - 如果
matplotlib或torch导入失败,确保虚拟环境已激活并安装了requirements.txt中指定的包。
- 增加更多命令行参数(例如:
--epochs、--batch-size、--lr、--save-batch),方便实验控制与重现。 - 将项目改成标准 Python 包形式(带
setup.py或pyproject.toml),便于复用。
示例代码,仅供学习与实验使用。
这是一个基于 PyTorch 的简单 CIFAR-10 训练/测试示例项目,包含训练、测试、模型结构和可视化工具。
Foundation_FinalProject/
├─ checkpoint.pth # 训练后保存的模型权重(示例)
├─ main.py # 主入口(通过 --train / --test 运行)
├─ train.py # 训练代码(train_model 函数)
├─ test.py # 测试代码(test_model 函数)
├─ models/simple_cnn.py # 简单 CNN 模型定义
├─ utils/plot_utils.py # 可视化函数(plot_loss)
├─ losses.csv # 训练后每个 epoch 的 loss(CSV)
└─ data/ # CIFAR-10 数据(由 torchvision 下载或预先放置)
- Python 3.11+(项目中使用的 venv 上是 Python 3.14,Python 3.10+ 推荐)
- 推荐使用虚拟环境(venv)来隔离依赖
- 依赖项已记录在
requirements.txt(基于当前 venv 的pip freeze)
- 进入项目目录:
cd D:\Code\PythonProjects\Foundation_FinalProject- 创建并激活虚拟环境(若尚未创建):
python -m venv .venv
Set-ExecutionPolicy -Scope Process -ExecutionPolicy Bypass
.\.venv\Scripts\Activate.ps1- 安装依赖:
pip install -r requirements.txt- 训练模型:
python main.py --train训练完成后会生成 checkpoint.pth(模型权重)和 losses.csv(每个 epoch 的平均 loss)。同时会弹出训练 loss 的图像。
- 测试模型:
python main.py --test如果 checkpoint.pth 可用,test.py 将加载并评估模型,打印测试集准确率。
- 可通过 Python 直接调用
train.train_model(save_csv=True, verbose=True, save_per_batch=True)记录更多信息。 utils/plot_utils.plot_loss(losses)可用于绘制训练曲线。- 若要保存更详细的批次级 loss,请在
train_model中启用save_per_batch=True。
- 避免使用仓库中通用或与库冲突的文件名(例如
test.py),因为可能会与其他包或脚本产生混淆。 - 如果你在运行时遇到
ImportError或ModuleNotFoundError:- 检查当前解释器是否指向项目虚拟环境(VS Code 左下角或
python --version/sys.executable)。 - 检查
__init__.py是否存在于utils/(已包含一个空的__init__.py)。
- 检查当前解释器是否指向项目虚拟环境(VS Code 左下角或
- 若
plot_utils不工作,请确认matplotlib已正确安装并激活在当前 venv 中。
- 将
train.py中的训练循环参数化(比如 epochs、batch_size、lr)并通过argparse暴露,以便从main.py调整实验参数。 - 添加 CLI 标志(如
--save-batch、--verbose)到main.py,以便更灵活地控制训练输出与 CSV 保存。
本仓库为示例/作业性质代码,请按照学校/实验室要求管理代码与数据。