Skip to content

MichaelMaxAgent/MiniSFT

Repository files navigation

MiniSFT

一个轻量级的 T5 模型监督微调(SFT)训练框架

Python PyTorch Transformers License

📖 简介

MiniSFT 是一个基于 T5 模型的轻量级监督微调(Supervised Fine-Tuning)训练框架。该项目提供了完整的数据预处理、模型训练和推理功能,适合用于构建小型对话模型或文本生成任务。

✨ 主要特性

  • 🚀 简单易用:提供完整的训练和推理流程
  • 🎯 模块化设计:清晰的代码结构,易于扩展
  • 💬 对话支持:支持流式和非流式对话
  • 📊 数据预处理:内置数据生成和预处理工具
  • 🔧 灵活配置:通过配置文件轻松调整参数
  • 🎨 命令行界面:提供友好的 CLI 交互界面

📋 目录

🔧 安装

环境要求

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA 11.0+ (推荐使用 GPU)

安装步骤

  1. 克隆仓库
git clone https://github.com/yourusername/MiniSFT.git
cd MiniSFT
  1. 安装依赖
pip install -r requirements.txt

🚀 快速开始

1. 准备数据

将您的原始文本放入 input.txt 文件中,然后运行文本分割工具:

python cut.py

这将把长文本分割成多个较短的片段,保存在 ./txt 目录下。

2. 生成训练数据

使用数据预处理工具生成问答对:

python data_preprocessing.py

生成的训练数据将保存在 ./data/sft_train.json 文件中。

注意: data_preprocessing.py 需要配置 API 密钥才能使用。请在代码中修改 api_key 变量。

3. 训练模型

修改 config.py 中的配置参数,然后运行训练脚本:

python sft_train.py

训练日志将保存在 ./logs 目录,模型检查点将保存在 ./model_save/sft 目录。

4. 运行对话

训练完成后,使用 CLI 界面与模型对话:

python cli_demo.py

命令说明:

  • 输入 exit - 退出程序
  • 输入 cls - 清屏

📁 项目结构

MiniSFT/
├── model/                      # 模型相关代码
│   ├── chat_model.py          # T5 模型封装
│   └── infer.py               # 推理逻辑
├── data/                       # 数据目录
│   └── sft_train.json         # 训练数据
├── model_save/                 # 模型保存目录
│   └── sft/                   # SFT 训练后的模型
├── logs/                       # 训练日志
├── txt/                        # 文本分割后的文件
├── config.py                   # 配置文件
├── sft_train.py               # 训练脚本
├── cli_demo.py                # 命令行对话界面
├── data_preprocessing.py      # 数据预处理脚本
├── cut.py                     # 文本分割工具
├── input.txt                  # 原始输入文本
├── requirements.txt           # 依赖包列表
└── README.md                  # 项目说明文档

📚 使用指南

数据格式

训练数据采用 JSON Lines 格式,每条数据包含 promptresponse 字段:

{
  "prompt": "什么是深度学习?",
  "response": "深度学习是机器学习的一个分支,它使用多层神经网络来学习数据的表示。"
}

模型训练

训练配置在 config.pySFTconfig 类中定义:

@dataclass
class SFTconfig:
    max_seq_len: int = 384 + 8
    batch_size: int = 19
    num_train_epochs: int = 600
    learning_rate: float = 5e-5
    # ... 更多配置

模型推理

推理配置在 config.pyInferConfig 类中定义:

@dataclass
class InferConfig:
    max_seq_len: int = 320
    mixed_precision: str = "bf16"
    model_dir: str = PROJECT_ROOT + '/model_save/sft/'
    # ... 更多配置

生成策略

模型支持多种生成策略,在 model/chat_model.py:15 中的 my_generate 方法中配置:

  • greedy: 贪婪搜索(最快)
  • beam: Beam Search(质量较好)
  • sampling: 采样生成(多样性)
  • contrastive: 对比搜索

⚙️ 配置说明

T5 模型配置

config.pyT5ModelConfig 类中定义模型架构:

@dataclass
class T5ModelConfig:
    d_ff: int = 3072              # 全连接层维度
    d_model: int = 768            # 词向量维度
    num_heads: int = 12           # 注意力头数
    d_kv: int = 64                # 每个头的维度
    num_decoder_layers: int = 10  # 解码器层数
    num_layers: int = 10          # 编码器层数

训练参数

主要训练参数说明:

参数 说明 默认值
batch_size 每个设备的批次大小 19
gradient_accumulation_steps 梯度累积步数 4
learning_rate 学习率 5e-5
num_train_epochs 训练轮数 600
fp16 是否使用混合精度训练 True
warmup_steps 预热步数 100

🤝 贡献指南

欢迎提交 Issue 和 Pull Request!

  1. Fork 本仓库
  2. 创建您的特性分支 (git checkout -b feature/AmazingFeature)
  3. 提交您的更改 (git commit -m 'Add some AmazingFeature')
  4. 推送到分支 (git push origin feature/AmazingFeature)
  5. 开启 Pull Request

📝 许可证

本项目采用 MIT 许可证。详见 LICENSE 文件。

🙏 致谢

📧 联系方式

如有问题或建议,欢迎通过以下方式联系:

⭐ Star History

如果这个项目对您有帮助,欢迎给个 Star!


Made with ❤️ by MiniSFT Team

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages