# CycleGAN Monet风格转换 - 完整训练Notebook

本notebook用于在Jupyter Lab中训练CycleGAN模型，将照片转换为Monet风格的画作。

## 使用说明：
1. 确保数据已经放在正确的位置
2. 修改下面的配置参数
3. 按顺序运行所有cells
4. 训练完成后生成提交文件


## 1. 导入必要的库


In [None]:
# 导入必要的库
import os
import sys
import time
import random
import logging
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from datetime import datetime

# 设置随机种子以保证可重复性
seed = 2021
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

# 设置TensorFlow日志级别（减少输出）
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# 添加项目路径到系统路径
sys.path.append('Something-of-a-Painter')

# 从项目导入必要的模块
from models import CycleGan
from layers.Augmentations import DiffAugment, DataAugment
from data_provider.data_factory import data_provider, get_gan_dataset
from utils.metrics import (
    discriminator_loss, generator_loss, calc_cycle_loss, identity_loss
)
from utils.tools import (
    display_samples, display_augmented_samples, display_generated_samples,
    predict_and_save, LogCallback, ClearMemory
)

# 配置matplotlib显示中文
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

print("✓ 所有库导入成功！")
print(f"TensorFlow 版本: {tf.__version__}")
print(f"Python 版本: {sys.version}")

# 检查GPU是否可用
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    print(f"✓ 检测到 {len(gpus)} 个GPU")
    for i, gpu in enumerate(gpus):
        print(f"  GPU {i}: {gpu}")
else:
    print("⚠ 未检测到GPU，将使用CPU（速度会很慢）")


## 2. 配置参数

**请根据你的实际情况修改以下参数！**


In [None]:
# ==================== 配置类 ====================
class Config:
    """训练配置参数类"""
    
    # ========== 数据路径配置 ==========
    # 修改为你实际的数据路径
    # 数据目录应该包含: gan-getting-started/monet_tfrec/ 和 gan-getting-started/photo_tfrec/
    root_path = '../../data/'  # 数据根目录
    # 例如: root_path = 'D:/data/kaggle/'
    
    # ========== 训练参数 ==========
    batch_size = 8              # 批次大小（根据显存调整：4/8/16）
    train_epochs = 20           # 训练轮数（推荐：20-50）
    steps_per_epoch = -1        # 每轮步数（-1表示自动计算）
    learning_rate = 2e-4        # 学习率
    
    # ========== 模型参数 ==========
    height = 256                # 图像高度
    width = 256                 # 图像宽度
    channels = 3                # 图像通道数（RGB）
    out_channels = 3            # 输出通道数
    lambda_cycle = 10           # Cycle Loss权重
    
    # ========== 数据增强配置 ==========
    ds_augment = False          # 是否使用数据集增强（True/False）
    diffaugment = ['color', 'translation', 'cutout']  # DiffAugmentation类型
    
    # ========== 其他配置 ==========
    model_id = 'EXP'            # 模型ID（用于命名保存文件夹）
    seed = 2021                 # 随机种子
    use_wandb = False           # 是否使用wandb记录（需要先安装wandb）
    
    # ========== 自动配置（通常不需要修改）==========
    checkpoints = 'checkpoints'
    auto = tf.data.experimental.AUTOTUNE

# 创建配置对象
args = Config()

# 创建保存目录
if args.root_path == '../../data/':
    print("⚠ 警告：请修改 root_path 为你实际的数据路径！")
    
# 打印配置信息
print("=" * 60)
print("训练配置:")
print("=" * 60)
print(f"数据路径: {args.root_path}")
print(f"批次大小: {args.batch_size}")
print(f"训练轮数: {args.train_epochs}")
print(f"学习率: {args.learning_rate}")
print(f"使用数据集增强: {args.ds_augment}")
print(f"使用DiffAugmentation: {args.diffaugment}")
print("=" * 60)


## 3. 加载数据


In [None]:
# 配置全局变量（用于数据加载函数）
import data_provider.data_factory as data_factory
data_factory.HEIGHT = args.height
data_factory.WIDTH = args.width
data_factory.CHANNELS = args.channels
data_factory.AUTO = args.auto

# 设置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger(__name__)

# 创建模拟的参数对象（用于数据加载）
class Args:
    def __init__(self, config):
        self.height = config.height
        self.width = config.width
        self.channels = config.channels
        self.auto = config.auto
        self.root_path = config.root_path
        self.batch_size = config.batch_size

args_obj = Args(args)

# 加载数据
print("正在加载数据...")
try:
    gan_ds, (n_monet, monet_ds), (n_photo, photo_ds) = data_provider(args_obj, logger)
    
    print(f"\n✓ 数据加载成功！")
    print(f"Monet图像数量: {n_monet}")
    print(f"Photo图像数量: {n_photo}")
    
    # 自动计算steps_per_epoch
    if args.steps_per_epoch == -1:
        args.steps_per_epoch = max(n_monet, n_photo) // args.batch_size // 4
        print(f"自动设置 steps_per_epoch: {args.steps_per_epoch}")
    else:
        print(f"使用手动设置 steps_per_epoch: {args.steps_per_epoch}")
        
except Exception as e:
    print(f"✗ 数据加载失败: {e}")
    print("请检查数据路径是否正确！")
    print(f"当前路径: {args.root_path}")


## 4. 可视化数据样本

让我们看看数据的样本，确保数据加载正确。


In [None]:
# 显示Monet样本
print("Monet 样本:")
display_samples(
    path=None,  # 不使用文件保存，直接在notebook中显示
    name="Monet", 
    ds=monet_ds.batch(1), 
    row=2, 
    col=4
)

# 显示Photo样本
print("\nPhoto 样本:")
display_samples(
    path=None,
    name="Photo", 
    ds=photo_ds.batch(1), 
    row=2, 
    col=4
)


## 5. 构建模型


In [None]:
# 创建参数对象用于模型初始化
class ModelArgs:
    def __init__(self, config):
        self.height = config.height
        self.width = config.width
        self.channels = config.channels
        self.out_channels = config.out_channels
        self.ds_augment = config.ds_augment
        self.cycle_noise = 0
        self.model = 'CycleGan'
        self.transformer_blocks = 6
        self.wandb = config.use_wandb

model_args = ModelArgs(args)

# 创建数据增强层
dsaug_layer = DataAugment(args.height, args.width, args.channels)

# 创建生成器和判别器
print("正在创建模型...")
monet_generator = CycleGan.Model(model_args).m_gen  # Monet生成器
photo_generator = CycleGan.Model(model_args).p_gen  # Photo生成器
monet_discriminator = CycleGan.Model(model_args).m_disc  # Monet判别器
photo_discriminator = CycleGan.Model(model_args).p_disc  # Photo判别器

print("✓ 模型创建成功！")
print(f"\n生成器参数量: {monet_generator.count_params():,}")
print(f"判别器参数量: {monet_discriminator.count_params():,}")


## 6. 配置优化器和损失函数


In [None]:
from tensorflow.keras import optimizers

# 创建优化器
# Adam优化器，学习率2e-4，beta_1=0.5（GAN常用配置）
monet_generator_optimizer = optimizers.Adam(args.learning_rate, beta_1=0.5)
photo_generator_optimizer = optimizers.Adam(args.learning_rate, beta_1=0.5)
monet_discriminator_optimizer = optimizers.Adam(args.learning_rate, beta_1=0.5)
photo_discriminator_optimizer = optimizers.Adam(args.learning_rate, beta_1=0.5)

print("✓ 优化器配置完成")
print(f"学习率: {args.learning_rate}")
print(f"优化器: Adam (beta_1=0.5)")


## 7. 创建CycleGAN模型并编译


In [None]:
# 创建CycleGAN模型
class CycleGanModel(tf.keras.Model):
    """CycleGAN模型类"""
    
    def __init__(self, dsaug_layer, m_gen, p_gen, m_disc, p_disc, lambda_cycle=10):
        super(CycleGanModel, self).__init__()
        self.dsaug_layer = dsaug_layer
        self.m_gen = m_gen  # Monet生成器
        self.p_gen = p_gen  # Photo生成器
        self.m_disc = m_disc  # Monet判别器
        self.p_disc = p_disc  # Photo判别器
        self.lambda_cycle = lambda_cycle
        
    def compile(self, m_gen_optimizer, p_gen_optimizer, m_disc_optimizer, p_disc_optimizer,
                gen_loss_fn, disc_loss_fn, cycle_loss_fn, identity_loss_fn, diffaugment):
        super(CycleGanModel, self).compile()
        self.m_gen_optimizer = m_gen_optimizer
        self.p_gen_optimizer = p_gen_optimizer
        self.m_disc_optimizer = m_disc_optimizer
        self.p_disc_optimizer = p_disc_optimizer
        self.gen_loss_fn = gen_loss_fn
        self.disc_loss_fn = disc_loss_fn
        self.cycle_loss_fn = cycle_loss_fn
        self.identity_loss_fn = identity_loss_fn
        self.diffaugment = diffaugment
    
    def train_step(self, batch_data):
        """训练步骤"""
        real_monet, real_photo = batch_data
        
        # 数据增强（如果启用）
        if args.ds_augment:
            real_monet = self.dsaug_layer(real_monet)
            real_photo = self.dsaug_layer(real_photo)
        
        batch_size = tf.shape(real_monet)[0]
        
        with tf.GradientTape(persistent=True) as tape:
            # Photo -> Monet -> Photo
            fake_monet = self.m_gen(real_photo, training=True)
            cycled_photo = self.p_gen(fake_monet, training=True)
            
            # Monet -> Photo -> Monet
            fake_photo = self.p_gen(real_monet, training=True)
            cycled_monet = self.m_gen(fake_photo, training=True)
            
            # Identity loss (恒等损失)
            same_monet = self.m_gen(real_monet, training=True)
            same_photo = self.p_gen(real_photo, training=True)
            
            # DiffAugmentation (如果启用)
            if len(self.diffaugment) != 0:
                both_monet = tf.concat([real_monet, fake_monet], axis=0)
                aug_monet = DiffAugment(both_monet, self.diffaugment)
                real_monet = aug_monet[:batch_size]
                fake_monet = aug_monet[batch_size:]
            
            # 判别器输出
            disc_real_monet = self.m_disc(real_monet, training=True)
            disc_fake_monet = self.m_disc(fake_monet, training=True)
            disc_real_photo = self.p_disc(real_photo, training=True)
            disc_fake_photo = self.p_disc(fake_photo, training=True)
            
            # 计算损失
            monet_gen_loss = self.gen_loss_fn(disc_fake_monet)
            photo_gen_loss = self.gen_loss_fn(disc_fake_photo)
            
            total_cycle_loss = (self.cycle_loss_fn(real_monet, cycled_monet, self.lambda_cycle) + 
                               self.cycle_loss_fn(real_photo, cycled_photo, self.lambda_cycle))
            
            total_monet_gen_loss = (monet_gen_loss + total_cycle_loss + 
                                   self.identity_loss_fn(real_monet, same_monet, self.lambda_cycle))
            total_photo_gen_loss = (photo_gen_loss + total_cycle_loss + 
                                   self.identity_loss_fn(real_photo, same_photo, self.lambda_cycle))
            
            monet_disc_loss = self.disc_loss_fn(disc_real_monet, disc_fake_monet)
            photo_disc_loss = self.disc_loss_fn(disc_real_photo, disc_fake_photo)
        
        # 计算梯度并更新
        monet_gen_grads = tape.gradient(total_monet_gen_loss, self.m_gen.trainable_variables)
        photo_gen_grads = tape.gradient(total_photo_gen_loss, self.p_gen.trainable_variables)
        monet_disc_grads = tape.gradient(monet_disc_loss, self.m_disc.trainable_variables)
        photo_disc_grads = tape.gradient(photo_disc_loss, self.p_disc.trainable_variables)
        
        self.m_gen_optimizer.apply_gradients(zip(monet_gen_grads, self.m_gen.trainable_variables))
        self.p_gen_optimizer.apply_gradients(zip(photo_gen_grads, self.p_gen.trainable_variables))
        self.m_disc_optimizer.apply_gradients(zip(monet_disc_grads, self.m_disc.trainable_variables))
        self.p_disc_optimizer.apply_gradients(zip(photo_disc_grads, self.p_disc.trainable_variables))
        
        return {
            'monet_gen_loss': total_monet_gen_loss,
            'photo_gen_loss': total_photo_gen_loss,
            'monet_disc_loss': monet_disc_loss,
            'photo_disc_loss': photo_disc_loss,
            'total_cycle_loss': total_cycle_loss
        }

# 创建模型实例
gan_model = CycleGanModel(
    dsaug_layer=dsaug_layer,
    m_gen=monet_generator,
    p_gen=photo_generator,
    m_disc=monet_discriminator,
    p_disc=photo_discriminator,
    lambda_cycle=args.lambda_cycle
)

