# CycleGAN Monet风格转换 - 独立项目

这是一个完整的Kaggle GAN竞赛解决方案，从零开始实现CycleGAN模型。

## 📋 项目特点

- ✅ **完全独立**: 不依赖外部项目
- ✅ **中文注释**: 详细的中文说明
- ✅ **可视化**: 实时查看训练过程和结果
- ✅ **自动保存**: 自动保存模型和样本
- ✅ **一键提交**: 自动生成Kaggle提交文件

## 🚀 使用说明

1. 确保数据已放在 `data/Image_Generation_Data_Kaggle/` 目录
2. 按顺序运行所有cells
3. 训练完成后自动生成提交文件

## 📊 数据说明

- **Monet图像**: 300张Monet原画
- **Photo图像**: 7038张照片
- **格式**: TFRecord文件，256x256像素


## 1. 导入库和设置


In [None]:
# 导入必要的库
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime

# 设置随机种子
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
tf.random.set_seed(SEED)

# 设置TensorFlow日志级别
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# 配置matplotlib
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

print("✅ 库导入成功！")
print(f"TensorFlow版本: {tf.__version__}")

# 检查GPU
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"✅ 检测到 {len(gpus)} 个GPU")
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(f"GPU配置错误: {e}")
else:
    print("⚠️ 未检测到GPU，将使用CPU（速度会很慢）")


## 2. 配置参数

**可以根据需要修改以下参数**


In [None]:
# ==================== 配置参数 ====================

# 数据路径
DATA_ROOT = "data/Image_Generation_Data_Kaggle"
MONET_TFREC_PATH = os.path.join(DATA_ROOT, "monet_tfrec")
PHOTO_TFREC_PATH = os.path.join(DATA_ROOT, "photo_tfrec")

# 模型参数
IMAGE_SIZE = 256
CHANNELS = 3
LAMBDA_CYCLE = 10.0
LAMBDA_IDENTITY = 0.5

# 训练参数
BATCH_SIZE = 8          # 批次大小（根据显存调整：4/8/16）
EPOCHS = 20             # 训练轮数（推荐：20-50）
LEARNING_RATE = 2e-4    # 学习率
BETA_1 = 0.5            # Adam优化器参数

# 数据增强
USE_AUGMENTATION = True
AUGMENTATION_PROB = 0.5

# 保存设置
SAVE_DIR = "saves"
MODEL_NAME = "cyclegan_monet"
SAVE_SAMPLES = True
NUM_SAMPLES_TO_SAVE = 10

# 打印配置
print("=" * 50)
print("训练配置:")
print("=" * 50)
print(f"数据路径: {DATA_ROOT}")
print(f"图像尺寸: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"批次大小: {BATCH_SIZE}")
print(f"训练轮数: {EPOCHS}")
print(f"学习率: {LEARNING_RATE}")
print(f"Cycle Loss权重: {LAMBDA_CYCLE}")
print(f"使用数据增强: {USE_AUGMENTATION}")
print("=" * 50)


## 3. 数据加载函数


In [None]:
import re

def count_data_items(filenames):
    """计算TFRecord文件中的数据项数量"""
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def decode_image(image):
    """解码JPEG图像并归一化到[-1, 1]"""
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, CHANNELS])
    return image

def read_tfrecord(example):
    """读取TFRecord示例"""
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    """从TFRecord文件加载数据集"""
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

def augment_image(image):
    """数据增强"""
    if not USE_AUGMENTATION:
        return image
    
    # 随机水平翻转
    image = tf.image.random_flip_left_right(image)
    
    # 随机垂直翻转
    image = tf.image.random_flip_up_down(image)
    
    # 随机旋转（90度倍数）
    if tf.random.uniform([]) < AUGMENTATION_PROB:
        k = tf.random.uniform([], 0, 4, dtype=tf.int32)
        image = tf.image.rot90(image, k)
    
    return image

print("✅ 数据加载函数定义完成")


## 4. 加载数据


In [None]:
# 获取文件列表
monet_files = tf.io.gfile.glob(os.path.join(MONET_TFREC_PATH, "*.tfrec"))
photo_files = tf.io.gfile.glob(os.path.join(PHOTO_TFREC_PATH, "*.tfrec"))

print(f"找到 {len(monet_files)} 个Monet TFRecord文件")
print(f"找到 {len(photo_files)} 个Photo TFRecord文件")

# 计算数据项数量
n_monet = count_data_items(monet_files)
n_photo = count_data_items(photo_files)

print(f"Monet图像数量: {n_monet}")
print(f"Photo图像数量: {n_photo}")

