Skip to content
This repository was archived by the owner on Jan 9, 2025. It is now read-only.

TomyJan/CIFAR10_CNN_Image_Classification

Repository files navigation

CIFAR10 CNN 图像分类

基于CNN(卷积神经网络)的CIFAR10数据集图像分类项目,适合深度学习初学者学习和实践。

项目概述

本项目使用卷积神经网络(CNN)对CIFAR10数据集进行图像分类。CIFAR10是一个包含60000张32x32彩色图像的数据集,共有10个类别,每个类别6000张图像。这些类别包括:飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

主要功能

  • 数据加载和预处理
  • CNN模型构建和训练
  • 模型评估和性能可视化
  • 预测结果分析和混淆矩阵生成

网络架构

项目使用了一个经典的CNN架构,具体包括:

  1. 第一个卷积块:

    • 两个卷积层(32个3x3卷积核)
    • ReLU激活函数
    • 最大池化层
    • Dropout(0.25)防止过拟合
  2. 第二个卷积块:

    • 两个卷积层(64个3x3卷积核)
    • ReLU激活函数
    • 最大池化层
    • Dropout(0.25)防止过拟合
  3. 全连接层:

    • Flatten层
    • Dense层(512个神经元)
    • Dropout(0.5)
    • 输出层(10个类别)

训练参数

  • 批次大小:32
  • 训练轮次:60(使用早停机制,当验证集损失在5个epoch内没有改善时停止训练)
  • 优化器:RMSprop
  • 学习率:0.0001
  • 支持数据增强(可选)
  • 模型保存:自动保存验证集准确率最高的模型

项目结构

├── train_cifar10_cnn.py              # 主要Python代码文件
├── predict.py                         # 使用训练好的模型进行预测
├── requirements.txt                    # 项目依赖文件
├── README.md                          # 项目文档
├── LICENSE                            # 许可证文件
├── saved_models/                      # 保存训练模型的目录
│   └── best_model.keras               # 训练过程中的最佳模型
├── test_images/                       # 测试图片目录
├── 数据分布.png                        # 数据集分布可视化
├── 混淆矩阵.png                        # 模型预测结果混淆矩阵
├── acc_loss.png                       # 训练过程准确率和损失曲线
└── classification_report.txt          # 模型评估报告(包含详细的分类指标)

可视化结果

项目提供了三种可视化结果和一个详细的评估报告:

  1. 数据分布图:显示训练集和测试集中各类别的数据分布情况
  2. 训练过程图:展示模型训练过程中的准确率和损失变化
  3. 混淆矩阵:直观显示模型在各个类别上的预测效果
  4. 分类报告:包含每个类别的精确率、召回率、F1分数等详细指标

环境要求

项目依赖以下主要包:

  • TensorFlow >= 2.16.1
  • NumPy >= 1.26.0
  • Pandas >= 2.2.0
  • Matplotlib >= 3.9.0
  • Seaborn >= 0.13.0
  • Scikit-learn >= 1.5.0

详细的依赖要求请参见 requirements.txt

使用说明

  1. 克隆项目到本地
  2. 安装所需依赖:
    pip install -r requirements.txt
  3. 训练模型(如果需要):
    python train_cifar10_cnn.py
  4. 使用训练好的模型进行预测:
    python predict.py <图片路径1> [图片路径2 ...]

示例图片

项目在 test_images 目录下提供了来自CIFAR10测试集的示例图片:

  • airplane_sample.png:飞机示例
  • automobile_sample.png:汽车示例
  • bird_sample.png:鸟类示例
  • cat_sample.png:猫示例
  • deer_sample.png:鹿示例
  • dog_sample.png:狗示例
  • frog_sample.png:青蛙示例
  • horse_sample.png:马示例
  • ship_sample.png:船示例
  • truck_sample.png:卡车示例

您可以使用这些示例图片来测试模型的预测效果:

# 预测单张图片
python predict.py test_images/airplane_sample.png

# 预测多张图片
python predict.py test_images/airplane_sample.png test_images/cat_sample.png test_images/ship_sample.png

您也可以使用自己的图片进行测试,建议使用与CIFAR10数据集类似的图片(包含完整的目标对象,背景相对简单)以获得更好的预测效果。

预测功能

项目提供了独立的预测脚本 predict.py,具有以下特点:

  • 支持预测单张或多张图片
  • 自动调整输入图片尺寸为32x32
  • 显示预测结果和置信度
  • 可视化预测图片
  • 输出所有类别的概率分布

学习要点

通过本项目,您可以学习:

  • CNN的基本架构和实现方法
  • 图像分类任务的完整处理流程
  • 模型评估和可视化技术
  • 深度学习项目的最佳实践