# 编译模型
gan_model.compile(
    m_gen_optimizer=monet_generator_optimizer,
    p_gen_optimizer=photo_generator_optimizer,
    m_disc_optimizer=monet_discriminator_optimizer,
    p_disc_optimizer=photo_discriminator_optimizer,
    gen_loss_fn=generator_loss,
    disc_loss_fn=discriminator_loss,
    cycle_loss_fn=calc_cycle_loss,
    identity_loss_fn=identity_loss,
    diffaugment=args.diffaugment
)

print("✓ 模型编译完成！")


## 8. 开始训练

训练可能需要较长时间，请耐心等待。可以观察loss的变化来判断训练是否正常。


In [None]:
# 创建保存目录
save_dir = f"Something-of-a-Painter/saves/{args.model_id}-{datetime.now().strftime('%Y%m%d-%H%M%S')}"
os.makedirs(save_dir, exist_ok=True)
print(f"模型将保存到: {save_dir}")

# 训练模型
print("\n开始训练...")
print("=" * 60)

history = gan_model.fit(
    gan_ds,
    steps_per_epoch=args.steps_per_epoch,
    epochs=args.train_epochs,
    verbose=1,
    callbacks=[
        LogCallback(logger, log_interval=20),  # 每20步打印一次日志
        ClearMemory(logger),  # 清理内存
    ]
)

print("\n✓ 训练完成！")
print("=" * 60)


## 9. 可视化训练损失

查看训练过程中的损失变化。


In [None]:
# 绘制训练损失
plt.figure(figsize=(15, 5))

# Generator Loss
plt.subplot(1, 3, 1)
plt.plot(history.history['monet_gen_loss'], label='Monet Gen Loss')
plt.plot(history.history['photo_gen_loss'], label='Photo Gen Loss')
plt.title('Generator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Discriminator Loss
plt.subplot(1, 3, 2)
plt.plot(history.history['monet_disc_loss'], label='Monet Disc Loss')
plt.plot(history.history['photo_disc_loss'], label='Photo Disc Loss')
plt.title('Discriminator Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

# Cycle Loss
plt.subplot(1, 3, 3)
plt.plot(history.history['total_cycle_loss'], label='Cycle Loss')
plt.title('Cycle Consistency Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()


## 10. 可视化生成结果

让我们看看模型生成的Monet风格图像。


In [None]:
# 生成并显示样本
print("生成Monet风格图像样本...")
display_generated_samples(
    path=None,  # 直接在notebook中显示
    ds=photo_ds.batch(1),
    model=monet_generator,
    n_samples=6
)


## 11. 保存模型


In [None]:
# 保存模型
model_save_path = os.path.join(save_dir, "model_m_gen")
monet_generator.save(model_save_path)
print(f"✓ 模型已保存到: {model_save_path}")

# 也保存checkpoint
checkpoint_path = os.path.join(save_dir, "checkpoints")
os.makedirs(checkpoint_path, exist_ok=True)
gan_model.save_weights(os.path.join(checkpoint_path, "cp.ckpt"))
print(f"✓ Checkpoint已保存到: {checkpoint_path}")


## 12. 生成提交文件

为Kaggle竞赛生成提交文件。


In [None]:
# 创建输出目录
images_dir = os.path.join(save_dir, "images")
os.makedirs(images_dir, exist_ok=True)

# 生成图像
print("正在生成提交图像...")
predict_and_save(
    path=save_dir,
    input_ds=photo_ds.batch(1),
    generator_model=monet_generator
)

# 统计生成的文件
num_files = len([f for f in os.listdir(images_dir) if f.endswith('.jpg')])
print(f"\n✓ 已生成 {num_files} 张图像")

# 创建ZIP文件
import shutil
zip_path = shutil.make_archive(images_dir, 'zip', images_dir)
zip_size = os.path.getsize(zip_path) / (1024 * 1024)

print(f"\n✓ ZIP文件已创建: {zip_path}")
print(f"文件大小: {zip_size:.2f} MB")

# 验证
if num_files < 7000:
    print("⚠ 警告: 图像数量少于7000张！")
elif num_files > 10000:
    print("⚠ 警告: 图像数量超过10000张！")
else:
    print("✓ 图像数量符合要求 (7000-10000张)")


## 完成！

你的模型已经训练完成，提交文件已生成！

### 下一步：
1. 找到生成的ZIP文件（在 saves 目录下）
2. 上传到Kaggle竞赛页面
3. 等待评估结果

### 提示：
- 如果想训练更长时间以获得更好的效果，可以增加 `train_epochs` 参数
- 可以尝试不同的数据增强配置
- 记得保存这个notebook和配置