# 加载数据集
monet_ds = load_dataset(monet_files)
photo_ds = load_dataset(photo_files)

# 应用数据增强
monet_ds = monet_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)
photo_ds = photo_ds.map(augment_image, num_parallel_calls=tf.data.AUTOTUNE)

# 配置数据集
monet_ds = monet_ds.shuffle(1000).repeat().batch(BATCH_SIZE, drop_remainder=True)
photo_ds = photo_ds.shuffle(1000).repeat().batch(BATCH_SIZE, drop_remainder=True)

# 缓存和预取
monet_ds = monet_ds.cache().prefetch(tf.data.AUTOTUNE)
photo_ds = photo_ds.cache().prefetch(tf.data.AUTOTUNE)

# 创建配对数据集
dataset = tf.data.Dataset.zip((monet_ds, photo_ds))

print("✅ 数据加载完成")


## 5. 可视化数据样本


In [None]:
# 显示Monet样本
print("Monet样本:")
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()

for i, (monet, photo) in enumerate(dataset.take(8)):
    axes[i].imshow(monet[0] * 0.5 + 0.5)
    axes[i].set_title(f'Monet {i+1}')
    axes[i].axis('off')

plt.suptitle('Monet样本')
plt.tight_layout()
plt.show()

# 显示Photo样本
print("Photo样本:")
fig, axes = plt.subplots(2, 4, figsize=(12, 6))
axes = axes.ravel()

for i, (monet, photo) in enumerate(dataset.take(8)):
    axes[i].imshow(photo[0] * 0.5 + 0.5)
    axes[i].set_title(f'Photo {i+1}')
    axes[i].axis('off')

plt.suptitle('Photo样本')
plt.tight_layout()
plt.show()


## 6. 模型定义


In [None]:
import tensorflow_addons as tfa

def downsample(filters, size, apply_instancenorm=True, strides=2):
    """下采样层"""
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2D(filters, size, strides=strides, padding='same',
                                     kernel_initializer=initializer, use_bias=False))

    if apply_instancenorm:
        result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    result.add(tf.keras.layers.LeakyReLU())

    return result

def upsample(filters, size, apply_dropout=False, strides=2):
    """上采样层"""
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    result = tf.keras.Sequential()
    result.add(tf.keras.layers.Conv2DTranspose(filters, size, strides=strides, padding='same',
                                              kernel_initializer=initializer, use_bias=False))

    result.add(tfa.layers.InstanceNormalization(gamma_initializer=gamma_init))

    if apply_dropout:
        result.add(tf.keras.layers.Dropout(0.5))

    result.add(tf.keras.layers.ReLU())

    return result

def Generator():
    """生成器（U-Net架构）"""
    inputs = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, CHANNELS])

    # 下采样
    down_stack = [
        downsample(64, 4, apply_instancenorm=False),  # (bs, 128, 128, 64)
        downsample(128, 4),                           # (bs, 64, 64, 128)
        downsample(256, 4),                           # (bs, 32, 32, 256)
        downsample(512, 4),                           # (bs, 16, 16, 512)
        downsample(512, 4),                           # (bs, 8, 8, 512)
        downsample(512, 4),                           # (bs, 4, 4, 512)
        downsample(512, 4),                           # (bs, 2, 2, 512)
        downsample(512, 4),                           # (bs, 1, 1, 512)
    ]

    # 上采样
    up_stack = [
        upsample(512, 4, apply_dropout=True),  # (bs, 2, 2, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 4, 4, 1024)
        upsample(512, 4, apply_dropout=True),  # (bs, 8, 8, 1024)
        upsample(512, 4),                      # (bs, 16, 16, 1024)
        upsample(256, 4),                      # (bs, 32, 32, 512)
        upsample(128, 4),                      # (bs, 64, 64, 256)
        upsample(64, 4),                       # (bs, 128, 128, 128)
    ]

    initializer = tf.random_normal_initializer(0., 0.02)
    last = tf.keras.layers.Conv2DTranspose(CHANNELS, 4,
                                          strides=2,
                                          padding='same',
                                          kernel_initializer=initializer,
                                          activation='tanh')  # (bs, 256, 256, 3)

    x = inputs

    # 下采样并保存跳跃连接
    skips = []
    for down in down_stack:
        x = down(x)
        skips.append(x)

    skips = reversed(skips[:-1])

    # 上采样并建立跳跃连接
    for up, skip in zip(up_stack, skips):
        x = up(x)
        x = tf.keras.layers.Concatenate()([x, skip])

    x = last(x)

    return tf.keras.Model(inputs=inputs, outputs=x)

def Discriminator():
    """判别器（PatchGAN架构）"""
    initializer = tf.random_normal_initializer(0., 0.02)
    gamma_init = tf.keras.initializers.RandomNormal(mean=0.0, stddev=0.02)

    inp = tf.keras.layers.Input(shape=[IMAGE_SIZE, IMAGE_SIZE, CHANNELS], name='input_image')

    x = inp

    down1 = downsample(64, 4, False)(x)  # (bs, 128, 128, 64)
    down2 = downsample(128, 4)(down1)    # (bs, 64, 64, 128)
    down3 = downsample(256, 4)(down2)     # (bs, 32, 32, 256)

    zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3)  # (bs, 34, 34, 256)
    conv = tf.keras.layers.Conv2D(512, 4, strides=1,
                                 kernel_initializer=initializer,
                                 use_bias=False)(zero_pad1)  # (bs, 31, 31, 512)

    norm1 = tfa.layers.InstanceNormalization(gamma_initializer=gamma_init)(conv)

    leaky_relu = tf.keras.layers.LeakyReLU()(norm1)

    zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu)  # (bs, 33, 33, 512)

    last = tf.keras.layers.Conv2D(1, 4, strides=1,
                                 kernel_initializer=initializer)(zero_pad2)  # (bs, 30, 30, 1)

    return tf.keras.Model(inputs=inp, outputs=last)

print("✅ 模型定义完成")


In [None]:
# 创建模型
print("🏗️ 创建模型...")
monet_generator = Generator()
photo_generator = Generator()
monet_discriminator = Discriminator()
photo_discriminator = Discriminator()

print(f"生成器参数量: {monet_generator.count_params():,}")
print(f"判别器参数量: {monet_discriminator.count_params():,}")

print("✅ 模型创建完成")


## 8. 损失函数定义


In [None]:
def discriminator_loss(real, generated):
    """判别器损失函数"""
    real_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(
        tf.ones_like(real), real)
    generated_loss = tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(
        tf.zeros_like(generated), generated)
    total_disc_loss = real_loss + generated_loss
    return total_disc_loss * 0.5

def generator_loss(generated):
    """生成器损失函数"""
    return tf.keras.losses.BinaryCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.NONE)(
        tf.ones_like(generated), generated)

def calc_cycle_loss(real_image, cycled_image, LAMBDA):
    """循环一致性损失"""
    loss1 = tf.reduce_mean(tf.abs(real_image - cycled_image))
    return LAMBDA * loss1

def identity_loss(real_image, same_image, LAMBDA):
    """恒等损失"""
    loss = tf.reduce_mean(tf.abs(real_image - same_image))
    return LAMBDA * 0.5 * loss

print("✅ 损失函数定义完成")


## 9. CycleGAN模型类


In [None]:
class CycleGAN(tf.keras.Model):
    """CycleGAN模型"""
    
    def __init__(self, monet_generator, photo_generator, monet_discriminator, photo_discriminator, lambda_cycle=10.0):
        super(CycleGAN, self).__init__()
        self.m_gen = monet_generator
        self.p_gen = photo_generator
        self.m_disc = monet_discriminator
        self.p_disc = photo_discriminator
        self.lambda_cycle = lambda_cycle

    def compile(self, m_gen_optimizer, p_gen_optimizer, m_disc_optimizer, p_disc_optimizer,
                gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn):
        super(CycleGAN, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn

    def train_step(self, batch_data):
        real_monet, real_photo = batch_data
        
        batch_size = tf.shape(real_monet)[0]
        
        with tf.GradientTape(persistent=True) as tape:
            # Photo -> Monet -> Photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)

            # Monet -> Photo -> Monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)

            # Identity loss
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)

            # Discriminator outputs
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)

            # Generator losses
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)

            # Cycle consistency loss
            total_cycle_loss = (self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + 
                               self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle))

            # Identity loss
            total_monet_gen_loss = (monet_gen_loss + total_cycle_loss + 
                                   self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle))
            total_photo_gen_loss = (photo_gen_loss + total_cycle_loss + 
                                   self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle))

            # Discriminator losses
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)

        # Calculate gradients
        monet_gen_grads = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_gen_grads = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        monet_disc_grads = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_disc_grads = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)

        # Apply gradients
        self.m_gen_optimizer.apply_gradients(zip(monet_gen_grads, self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_gen_grads, self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_disc_grads, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_disc_grads, self.p_disc.trainable_variables))
        
        return {
            'monet_gen_loss': total_monet_gen_loss,
            'photo_gen_loss': total_photo_gen_loss,
            'monet_disc_loss': monet_disc_loss,
            'photo_disc_loss': photo_disc_loss,
            'total_cycle_loss': total_cycle_loss
        }