注意事项

  • 首次运行时会自动下载CIFAR10数据集
  • 训练过程可能需要较长时间
  • 使用了早停机制来防止过拟合
  • 可以通过调整参数来优化模型性能

最近更新

  • 优化了项目结构,删除了冗余文件
  • 更新了依赖包版本要求
  • 改进了数据分布可视化方法
  • 简化了代码结构,提高了可读性
  • 添加了自动保存最佳模型功能
  • 优化了文件命名,使其更加规范和专业
  • 添加了独立的预测脚本,支持使用训练好的模型进行预测

技术原理

卷积和池化操作

最大池化运算原理

最大池化(Max Pooling)是CNN中的一个重要操作,主要用于:

  1. 降维:减少特征图的空间大小,降低计算复杂度
  2. 特征提取:保留区域内最显著的特征
  3. 位置不变性:提高模型对细微位置变化的鲁棒性

本项目中使用的 MaxPooling2D(pool_size=(2, 2)) 的工作原理:

  • 使用2×2的滑动窗口
  • 每次取窗口内4个值中的最大值
  • 步长为2,即不重叠
  • 输出特征图的尺寸减半

例如,对于一个4×4的特征图:

原始特征图:      最大池化后:
1  2  3  4       6  8
5  6  7  8   →   14 16
9  10 11 12
13 14 15 16

特征提取过程

在我们的CNN模型中,特征提取是一个层层递进的过程:

  1. 第一个卷积块:

    • 第一层3×3卷积提取基本特征(如边缘、颜色变化)
    • 第二层3×3卷积组合这些基本特征
    • 最大池化层保留最显著的特征,降维到16×16
  2. 第二个卷积块:

    • 使用更多的卷积核(64个)提取更复杂的特征
    • 再次池化降维到8×8
    • 此时特征更加抽象,可以表示物体的部分结构

关键参数解析

学习率及其影响

学习率是模型训练中最重要的超参数之一,它决定了模型参数更新的步长:

  • 过大的学习率:模型可能难以收敛,在最优解附近震荡
  • 过小的学习率:收敛速度慢,可能陷入局部最优

本项目选择 0.0001 的学习率原因:

  1. 模型较深,需要较小的学习率保持稳定
  2. RMSprop优化器本身会自适应调整学习率
  3. 实验表明该值能在训练速度和稳定性之间取得良好平衡

学习率对比实验结果:

学习率 最终准确率 收敛速度 训练稳定性
0.01 65% 不稳定
0.001 75% 中等 较稳定
0.0001 78% 稳定

损失函数选择

本项目使用 categorical_crossentropy(分类交叉熵)损失函数的原因:

  1. 适用性:

    • 专门用于多分类问题
    • 与softmax输出层完美配合
    • 对于互斥的类别效果最好
  2. 数学原理:

    • 计算预测概率分布与真实分布的差异
    • 鼓励模型对正确类别给出高置信度
    • 惩罚错误预测,特别是高置信度的错误预测
  3. 与其他损失函数的比较:

    损失函数 优点 缺点 适用场景
    Categorical Crossentropy 训练稳定,适合多分类 计算开销较大 多分类问题
    MSE 计算简单 梯度可能消失 回归问题
    Hinge Loss 适合SVM 对异常值敏感 二分类问题

模型调优实验

网络结构参数

  1. 卷积层数量影响:

    • 两个卷积块(当前方案):准确率78%
    • 一个卷积块:准确率70%
    • 三个卷积块:准确率77%,但训练时间增加50%
  2. 卷积核数量影响:

    • 32-64(当前方案):准确率78%
    • 16-32:准确率75%,模型更小
    • 64-128:准确率79%,但显著增加计算量
  3. Dropout率影响:

    Dropout率 训练准确率 测试准确率 过拟合程度
    0.1 85% 76% 严重
    0.25(当前) 82% 78% 轻微
    0.5 78% 77% 很小

优化策略

  1. 批次大小选择:

    • 32(当前方案):训练稳定,内存占用适中
    • 16:训练更慢,但略微提高准确率(+0.5%)
    • 64:训练更快,但准确率略有下降(-1%)
  2. 数据增强影响:

    • 不使用(当前方案):训练更快,准确率78%
    • 使用水平翻转:准确率提升到79%,但训练时间增加30%
    • 使用旋转和缩放:准确率达到80%,但训练时间翻倍

最佳实践建议

根据实验结果,我们建议:

  1. 使用当前的网络结构(2个卷积块)
  2. 保持0.25的Dropout率防止过拟合
  3. 批次大小32在效率和效果间取得平衡
  4. 如果训练时间允许,可以开启数据增强

许可证

本项目采用 MPL-2.0 许可证,详见 LICENSE 文件

本项目代码优化及文档编写完全使用 Cursor 完成, 第一个 commit 为原始代码

About

CIFAR10 图像分类

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages