In [17]:
# This cell has been split into multiple cells for better readability and a tutorial-like structure. Please refer to the new cells.

## Complete CNN Image Classification Implementation

This project demonstrates a complete Convolutional Neural Network (CNN) implementation for image classification, built using TensorFlow and OpenCV. It is structured into three main modules:

1.  **`data_preprocess.py`**: Handles all data preparation steps, including renaming images, resizing them to a uniform dimension, and converting image metadata into a CSV format suitable for training.
2.  **`train.py`**: Contains the CNN model definition and the training loop. It loads preprocessed data, builds a multi-layer CNN, and trains it to classify images.
3.  **`predict.py`**: Implements the image recognition functionality. It loads the trained model and uses it to predict categories for new images or real-time video streams.

### 1. Data Preprocessing Module (`data_preprocess.py`)

This module is responsible for preparing the image dataset for model training. It includes functionalities to:

-   **Rename images**: Standardize filenames to a `category-index.jpg` format.
-   **Resize images**: Ensure all images have a consistent dimension (e.g., 200x200 pixels).
-   **Generate CSV**: Create a CSV file that maps image paths to their respective labels, including one-hot encoded labels, which is essential for training the CNN.

In [10]:
import os
import cv2
import pandas as pd
import numpy as np


class DataPreprocess:
    def __init__(self, image_path="./picture/", csv_path="./"):
        self.image_path = image_path
        self.csv_path = csv_path

    def rename(self):
        """
        数据图片重命名
        将数据集按照 类别-序号.jpg 格式重命名
        """
        print("开始重命名图片...")
        listdir = os.listdir(self.image_path)
        i = 0

        while i < len(listdir):
            category_path = os.path.join(self.image_path, listdir[i])
            if not os.path.isdir(category_path):
                i += 1
                continue

            images_list_dir = os.listdir(category_path)
            j = 0

            while j < len(images_list_dir):
                old_name = os.path.join(self.image_path, listdir[i], images_list_dir[j])
                new_name = os.path.join(self.image_path, "%d-%d.jpg" % (i, j))

                try:
                    os.rename(old_name, new_name)
                    print(f"重命名: {old_name} -> {new_name}")
                except Exception as e:
                    print(f"重命名失败: {e}")

                j += 1
            i += 1

        # 删除空的分类文件夹
        for p in range(len(listdir)):
            tmp_path = os.path.join(self.image_path, listdir[p])
            if os.path.exists(tmp_path) and os.path.isdir(tmp_path):
                try:
                    os.removedirs(tmp_path)
                    print(f"删除空文件夹: {tmp_path}")
                except:
                    pass

        print("图片重命名完成！")

    def resize_img(self, target_size=(200, 200)):
        """
        统一图片尺寸为 200x200
        """
        print(f"开始调整图片尺寸为 {target_size}...")
        listdir = os.listdir(self.image_path)
        success_count = 0
        fail_count = 0

        for file in listdir:
            file_path = os.path.join(self.image_path, file)

            if not file.endswith(('.jpg', '.jpeg', '.png')):
                continue

            try:
                imread = cv2.imread(file_path)
                if imread is None:
                    print(f"无法读取图片: {file_path}")
                    os.remove(file_path)
                    fail_count += 1
                    continue

                resize = cv2.resize(imread, target_size)
                cv2.imwrite(file_path, resize)
                success_count += 1
                print(f"处理成功: {file} ({success_count}/{len(listdir)})")

            except Exception as e:
                print(f"处理失败: {file}, 错误: {e}")
                try:
                    os.remove(file_path)
                except:
                    pass
                fail_count += 1

        print(f"图片尺寸调整完成！成功: {success_count}, 失败: {fail_count}")

    def train_data_to_csv(self, csv_name="train.csv"):
        """
        转存图片信息到CSV文件
        格式：path, label, label_0, label_1, label_2, label_3, label_4
        """
        print("开始生成CSV文件...")
        files = os.listdir(self.image_path)
        data = []

        for file in files:
            if not file.endswith(('.jpg', '.jpeg', '.png')):
                continue

            # 从文件名提取标签 (格式: 0-1.jpg)
            label = file.split('-')[0]
            file_path = os.path.join(self.image_path, file)

            data.append({
                "path": file_path,
                "label": label
            })

        # 创建DataFrame
        frame = pd.DataFrame(data, columns=['path', 'label'])

        # One-hot编码
        dummies = pd.get_dummies(frame['label'], prefix='label')

        # 合并数据
        concat = pd.concat([frame, dummies], axis=1)

        # 保存为CSV
        csv_file = os.path.join(self.csv_path, csv_name)
        concat.to_csv(csv_file, index=False)

        print(f"CSV文件已生成: {csv_file}")
        print(f"总共处理了 {len(data)} 张图片")
        print("\nCSV文件预览:")
        print(concat.head())

        return csv_file

    def run_all(self):
        """
        执行所有预处理步骤
        """
        print("=" * 60)
        print("开始数据预处理流程")
        print("=" * 60)

        # 1. 重命名
        self.rename()

        # 2. 调整尺寸
        self.resize_img()

        # 3. 生成CSV
        self.train_data_to_csv()

        print("=" * 60)
        print("数据预处理完成！")
        print("=" * 60)


#### **Usage of Data Preprocessing**

To prepare your dataset, you first need to place your images in a directory (e.g., `./picture/` organized by subfolders for each category). Then, instantiate the `DataPreprocess` class and call its `run_all()` method.

In [11]:
# Example: Initialize and run data preprocessing
# Make sure to create a 'picture' directory and populate it with image subfolders like '0', '1', etc.
# For example:
# ./picture/0/image001.jpg
# ./picture/1/image002.jpg

# You might need to create dummy image files for this to run without error
# if not os.path.exists('./picture/0'):
#     os.makedirs('./picture/0')
# if not os.path.exists('./picture/1'):
#     os.makedirs('./picture/1')
# # Example: create a dummy image file
# cv2.imwrite('./picture/0/test_image_0.jpg', np.zeros((100,100,3), dtype=np.uint8))
# cv2.imwrite('./picture/1/test_image_1.jpg', np.zeros((100,100,3), dtype=np.uint8))

# preprocessor = DataPreprocess()
# preprocessor.run_all()


### 2. Model Training Module (`train.py`)

This module defines, trains, and saves the CNN model. It includes:

-   **Data Loading**: Reads the preprocessed CSV file and splits data into training and testing sets.
-   **Batch Generation**: Provides methods to load images and their corresponding one-hot encoded labels in batches during training.
-   **CNN Architecture**: Defines a multi-layer CNN with convolutional layers, max-pooling, batch normalization, and fully connected layers.
-   **Training Loop**: Manages the training process, including epoch iterations, loss calculation, accuracy evaluation, and model saving.

In [12]:
import tensorflow as tf
import pandas as pd
import numpy as np
import cv2
import os


