In [1]:
# 导入必要的库
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from datetime import datetime
import cv2
from PIL import Image
import shutil
from tqdm import tqdm
import tensorflow as tf  # 仅用于读取TFRecord文件

# 设置随机种子
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

# 配置matplotlib
plt.rcParams['font.sans-serif'] = ['SimHei', 'Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False

print("✅ 库导入成功！")
print(f"PyTorch版本: {torch.__version__}")

# 检查GPU
if torch.cuda.is_available():
    print(f"✅ 检测到 {torch.cuda.device_count()} 个GPU")
    print(f"当前GPU: {torch.cuda.get_device_name(0)}")
    device = torch.device('cuda')
else:
    print("⚠️ 未检测到GPU，将使用CPU（速度会很慢）")
    device = torch.device('cpu')

print(f"使用设备: {device}")


✅ 库导入成功！
PyTorch版本: 2.5.1+cu121
✅ 检测到 1 个GPU
当前GPU: NVIDIA GeForce RTX 5070 Ti
使用设备: cuda


NVIDIA GeForce RTX 5070 Ti with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90.
If you want to use the NVIDIA GeForce RTX 5070 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



## 2. 配置参数


In [2]:
# ==================== 配置参数 ====================

# 数据路径
DATA_ROOT = "data/Image_Generation_Data_Kaggle"
MONET_TFREC_PATH = os.path.join(DATA_ROOT, "monet_tfrec")
PHOTO_TFREC_PATH = os.path.join(DATA_ROOT, "photo_tfrec")

# 模型参数
IMAGE_SIZE = 256
CHANNELS = 3
LAMBDA_CYCLE = 10.0
LAMBDA_IDENTITY = 0.5

# 训练参数
BATCH_SIZE = 8          # 批次大小（根据显存调整：4/8/16）
EPOCHS = 20             # 训练轮数（推荐：20-50）
LEARNING_RATE = 2e-4    # 学习率
BETA_1 = 0.5            # Adam优化器参数

# 数据增强
USE_AUGMENTATION = True
AUGMENTATION_PROB = 0.5

# 保存设置
SAVE_DIR = "saves"
MODEL_NAME = "pytorch_cyclegan_monet"
SAVE_SAMPLES = True
NUM_SAMPLES_TO_SAVE = 10

# 打印配置
print("=" * 50)
print("训练配置:")
print("=" * 50)
print(f"数据路径: {DATA_ROOT}")
print(f"图像尺寸: {IMAGE_SIZE}x{IMAGE_SIZE}")
print(f"批次大小: {BATCH_SIZE}")
print(f"训练轮数: {EPOCHS}")
print(f"学习率: {LEARNING_RATE}")
print(f"Cycle Loss权重: {LAMBDA_CYCLE}")
print(f"使用数据增强: {USE_AUGMENTATION}")
print(f"使用设备: {device}")
print("=" * 50)


训练配置:
数据路径: data/Image_Generation_Data_Kaggle
图像尺寸: 256x256
批次大小: 8
训练轮数: 20
学习率: 0.0002
Cycle Loss权重: 10.0
使用数据增强: True
使用设备: cuda


## 3. 数据加载


In [3]:
import re

def count_data_items(filenames):
    """计算TFRecord文件中的数据项数量"""
    n = [int(re.compile(r"-([0-9]*)\.").search(filename).group(1)) for filename in filenames]
    return np.sum(n)

def decode_image(image):
    """解码JPEG图像并归一化到[-1, 1]"""
    image = tf.image.decode_jpeg(image, channels=CHANNELS)
    image = (tf.cast(image, tf.float32) / 127.5) - 1
    image = tf.reshape(image, [IMAGE_SIZE, IMAGE_SIZE, CHANNELS])
    return image

def read_tfrecord(example):
    """读取TFRecord示例"""
    tfrecord_format = {
        'image_name': tf.io.FixedLenFeature([], tf.string),
        'image': tf.io.FixedLenFeature([], tf.string),
        'target': tf.io.FixedLenFeature([], tf.string)
    }
    example = tf.io.parse_single_example(example, tfrecord_format)
    image = decode_image(example['image'])
    return image

def load_dataset(filenames):
    """从TFRecord文件加载数据集"""
    dataset = tf.data.TFRecordDataset(filenames)
    dataset = dataset.map(read_tfrecord, num_parallel_calls=tf.data.AUTOTUNE)
    return dataset

# 获取文件列表
monet_files = tf.io.gfile.glob(os.path.join(MONET_TFREC_PATH, "*.tfrec"))
photo_files = tf.io.gfile.glob(os.path.join(PHOTO_TFREC_PATH, "*.tfrec"))

print(f"找到 {len(monet_files)} 个Monet TFRecord文件")
print(f"找到 {len(photo_files)} 个Photo TFRecord文件")

# 计算数据项数量
n_monet = count_data_items(monet_files)
n_photo = count_data_items(photo_files)

print(f"Monet图像数量: {n_monet}")
print(f"Photo图像数量: {n_photo}")

# 加载数据集
monet_ds = load_dataset(monet_files)
photo_ds = load_dataset(photo_files)

print("✅ 数据加载完成")


找到 5 个Monet TFRecord文件
找到 20 个Photo TFRecord文件
Monet图像数量: 300
Photo图像数量: 7038
✅ 数据加载完成


## 4. PyTorch数据集类


In [4]:
class CycleGANDataset(Dataset):
    """CycleGAN数据集类"""
    
    def __init__(self, tf_dataset, transform=None):
        self.tf_dataset = tf_dataset
        self.transform = transform
        self.data = []
        
        # 将TensorFlow数据集转换为列表
        print("正在转换数据集...")
        for item in tqdm(tf_dataset, desc="转换数据"):
            # 转换为numpy数组
            img = item.numpy()
            # 转换为PIL图像
            img = Image.fromarray((img * 127.5 + 127.5).astype(np.uint8))
            self.data.append(img)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img = self.data[idx]
        
        if self.transform:
            img = self.transform(img)
        
        return img

# 数据变换
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 归一化到[-1, 1]
])