print("✅ CycleGAN模型类定义完成")


## 10. 创建优化器和编译模型


In [None]:
# 创建优化器
print("⚙️ 配置优化器...")
monet_gen_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)
photo_gen_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)
monet_disc_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)
photo_disc_optimizer = tf.keras.optimizers.Adam(LEARNING_RATE, beta_1=BETA_1)

# 创建CycleGAN模型
cyclegan = CycleGAN(
    monet_generator=monet_generator,
    photo_generator=photo_generator,
    monet_discriminator=monet_discriminator,
    photo_discriminator=photo_discriminator,
    lambda_cycle=LAMBDA_CYCLE
)

# 编译模型
cyclegan.compile(
    m_gen_optimizer=monet_gen_optimizer,
    p_gen_optimizer=photo_gen_optimizer,
    m_disc_optimizer=monet_disc_optimizer,
    p_disc_optimizer=photo_disc_optimizer,
    gen_loss_fn=generator_loss,
    disc_loss_fn=discriminator_loss,
    cycle_loss_fn=calc_cycle_loss,
    identity_loss_fn=identity_loss
)

print("✅ 模型编译完成")


## 11. 开始训练

**训练可能需要较长时间，请耐心等待**


In [None]:
# 创建保存目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = os.path.join(SAVE_DIR, f"{MODEL_NAME}_{timestamp}")
os.makedirs(save_dir, exist_ok=True)
print(f"模型将保存到: {save_dir}")

# 计算训练步数
steps_per_epoch = max(n_monet, n_photo) // BATCH_SIZE
print(f"每轮步数: {steps_per_epoch}")
print(f"总步数: {steps_per_epoch * EPOCHS}")

# 创建回调
class LogCallback(tf.keras.callbacks.Callback):
    def __init__(self, log_interval=50):
        super().__init__()
        self.log_interval = log_interval
    
    def on_batch_end(self, batch, logs=None):
        if batch % self.log_interval == 0:
            print(f"Batch {batch}: "
                  f"Monet Gen Loss: {logs['monet_gen_loss']:.4f}, "
                  f"Photo Gen Loss: {logs['photo_gen_loss']:.4f}, "
                  f"Monet Disc Loss: {logs['monet_disc_loss']:.4f}, "
                  f"Photo Disc Loss: {logs['photo_disc_loss']:.4f}, "
                  f"Cycle Loss: {logs['total_cycle_loss']:.4f}")

callbacks = [
    LogCallback(log_interval=50),
    tf.keras.callbacks.ModelCheckpoint(
        filepath=os.path.join(save_dir, "checkpoint.h5"),
        save_weights_only=True,
        save_best_only=False,
        verbose=1
    )
]

# 开始训练
print("🚀 开始训练...")
print("=" * 50)

history = cyclegan.fit(
    dataset,
    steps_per_epoch=steps_per_epoch,
    epochs=EPOCHS,
    verbose=1,
    callbacks=callbacks
)

print("✅ 训练完成！")


## 12. 可视化训练历史


In [None]:
# 绘制训练历史
fig, axes = plt.subplots(2, 2, figsize=(12, 8))

# Generator Loss
axes[0, 0].plot(history.history['monet_gen_loss'], label='Monet Gen')
axes[0, 0].plot(history.history['photo_gen_loss'], label='Photo Gen')
axes[0, 0].set_title('Generator Loss')
axes[0, 0].set_xlabel('Epoch')
axes[0, 0].set_ylabel('Loss')
axes[0, 0].legend()
axes[0, 0].grid(True)

# Discriminator Loss
axes[0, 1].plot(history.history['monet_disc_loss'], label='Monet Disc')
axes[0, 1].plot(history.history['photo_disc_loss'], label='Photo Disc')
axes[0, 1].set_title('Discriminator Loss')
axes[0, 1].set_xlabel('Epoch')
axes[0, 1].set_ylabel('Loss')
axes[0, 1].legend()
axes[0, 1].grid(True)

# Cycle Loss
axes[1, 0].plot(history.history['total_cycle_loss'], label='Cycle Loss')
axes[1, 0].set_title('Cycle Consistency Loss')
axes[1, 0].set_xlabel('Epoch')
axes[1, 0].set_ylabel('Loss')
axes[1, 0].legend()
axes[1, 0].grid(True)