class CNNModel:
    def __init__(self, csv_path="./train.csv", image_path="./picture/"):
        self.csv_path = csv_path
        self.image_path = image_path
        self.batch_size = 16
        self.start = 0
        self.has_next_batch = True

        # GPU配置
        self.config = tf.ConfigProto()
        self.config.gpu_options.allow_growth = True

        # 加载数据
        self.load_data()

        # 构建模型
        self.build_model_graph()

    def load_data(self):
        """
        加载CSV数据
        """
        print("加载数据集...")
        df = pd.read_csv(self.csv_path)

        # 分离训练集和测试集 (80/20)
        msk = np.random.rand(len(df)) < 0.8
        self.train_data = df[msk]
        self.test_data = df[~msk]

        self.batches = len(self.train_data) // self.batch_size

        print(f"训练集大小: {len(self.train_data)}")
        print(f"测试集大小: {len(self.test_data)}")
        print(f"批次数量: {self.batches}")

    def next_batch(self):
        """
        获取下一批训练数据
        """
        if self.start + self.batch_size > len(self.train_data):
            self.has_next_batch = False
            return None, None

        batch_data = self.train_data[self.start:self.start + self.batch_size]
        self.start += self.batch_size

        # 读取图片
        images = []
        labels = []

        for _, row in batch_data.iterrows():
            img = cv2.imread(row['path'])
            img = cv2.resize(img, (200, 200))
            img = np.asarray(img, np.float32) / 255.0
            images.append(img)

            # 获取one-hot标签
            label = [row[f'label_{i}'] for i in range(5)]
            labels.append(label)

        return np.array(images), np.array(labels)

    def get_test_data(self):
        """
        获取测试数据
        """
        images = []
        labels = []

        for _, row in self.test_data.iterrows():
            img = cv2.imread(row['path'])
            img = cv2.resize(img, (200, 200))
            img = np.asarray(img, np.float32) / 255.0
            images.append(img)

            label = [row[f'label_{i}'] for i in range(5)]
            labels.append(label)

        return np.array(images), np.array(labels)

    def build_model_graph(self):
        """
        构建CNN模型
        """
        print("构建CNN模型...")

        with tf.name_scope("input"):
            self.x = tf.placeholder(tf.float32, [None, 200, 200, 3], "x")
            self.y = tf.placeholder(tf.float32, [None, 5], "y")

        # 卷积层1
        with tf.variable_scope("conv_layer_1"):
            conv1 = tf.layers.conv2d(self.x, 64, [3, 3], activation=tf.nn.relu, name='conv1')
            max1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
            bn1 = tf.layers.batch_normalization(max1, name='bn1')
            output1 = tf.layers.dropout(bn1, rate=0.5, name='dropout')

        # 卷积层2
        with tf.variable_scope("conv_layer_2"):
            conv2 = tf.layers.conv2d(output1, 64, [3, 3], activation=tf.nn.relu, name='conv2')
            max2 = tf.layers.max_pooling2d(conv2, [2, 2], [2, 2], name='max2')
            bn2 = tf.layers.batch_normalization(max2)
            output2 = tf.layers.dropout(bn2, rate=0.5, name='dropout')

        # 卷积层3
        with tf.variable_scope("conv_layer_3"):
            conv3 = tf.layers.conv2d(output2, 64, [3, 3], activation=tf.nn.relu, name='conv3')
            max3 = tf.layers.max_pooling2d(conv3, [2, 2], [2, 2], name='max3')
            bn3 = tf.layers.batch_normalization(max3, name='bn3')
            output3 = bn3

        # 卷积层4
        with tf.variable_scope("conv_layer_4"):
            conv4 = tf.layers.conv2d(output3, 32, [3, 3], activation=tf.nn.relu, name='conv4')
            max4 = tf.layers.max_pooling2d(conv4, [2, 2], [2, 2], name='max4')
            bn4 = tf.layers.batch_normalization(max4, name='bn4')
            output = bn4
            flatten = tf.layers.flatten(output, 'flatten')

        # 全连接层1
        with tf.variable_scope("fc_layer1"):
            fc1 = tf.layers.dense(flatten, 256, activation=tf.nn.relu)
            fc_bn1 = tf.layers.batch_normalization(fc1, name='bn1')
            dropout1 = tf.layers.dropout(fc_bn1, rate=0.5)

        # 全连接层2
        with tf.variable_scope("fc_layer2"):
            fc2 = tf.layers.dense(dropout1, 128, activation=tf.nn.relu)
            dropout2 = tf.layers.dropout(fc2, rate=0.5)

        # 全连接层3
        with tf.variable_scope("fc_layer3"):
            fc3 = tf.layers.dense(dropout2, 64)
            dropout3 = tf.layers.dropout(fc3, rate=0.5)

        # 全连接层4
        with tf.variable_scope("fc_layer4"):
            fc4 = tf.layers.dense(dropout3, 32)

        # 全连接层5 (输出层)
        with tf.variable_scope("fc_layer5"):
            fc5 = tf.layers.dense(fc4, 5)

        # 预测和损失
        self.softmax = tf.nn.softmax(fc5, name='softmax')
        self.predict = tf.argmax(self.softmax, axis=1)
        self.loss = tf.reduce_mean(
            tf.nn.softmax_cross_entropy_with_logits_v2(logits=fc5, labels=self.y, name='loss')
        )

        # 准确率
        correct_prediction = tf.equal(self.predict, tf.argmax(self.y, axis=1))
        self.accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        # 优化器
        self.opt = tf.train.AdamOptimizer(learning_rate=0.001).minimize(self.loss)

        # TensorBoard
        tf.summary.scalar("loss", self.loss)
        tf.summary.scalar("accuracy", self.accuracy)
        self.merged = tf.summary.merge_all()

        print("模型构建完成！")

    def train(self, epochs=100, save_interval=5):
        """
        训练模型
        """
        print("=" * 60)
        print("开始训练模型")
        print("=" * 60)

        saver = tf.train.Saver(max_to_keep=3)
        sess = tf.InteractiveSession(config=self.config)
        sess.run(tf.global_variables_initializer())
        writer = tf.summary.FileWriter("./log", graph=sess.graph)

        for epoch in range(epochs):
            batch_idx = 1
            all_loss = 0
            all_acc = 0

            # 训练一个epoch
            while batch_idx <= self.batches and self.has_next_batch:
                train_x, train_y = self.next_batch()

                _, loss_, accuracy_, merged_ = sess.run(
                    [self.opt, self.loss, self.accuracy, self.merged],
                    feed_dict={self.x: train_x, self.y: train_y}
                )

                all_loss += loss_
                all_acc += accuracy_

                # 进度条
                progress = "=" * batch_idx + ">" + "-" * (self.batches - batch_idx)
                print(f"\repoch {epoch+1}/{epochs} -- batch: {batch_idx}/{self.batches} --> "
                      f"[{progress}] loss: {loss_:.4f}, acc: {accuracy_:.4f}", end="")

                batch_idx += 1
                writer.add_summary(merged_, epoch * self.batches + batch_idx - 1)

            # Epoch统计
            mean_loss = all_loss / self.batches
            mean_acc = all_acc / self.batches
            print(f"\n===epoch {epoch+1}/{epochs}=== > mean loss: {mean_loss:.4f}, mean acc: {mean_acc:.4f}")

            # 测试集评估
            test_x, test_y = self.get_test_data()
            test_batch_size = min(16, len(test_x))
            test_loss_, test_acc_ = sess.run(
                [self.loss, self.accuracy],
                feed_dict={self.x: test_x[:test_batch_size], self.y: test_y[:test_batch_size]}
            )
            print(f"===epoch {epoch+1}/{epochs}=== > test loss: {test_loss_:.4f}, test acc: {test_acc_:.4f}\n")

            # 重置批次
            self.start = 0
            self.has_next_batch = True

            # 保存模型
            if (epoch + 1) % save_interval == 0:
                save_path = saver.save(sess, "./h5_dell/model.ckpt", global_step=epoch+1)
                print(f"模型已保存: {save_path}\n")

        print("=" * 60)
        print("训练完成！")
        print("=" * 60)

        sess.close()


#### **Usage of Model Training**

After preparing your data and generating `train.csv`, you can train the CNN model. Instantiate the `CNNModel` class and call its `train()` method. You can specify the number of training epochs.

In [13]:
# Example: Initialize and train the model
# This assumes 'train.csv' has been generated by the DataPreprocess module.

# model = CNNModel()
# model.train(epochs=10)

### 3. Image Recognition Module (`predict.py`)

This module uses the trained CNN model to classify new images or perform real-time object recognition from a video stream. Key features include:

-   **Model Loading**: Restores the trained TensorFlow model from checkpoint files.
-   **Image Prediction**: Takes an image path, preprocesses it, and outputs the predicted class and confidence.
-   **Video Prediction**: Utilizes a webcam feed for real-time inference, displaying the recognized class and confidence directly on the video frames.
-   **Class Labels & Colors**: Defines human-readable class names and corresponding colors for visualization.

In [14]:
import tensorflow as tf
import cv2
import numpy as np
import os # Import os for path manipulation

class ImagePredictor:
    def __init__(self, model_path="./h5_dell/"):
        self.model_path = model_path

        # 类别标签
        self.class_names = {
            0: "厨余垃圾 (Kitchen Waste)",
            1: "可回收垃圾 (Recyclable)",
            2: "有毒垃圾 (Hazardous)",
            3: "其它垃圾 (Other)",
            4: "未知类别 (Unknown)"
        }

        # 颜色映射 (BGR)
        self.colors = {
            0: (0, 0, 255),      # 红色
            1: (0, 255, 255),    # 黄色
            2: (0, 255, 0),      # 绿色
            3: (255, 0, 255),    # 紫色
            4: (128, 128, 128)   # 灰色
        }

        # 构建预测图
        self.build_predict_graph()

    def build_predict_graph(self):
        """
        构建预测用的计算图
        """
        with tf.name_scope("input"):
            self.x = tf.placeholder(tf.float32, [None, 200, 200, 3], "x")

        # 重建网络结构 (与训练时一致)
        with tf.variable_scope("conv_layer_1"):
            conv1 = tf.layers.conv2d(self.x, 64, [3, 3], activation=tf.nn.relu, name='conv1')
            max1 = tf.layers.max_pooling2d(conv1, [2, 2], [2, 2])
            bn1 = tf.layers.batch_normalization(max1, name='bn1', training=False)
            output1 = bn1

        with tf.variable_scope("conv_layer_2"):
            conv2 = tf.layers.conv2d(output1, 64, [3, 3], activation=tf.nn.relu, name='conv2')
            max2 = tf.layers.max_pooling2d(conv2, [2, 2], [2, 2], name='max2')
            bn2 = tf.layers.batch_normalization(max2, training=False)
            output2 = bn2

        with tf.variable_scope("conv_layer_3"):
            conv3 = tf.layers.conv2d(output2, 64, [3, 3], activation=tf.nn.relu, name='conv3')
            max3 = tf.layers.max_pooling2d(conv3, [2, 2], [2, 2], name='max3')
            bn3 = tf.layers.batch_normalization(max3, name='bn3', training=False)
            output3 = bn3

        with tf.variable_scope("conv_layer_4"):
            conv4 = tf.layers.conv2d(output3, 32, [3, 3], activation=tf.nn.relu, name='conv4')
            max4 = tf.layers.max_pooling2d(conv4, [2, 2], [2, 2], name='max4')
            bn4 = tf.layers.batch_normalization(max4, name='bn4', training=False)
            output = bn4
            flatten = tf.layers.flatten(output, 'flatten')

        with tf.variable_scope("fc_layer1"):
            fc1 = tf.layers.dense(flatten, 256, activation=tf.nn.relu)
            fc_bn1 = tf.layers.batch_normalization(fc1, name='bn1', training=False)

        with tf.variable_scope("fc_layer2"):
            fc2 = tf.layers.dense(fc_bn1, 128, activation=tf.nn.relu)

        with tf.variable_scope("fc_layer3"):
            fc3 = tf.layers.dense(fc2, 64)

        with tf.variable_scope("fc_layer4"):
            fc4 = tf.layers.dense(fc3, 32)

        with tf.variable_scope("fc_layer5"):
            fc5 = tf.layers.dense(fc4, 5)

        self.probab = tf.nn.softmax(fc5)
        self.predict = tf.argmax(self.probab, axis=1)

    def predict_image(self, image_path):
        """
        预测单张图片
        """
        print(f"加载图片: {image_path}")

        saver = tf.train.Saver()
        # Reset the default graph to avoid issues with multiple graph creations
        tf.reset_default_graph()
        self.build_predict_graph() # Rebuild the graph in the new context
        sess = tf.InteractiveSession()

        # Restore model
        ckpt = tf.train.latest_checkpoint(self.model_path)
        if ckpt:
            saver.restore(sess, ckpt)
            print(f"模型已加载: {ckpt}")
        else:
            print("未找到模型文件！")
            sess.close()
            return None

        # 读取并预处理图片
        image = cv2.imread(image_path)
        if image is None:
            print("无法读取图片！")
            sess.close()
            return None

        image_resized = cv2.resize(image, (200, 200))
        image_normalized = np.asarray(image_resized, np.float32) / 255.0
        image_batch = np.reshape(
            image_normalized,
            (1, image_normalized.shape[0], image_normalized.shape[1], image_normalized.shape[2])
        )

        # 预测
        [predict_class, probab] = sess.run(
            [self.predict, self.probab],
            feed_dict={self.x: image_batch}
        )

        predict_class = predict_class[0]
        confidence = np.max(probab)

        # 置信度过低则判定为未知
        if confidence < 0.6:
            predict_class = 4

        print(f"\n预测结果:")
        print(f"  类别: {self.class_names[predict_class]}")
        print(f"  置信度: {confidence:.2%}")

        # 在图片上显示结果
        result_image = image.copy()
        color = self.colors[predict_class]
        text = f"{self.class_names[predict_class]} ({confidence:.2%})"

        cv2.putText(
            result_image, text, (10, 50),
            cv2.FONT_HERSHEY_SIMPLEX, 1.2, color, 3, cv2.LINE_AA
        )

        # Display the image (Note: cv2.imshow does not work directly in Colab notebooks)
        # For Colab, you might want to save the image or display it using matplotlib
        # cv2.imshow("Recognition Result", result_image)
        # cv2.waitKey(0)
        # cv2.destroyAllWindows()

        # Alternative for Colab: Save image and display
        output_image_path = './prediction_result.jpg'
        cv2.imwrite(output_image_path, result_image)
        print(f"Prediction result saved to {output_image_path}")
        # To display in Colab, you could use:
        # from IPython.display import Image, display
        # display(Image(filename=output_image_path))

        sess.close()

        return predict_class, confidence

    def predict_video(self, camera_id=0):
        """
        实时视频识别
        """
        print(f"启动摄像头 {camera_id}...")

        saver = tf.train.Saver()
        tf.reset_default_graph()
        self.build_predict_graph()
        sess = tf.InteractiveSession()

        # 恢复模型
        ckpt = tf.train.latest_checkpoint(self.model_path)
        if ckpt:
            saver.restore(sess, ckpt)
            print(f"模型已加载: {ckpt}")
        else:
            print("未找到模型文件！")
            sess.close()
            return

        # 打开摄像头
        # NOTE: cv2.VideoCapture(0) for webcam often does not work directly in Colab notebooks
        # You might need to use a browser-based webcam solution or upload video files.
        capture = cv2.VideoCapture(camera_id)

        if not capture.isOpened():
            print("无法打开摄像头！")
            sess.close()
            return

        print("按 ESC 键退出...")

        while True:
            ret, frame = capture.read()

            if not ret:
                break

            # 预处理
            resize = cv2.resize(frame, (200, 200))
            x_ = np.asarray(resize, np.float32) / 255.0
            x_ = np.reshape(x_, [1, x_.shape[0], x_.shape[1], x_.shape[2]])

            # 预测
            [predict_class, probab] = sess.run(
                [self.predict, self.probab],
                feed_dict={self.x: x_}
            )

            predict_class = predict_class[0]
            confidence = np.max(probab)

            # 置信度过低则判定为未知
            if confidence < 0.6:
                predict_class = 4

            # 显示结果
            color = self.colors[predict_class]
            text = f"{self.class_names[predict_class]} ({confidence:.2%})"

            cv2.putText(
                frame, text, (10, 50),
                cv2.FONT_HERSHEY_SIMPLEX, 1.5, color, 3, cv2.LINE_AA
            )

            # cv2.imshow("Real-time Recognition", frame)

            # For Colab, you would need to display frames differently
            # For example, convert frame to PIL Image and display or save frames.
            # This real-time video display part is challenging in a Colab environment.

            # For demonstration purposes, we'll just break after a few frames or on key press
            key = cv2.waitKey(1)
            if key == 27: # ESC key
                break

        capture.release()
        cv2.destroyAllWindows()
        sess.close()

        print("视频识别结束")


#### **Usage of Image Recognition**

To perform predictions, instantiate the `ImagePredictor` class. You can then use `predict_image()` for single images or `predict_video()` for real-time video streams (note that `predict_video` might have limitations in a Colab environment).

In [15]:
# Example: Predict a single image
# You would need a trained model saved in './h5_dell/' and a test image, e.g., './test.jpg'
# You can create a dummy test.jpg for testing purposes:
# cv2.imwrite('./test.jpg', np.zeros((200,200,3), dtype=np.uint8))

# predictor = ImagePredictor()
# predictor.predict_image('./test.jpg')

# Example: Predict from video (will likely not work directly in Colab without specific setups)
# predictor.predict_video()


### 4. Main Program Entry Point (`main.py`)

This section defines the `main` function which acts as the entry point for the entire application. It uses `argparse` to allow users to specify the operation mode (preprocess, train, predict_image, predict_video) and other parameters from the command line. This is typical for standalone scripts, but for a Colab notebook, you would typically run the class methods directly as shown in the usage examples above.

In [16]:
def main():
    import argparse

    parser = argparse.ArgumentParser(description='CNN Image Classification')
    parser.add_argument('--mode', type=str, required=True,
                       choices=['preprocess', 'train', 'predict_image', 'predict_video'],
                       help='运行模式')
    parser.add_argument('--image_path', type=str, default='./test.jpg',
                       help='测试图片路径 (predict_image模式使用)')
    parser.add_argument('--data_path', type=str, default='./picture/',
                       help='数据集路径')
    parser.add_argument('--epochs', type=int, default=100,
                       help='训练轮数')

    args = parser.parse_args()

    if args.mode == 'preprocess':
        # 数据预处理
        preprocessor = DataPreprocess(image_path=args.data_path)
        preprocessor.run_all()

    elif args.mode == 'train':
        # 训练模型
        model = CNNModel()
        model.train(epochs=args.epochs)

    elif args.mode == 'predict_image':
        # 图片识别
        predictor = ImagePredictor()
        predictor.predict_image(args.image_path)

    elif args.mode == 'predict_video':
        # 视频识别
        predictor = ImagePredictor()
        predictor.predict_video()


# The original code's `if __name__ == "__main__":` block is intended for script execution.
# In a Colab notebook, you would typically call the functions/methods directly
# or simulate the argparse behavior if needed.

# if __name__ == "__main__":
#     print("=" * 60)
#     print("CNN 图像分类项目")
#     print("=" * 60)

#     # Uncomment the desired section to run

#     # 1. 数据预处理
#     # preprocessor = DataPreprocess()
#     # preprocessor.run_all()

#     # 2. 训练模型
#     # model = CNNModel()
#     # model.train(epochs=100)

#     # 3. 图片识别
#     # predictor = ImagePredictor()
#     # predictor.predict_image('./test.jpg')

#     # 4. 视频识别
#     # predictor = ImagePredictor()
#     # predictor.predict_video()

#     main() # This line would execute the argparse logic