# 创建PyTorch数据集
print("创建Monet数据集...")
monet_dataset = CycleGANDataset(monet_ds.take(n_monet), transform)
print("创建Photo数据集...")
photo_dataset = CycleGANDataset(photo_ds.take(n_photo), transform)

# 创建数据加载器
monet_loader = DataLoader(monet_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)
photo_loader = DataLoader(photo_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0)

print(f"✅ 数据集创建完成")
print(f"Monet数据集大小: {len(monet_dataset)}")
print(f"Photo数据集大小: {len(photo_dataset)}")


创建Monet数据集...
正在转换数据集...


转换数据: 300it [00:00, 3536.78it/s]


创建Photo数据集...
正在转换数据集...


转换数据: 7038it [00:03, 1891.39it/s]

✅ 数据集创建完成
Monet数据集大小: 300
Photo数据集大小: 7038





## 5. 模型定义


In [5]:
class InstanceNorm2d(nn.Module):
    """Instance Normalization层"""
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super(InstanceNorm2d, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        
        self.weight = nn.Parameter(torch.ones(num_features))
        self.bias = nn.Parameter(torch.zeros(num_features))
        
    def forward(self, x):
        # 计算均值和方差
        mean = x.mean(dim=(2, 3), keepdim=True)
        var = x.var(dim=(2, 3), keepdim=True, unbiased=False)
        
        # 归一化
        x = (x - mean) / torch.sqrt(var + self.eps)
        
        # 缩放和偏移
        x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
        
        return x

class ResidualBlock(nn.Module):
    """残差块"""
    def __init__(self, in_channels):
        super(ResidualBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            InstanceNorm2d(in_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels, in_channels, 3, padding=1),
            InstanceNorm2d(in_channels)
        )
    
    def forward(self, x):
        return x + self.conv_block(x)

class Generator(nn.Module):
    """生成器（U-Net架构）"""
    def __init__(self, input_channels=3, output_channels=3, num_residual_blocks=6):
        super(Generator, self).__init__()
        
        # 初始卷积层
        self.initial = nn.Sequential(
            nn.Conv2d(input_channels, 64, 7, padding=3),
            InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # 下采样
        self.down1 = nn.Sequential(
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.down2 = nn.Sequential(
            nn.Conv2d(128, 256, 3, stride=2, padding=1),
            InstanceNorm2d(256),
            nn.ReLU(inplace=True)
        )
        
        # 残差块
        residual_blocks = []
        for _ in range(num_residual_blocks):
            residual_blocks.append(ResidualBlock(256))
        self.residual_blocks = nn.Sequential(*residual_blocks)
        
        # 上采样
        self.up1 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
            InstanceNorm2d(128),
            nn.ReLU(inplace=True)
        )
        
        self.up2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
            InstanceNorm2d(64),
            nn.ReLU(inplace=True)
        )
        
        # 输出层
        self.output = nn.Sequential(
            nn.Conv2d(64, output_channels, 7, padding=3),
            nn.Tanh()
        )
    
    def forward(self, x):
        # 下采样
        d1 = self.initial(x)
        d2 = self.down1(d1)
        d3 = self.down2(d2)
        
        # 残差块
        residual = self.residual_blocks(d3)
        
        # 上采样
        u1 = self.up1(residual)
        u2 = self.up2(u1)
        
        # 输出
        output = self.output(u2)
        
        return output

class Discriminator(nn.Module):
    """判别器（PatchGAN架构）"""
    def __init__(self, input_channels=3):
        super(Discriminator, self).__init__()
        
        self.model = nn.Sequential(
            # 第一层
            nn.Conv2d(input_channels, 64, 4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第二层
            nn.Conv2d(64, 128, 4, stride=2, padding=1),
            InstanceNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第三层
            nn.Conv2d(128, 256, 4, stride=2, padding=1),
            InstanceNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 第四层
            nn.Conv2d(256, 512, 4, stride=2, padding=1),
            InstanceNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),
            
            # 输出层
            nn.Conv2d(512, 1, 4, padding=1)
        )
    
    def forward(self, x):
        return self.model(x)

# 创建模型
print("创建模型...")
monet_generator = Generator().to(device)
photo_generator = Generator().to(device)
monet_discriminator = Discriminator().to(device)
photo_discriminator = Discriminator().to(device)

print(f"生成器参数量: {sum(p.numel() for p in monet_generator.parameters()):,}")
print(f"判别器参数量: {sum(p.numel() for p in monet_discriminator.parameters()):,}")
print("✅ 模型创建完成")


创建模型...
生成器参数量: 7,845,123
判别器参数量: 2,766,529
✅ 模型创建完成


In [6]:
# 损失函数
criterion_gan = nn.MSELoss()
criterion_cycle = nn.L1Loss()
criterion_identity = nn.L1Loss()

# 优化器
optimizer_G = optim.Adam(
    list(monet_generator.parameters()) + list(photo_generator.parameters()),
    lr=LEARNING_RATE, betas=(BETA_1, 0.999)
)

optimizer_D_monet = optim.Adam(monet_discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, 0.999))
optimizer_D_photo = optim.Adam(photo_discriminator.parameters(), lr=LEARNING_RATE, betas=(BETA_1, 0.999))

# 学习率调度器
scheduler_G = optim.lr_scheduler.LambdaLR(optimizer_G, lr_lambda=lambda epoch: 0.95 ** epoch)
scheduler_D_monet = optim.lr_scheduler.LambdaLR(optimizer_D_monet, lr_lambda=lambda epoch: 0.95 ** epoch)
scheduler_D_photo = optim.lr_scheduler.LambdaLR(optimizer_D_photo, lr_lambda=lambda epoch: 0.95 ** epoch)

print("✅ 损失函数和优化器配置完成")


✅ 损失函数和优化器配置完成


## 7. 训练函数


In [7]:
def train_epoch(monet_loader, photo_loader, epoch):
    """训练一个epoch"""
    monet_generator.train()
    photo_generator.train()
    monet_discriminator.train()
    photo_discriminator.train()
    
    total_loss_G = 0
    total_loss_D_monet = 0
    total_loss_D_photo = 0
    
    # 创建迭代器
    monet_iter = iter(monet_loader)
    photo_iter = iter(photo_loader)
    
    # 计算步数
    steps = min(len(monet_loader), len(photo_loader))
    
    for step in tqdm(range(steps), desc=f"Epoch {epoch+1}"):
        try:
            real_monet = next(monet_iter).to(device)
            real_photo = next(photo_iter).to(device)
        except StopIteration:
            break
        
        batch_size = real_monet.size(0)
        
        # 真实和假的标签
        real_label = torch.ones(batch_size, 1, 30, 30).to(device)
        fake_label = torch.zeros(batch_size, 1, 30, 30).to(device)
        
        # ==================== 训练生成器 ====================
        optimizer_G.zero_grad()
        
        # Identity loss
        loss_identity_monet = criterion_identity(monet_generator(real_monet), real_monet)
        loss_identity_photo = criterion_identity(photo_generator(real_photo), real_photo)
        loss_identity = (loss_identity_monet + loss_identity_photo) * LAMBDA_IDENTITY
        
        # GAN loss
        fake_monet = monet_generator(real_photo)
        fake_photo = photo_generator(real_monet)
        
        pred_fake_monet = monet_discriminator(fake_monet)
        pred_fake_photo = photo_discriminator(fake_photo)
        
        loss_G_monet = criterion_gan(pred_fake_monet, real_label)
        loss_G_photo = criterion_gan(pred_fake_photo, real_label)
        
        # Cycle loss
        cycled_monet = monet_generator(fake_photo)
        cycled_photo = photo_generator(fake_monet)
        
        loss_cycle_monet = criterion_cycle(cycled_monet, real_monet)
        loss_cycle_photo = criterion_cycle(cycled_photo, real_photo)
        loss_cycle = (loss_cycle_monet + loss_cycle_photo) * LAMBDA_CYCLE
        
        # 总生成器损失
        loss_G = loss_G_monet + loss_G_photo + loss_cycle + loss_identity
        loss_G.backward()
        optimizer_G.step()
        
        # ==================== 训练判别器 ====================
        # Monet判别器
        optimizer_D_monet.zero_grad()
        
        pred_real_monet = monet_discriminator(real_monet)
        pred_fake_monet = monet_discriminator(fake_monet.detach())
        
        loss_D_monet_real = criterion_gan(pred_real_monet, real_label)
        loss_D_monet_fake = criterion_gan(pred_fake_monet, fake_label)
        loss_D_monet = (loss_D_monet_real + loss_D_monet_fake) * 0.5
        
        loss_D_monet.backward()
        optimizer_D_monet.step()
        
        # Photo判别器
        optimizer_D_photo.zero_grad()
        
        pred_real_photo = photo_discriminator(real_photo)
        pred_fake_photo = photo_discriminator(fake_photo.detach())
        
        loss_D_photo_real = criterion_gan(pred_real_photo, real_label)
        loss_D_photo_fake = criterion_gan(pred_fake_photo, fake_label)
        loss_D_photo = (loss_D_photo_real + loss_D_photo_fake) * 0.5
        
        loss_D_photo.backward()
        optimizer_D_photo.step()
        
        # 累计损失
        total_loss_G += loss_G.item()
        total_loss_D_monet += loss_D_monet.item()
        total_loss_D_photo += loss_D_photo.item()
        
        # 打印进度
        if step % 50 == 0:
            print(f"Step {step}: G_loss={loss_G.item():.4f}, D_monet={loss_D_monet.item():.4f}, D_photo={loss_D_photo.item():.4f}")
    
    return total_loss_G / steps, total_loss_D_monet / steps, total_loss_D_photo / steps

print("✅ 训练函数定义完成")


✅ 训练函数定义完成


## 8. 开始训练


In [8]:
# 创建保存目录
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
save_dir = os.path.join(SAVE_DIR, f"{MODEL_NAME}_{timestamp}")
os.makedirs(save_dir, exist_ok=True)
print(f"模型将保存到: {save_dir}")

# 训练历史
history = {
    'G_loss': [],
    'D_monet_loss': [],
    'D_photo_loss': []
}

print("🚀 开始训练...")
print("=" * 50)

for epoch in range(EPOCHS):
    # 训练一个epoch
    avg_loss_G, avg_loss_D_monet, avg_loss_D_photo = train_epoch(monet_loader, photo_loader, epoch)
    
    # 记录历史
    history['G_loss'].append(avg_loss_G)
    history['D_monet_loss'].append(avg_loss_D_monet)
    history['D_photo_loss'].append(avg_loss_D_photo)
    
    # 更新学习率
    scheduler_G.step()
    scheduler_D_monet.step()
    scheduler_D_photo.step()
    
    # 打印epoch结果
    print(f"Epoch {epoch+1}/{EPOCHS}:")
    print(f"  G_loss: {avg_loss_G:.4f}")
    print(f"  D_monet_loss: {avg_loss_D_monet:.4f}")
    print(f"  D_photo_loss: {avg_loss_D_photo:.4f}")
    print(f"  Learning Rate: {scheduler_G.get_last_lr()[0]:.6f}")
    print("-" * 50)
    
    # 保存checkpoint
    if (epoch + 1) % 5 == 0:
        checkpoint = {
            'epoch': epoch,
            'monet_generator': monet_generator.state_dict(),
            'photo_generator': photo_generator.state_dict(),
            'monet_discriminator': monet_discriminator.state_dict(),
            'photo_discriminator': photo_discriminator.state_dict(),
            'optimizer_G': optimizer_G.state_dict(),
            'optimizer_D_monet': optimizer_D_monet.state_dict(),
            'optimizer_D_photo': optimizer_D_photo.state_dict(),
            'history': history
        }
        torch.save(checkpoint, os.path.join(save_dir, f'checkpoint_epoch_{epoch+1}.pth'))

print("✅ 训练完成！")


模型将保存到: saves\pytorch_cyclegan_monet_20251026_203754
🚀 开始训练...


Epoch 1:   0%|                                                                                  | 0/38 [00:00<?, ?it/s]


RuntimeError: CUDA error: no kernel image is available for execution on the device
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
