<a href="https://colab.research.google.com/github/HakureiPOI/CIFAR-10-wyx/blob/main/cifar10_main.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 基于 CIFAR10 和 ResNet18 的图像分类

In [1]:
from google.colab import drive
drive.mount('/content/drive')

MessageError: Error: credential propagation was unsuccessful

---

## Part1. 获取 CIFAR-10 数据集

In [None]:
import tensorflow as tf

In [None]:
# 获取 CIFAR-10 数据
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()

print(f'X_train shape: {X_train.shape}')
print(f'X_test shape: {X_test.shape}')
print(f'y_train shape: {y_train.shape}')
print(f'y_test shape: {y_test.shape}')

In [None]:
# CIFAR-10 是一个十分类的图像数据集
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

In [None]:
X_train[0]      #  32 × 32 的 3通道图片，各

In [None]:
y_train[0]      # 标签是 0~9 十个离散整数值

In [None]:
import matplotlib.pyplot as plt
import numpy as np

In [None]:
plt.figure(figsize=(12, 6))
images_idx = [] # 记录一下这里展示过的图片

# 随机展示一下各类别的图片
for i in range(len(class_names)):
    class_indices = np.where(y_train == i)[0]
    idx = np.random.choice(class_indices)
    images_idx.append(idx)

    plt.subplot(2, 5, i + 1)
    plt.imshow(X_train[idx])
    plt.title(class_names[i])
    plt.axis('off')

plt.show()

---

## Part2. 对影像数据进行预处理

In [None]:
def preprocess_image(image, label):
    # 像素值归一化
    image = tf.cast(image, tf.float32) / 255.0

    # 随机翻转
    image = tf.image.random_flip_left_right(image)
    # image = tf.image.random_flip_up_down(image)  # 可选上下翻转

    # 50% 概率高斯模糊
    def get_gaussian_kernel(kernel_size, sigma, n_channels):
        # 生成 1D 高斯
        x = tf.range(-(kernel_size // 2), kernel_size // 2 + 1, dtype=tf.float32)
        x = tf.exp(-(x**2) / (2.0 * sigma**2))
        x = x / tf.reduce_sum(x)

        # 外积得到 2D 高斯核 [k,k]
        kernel2d = tf.tensordot(x, x, axes=0)
        kernel2d = kernel2d[:, :, tf.newaxis, tf.newaxis]          # [k,k,1,1]
        kernel2d = tf.repeat(kernel2d, n_channels, axis=2)         # [k,k,C,1]
        return kernel2d

    def blur():
        kernel_size = 3
        sigma = tf.random.uniform((), minval=0.5, maxval=1.5)
        kernel = get_gaussian_kernel(kernel_size, sigma, tf.shape(image)[-1])
        return tf.nn.depthwise_conv2d(image[tf.newaxis, ...], kernel,
                                      strides=[1, 1, 1, 1], padding='SAME')[0]

    image = tf.cond(tf.random.uniform(()) > 0.5, blur, lambda: image)

    return image, label

In [None]:
# 将之前展示的图片预处理后查看效果
plt.figure(figsize = (12, 6))

for i, idx in enumerate(images_idx):
    image, label = preprocess_image(X_train[idx], y_train[idx])
    plt.subplot(2, 5, i + 1)
    plt.imshow(image.numpy())
    plt.title(class_names[label.item()])
    plt.axis('off')

plt.show()

In [None]:
# 构建数据集并应用预处理函数
batch_size = 128

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))