# Combined Loss
axes[1, 1].plot(history.history['monet_gen_loss'], label='Monet Gen')
axes[1, 1].plot(history.history['photo_gen_loss'], label='Photo Gen')
axes[1, 1].plot(history.history['monet_disc_loss'], label='Monet Disc')
axes[1, 1].plot(history.history['photo_disc_loss'], label='Photo Disc')
axes[1, 1].set_title('All Losses')
axes[1, 1].set_xlabel('Epoch')
axes[1, 1].set_ylabel('Loss')
axes[1, 1].legend()
axes[1, 1].grid(True)

plt.tight_layout()
plt.show()


## 13. 显示生成结果


In [None]:
# 显示生成的样本
print("🎨 显示生成样本...")
fig, axes = plt.subplots(2, 8, figsize=(16, 4))

photo_dataset = dataset.map(lambda x, y: y)  # 提取photo部分

for i, photo in enumerate(photo_dataset.take(8)):
    # 原始照片
    axes[0, i].imshow(photo[0] * 0.5 + 0.5)
    axes[0, i].set_title(f'Original {i+1}')
    axes[0, i].axis('off')
    
    # 生成的Monet风格
    generated = monet_generator(photo, training=False)
    axes[1, i].imshow(generated[0] * 0.5 + 0.5)
    axes[1, i].set_title(f'Generated {i+1}')
    axes[1, i].axis('off')

plt.suptitle('Photo to Monet Style Transfer')
plt.tight_layout()
plt.show()


## 14. 保存模型


In [None]:
# 保存模型
print("💾 保存模型...")
monet_generator.save(os.path.join(save_dir, "monet_generator.h5"))
photo_generator.save(os.path.join(save_dir, "photo_generator.h5"))
monet_discriminator.save(os.path.join(save_dir, "monet_discriminator.h5"))
photo_discriminator.save(os.path.join(save_dir, "photo_discriminator.h5"))

print("✅ 模型保存完成")
print(f"模型保存路径: {save_dir}")


## 15. 生成提交文件


In [None]:
from PIL import Image
import shutil

# 创建预测数据集（单张图像批次）
photo_files = tf.io.gfile.glob(os.path.join(PHOTO_TFREC_PATH, "*.tfrec"))
photo_ds_predict = load_dataset(photo_files)
photo_ds_predict = photo_ds_predict.batch(1)  # 单张图像批次

# 创建输出目录
output_dir = os.path.join(save_dir, "submission_images")
os.makedirs(output_dir, exist_ok=True)

# 生成图像
print("🎨 生成提交图像...")
count = 0

for photo in photo_ds_predict:
    # 生成Monet风格图像
    generated = monet_generator(photo, training=False)
    
    # 转换为PIL图像并保存
    img_array = (generated[0].numpy() * 127.5 + 127.5).astype(np.uint8)
    img = Image.fromarray(img_array)
    img.save(os.path.join(output_dir, f'{count+1}.jpg'))
    
    count += 1
    
    if count % 100 == 0:
        print(f"已生成 {count} 张图像")

print(f"总共生成了 {count} 张图像")

# 创建ZIP文件
zip_path = shutil.make_archive(output_dir, 'zip', output_dir)
zip_size = os.path.getsize(zip_path) / (1024 * 1024)

print(f"ZIP文件已创建: {zip_path}")
print(f"文件大小: {zip_size:.2f} MB")

# 验证结果
if count < 7000:
    print("⚠️ 警告: 图像数量少于7000张！")
elif count > 10000:
    print("⚠️ 警告: 图像数量超过10000张！")
else:
    print("✅ 图像数量符合要求 (7000-10000张)")

print(f"\n🎉 提交文件已生成: {zip_path}")
print("请将此文件上传到Kaggle竞赛页面")


## 完成！

你的CycleGAN模型已经训练完成，提交文件已生成！

### 📊 训练总结

- ✅ 模型训练完成
- ✅ 训练历史已可视化
- ✅ 生成样本已展示
- ✅ 模型已保存
- ✅ 提交文件已生成

### 🎯 下一步

1. 找到生成的ZIP文件（在saves目录下）
2. 上传到Kaggle竞赛页面
3. 等待评估结果

### 💡 优化建议

- 如果想获得更好的效果，可以增加训练轮数（EPOCHS）
- 可以尝试不同的超参数组合
- 可以添加更多的数据增强技术

### 📁 文件说明

- `monet_generator.h5`: 训练好的Monet生成器
- `submission_images.zip`: Kaggle提交文件
- `checkpoint.h5`: 训练检查点

**祝你竞赛成功！** 🎨🏆