train_dataset = train_dataset.map(preprocess_image, num_parallel_calls = tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.map(preprocess_image, num_parallel_calls = tf.data.AUTOTUNE).batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
for image, label in train_dataset.take(1):
    idx = np.random.choice(image.shape[0])

    plt.figure(figsize = (4, 4))
    plt.imshow(image[idx].numpy())
    plt.title(class_names[label[idx].numpy().item()])
    plt.axis('off')
    plt.show()

    break

---

## Part3. 搭建网络进行训练

In [None]:
from tensorflow.keras import layers, models, regularizers

In [None]:
def resnet18_cifar10(
    input_shape = (32, 32, 3),
    num_classes = 10,
    base_width = 64,
    weight_decay = 5e-4,
    dropout_rate = 0.0,
):

    wd = regularizers.l2(weight_decay)

    def conv3x3(filters, stride=1):
        # 3x3 卷积
        return layers.Conv2D(
            filters, 3, strides=stride, padding='same',
            use_bias=False, kernel_initializer='he_normal',
            kernel_regularizer=wd
        )

    def conv1x1(filters, stride = 1):
        # 1x1 卷积：用于 shortcut 对齐通道/下采样
        return layers.Conv2D(
            filters, 1, strides=stride, padding='valid',
            use_bias=False, kernel_initializer='he_normal',
            kernel_regularizer=wd
        )

    def preact_basic_block(x, filters, stride = 1):
        shortcut = x

        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)

        # 如果需要下采样/通道对齐：用 1x1 conv 处理 shortcut
        if stride != 1 or shortcut.shape[-1] != filters:
            shortcut = conv1x1(filters, stride=stride)(x)

        # 主分支两层 3x3
        x = conv3x3(filters, stride=stride)(x)

        x = layers.BatchNormalization()(x)
        x = layers.Activation("relu")(x)

        if dropout_rate > 0:
            x = layers.Dropout(dropout_rate)(x)

        x = conv3x3(filters, stride = 1)(x)

        # 残差相加
        x = layers.Add()([x, shortcut])
        return x

    def make_stage(x, filters, blocks, first_stride):
        # 一个 stage：第一个 block 可能 stride = 2 下采样，其余 stride = 1
        x = preact_basic_block(x, filters, stride=first_stride)
        for _ in range(1, blocks):
            x = preact_basic_block(x, filters, stride=1)
        return x

    inputs = layers.Input(shape=input_shape)

    # CIFAR 的 stem 用 3x3 stride = 1，避免过早下采样
    x = layers.Conv2D(
        base_width, 3, strides=1, padding='same',
        use_bias=False, kernel_initializer='he_normal',
        kernel_regularizer=wd
    )(inputs)

    # ResNet-18: (2,2,2,2)
    x = make_stage(x, base_width,     blocks=2, first_stride=1)  # 32x32
    x = make_stage(x, base_width*2,   blocks=2, first_stride=2)  # 16x16
    x = make_stage(x, base_width*4,   blocks=2, first_stride=2)  # 8x8
    x = make_stage(x, base_width*8,   blocks=2, first_stride=2)  # 4x4

    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)

    x = layers.GlobalAveragePooling2D()(x)
    outputs = layers.Dense(
        num_classes, activation='softmax',
        kernel_regularizer=wd
    )(x)

    model = models.Model(inputs, outputs, name='ResNet18_CIFAR10')
    return model

In [None]:
model = resnet18_cifar10(input_shape=(32,32,3), num_classes=10, weight_decay=5e-4, dropout_rate=0.3)
model.summary()

In [None]:
from tensorflow.keras import optimizers, callbacks

In [None]:
epochs = 100

steps_per_epoch = len(X_train) // batch_size
total_steps = steps_per_epoch * epochs

# 模拟退火
lr_schedule = optimizers.schedules.CosineDecay(
    initial_learning_rate=3e-3,
    decay_steps=total_steps,
    alpha=1e-2
)

optimizer = optimizers.Adam(learning_rate=lr_schedule)

# 早停机制
early_stop = callbacks.EarlyStopping(
    monitor='val_accuracy',
    mode='max',
    patience=15,
    restore_best_weights=True,
    verbose=1
)

# 保存最好模型
ckpt = callbacks.ModelCheckpoint(
    'best_cifar10_adam.keras',
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    verbose=1
)

In [None]:
model.compile(
    optimizer=optimizer,
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
    metrics=["accuracy"]
)

history = model.fit(
    train_dataset,
    validation_data=test_dataset,
    epochs=epochs,
    callbacks=[ckpt, early_stop]
)

In [None]:
save_path = "/content/drive/MyDrive/cifar10_models/resnet18_best.keras"

import os
os.makedirs(os.path.dirname(save_path), exist_ok=True)

model.save(save_path)
print("Saved to:", save_path)