### 模型定义

In [8]:
!pip install -q bitsandbytes

In [9]:
from diffusers import UNet2DConditionModel
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# 定义U-Net模型
UNet = UNet2DConditionModel(
    sample_size=64,
    in_channels=4,
    out_channels=4, 
    layers_per_block=2, 
    block_out_channels=(320, 640, 1280, 1280),
    cross_attention_dim=768
).to(DEVICE)

### 数据预处理方法定义

In [10]:
from transformers import CLIPTokenizer
from torchvision import transforms

# 数据预处理
transform = transforms.Compose([
    transforms.Resize(512, interpolation=transforms.InterpolationMode.BILINEAR),
    transforms.CenterCrop(512),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])


# 加载Text Encoder
tokenizer = CLIPTokenizer.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="tokenizer")

### 噪声调度器定义

In [11]:
from diffusers import DDPMScheduler
# 定义噪声调度器
noise_scheduler = DDPMScheduler(
    num_train_timesteps=1000,  # 定义整个扩散过程被切分成了多少时间步
    beta_start=0.00085,
    beta_end=0.012,         # 这两个参数定义了噪声方差序列的起点和终点(接近起点加的噪声量小,接近终点加的噪声量大)
    beta_schedule="scaled_linear" # 噪声beta增长的策略,此处为线性增长
)

### 处理数据集

In [12]:
import os
from datasets import load_dataset
import datasets

# 1. 自动定位那个名字很长的文件
DATA_DIR = "/kaggle/input/pokemon-captions-dataset"
# 找到目录下那个 train-xxxxx 开头的文件
parquet_file = [f for f in os.listdir(DATA_DIR) if f.startswith("train")][0]
full_path = os.path.join(DATA_DIR, parquet_file)

print(f"检测到数据文件: {full_path}")

# 2. 使用 "parquet" 引擎加载 (关键修改点)
# 我们不传 data_dir，而是直接传 data_files
dataset = load_dataset("parquet", data_files={"train": full_path}, split="train")

IMG_COL = "image" 
TXT_COL = "en_text"  

dataset = dataset.cast_column(IMG_COL, datasets.Image())
# 3. 检查一下列名 (防止列名不叫 image 或 text)
print(">>> 数据集列名:", dataset.column_names)
# 这一步很重要！等下运行完看看 print 出来的是不 ['image', 'text']
# 如果不是，你需要修改 preprocess_data 函数里的 key

# --- 接下来的预处理代码 (保持原逻辑，但注意检查列名) ---
# 假设列名是 'image' 和 'text' (如果打印出来不一样，请在这里修改)


def preprocess_data(examples):
    # 处理图片
    images = [image.convert("RGB") for image in examples[IMG_COL]]
    examples["pixel_values"] = [transform(image) for image in images]
    
    # 处理文本
    text_inputs = tokenizer(
        examples[TXT_COL], 
        padding="max_length", 
        truncation=True, 
        max_length=tokenizer.model_max_length,
        return_tensors="pt"
    )
    examples["input_ids"] = text_inputs.input_ids
    return examples

# 应用处理
train_dataset = dataset.map(preprocess_data, batched=True, remove_columns=dataset.column_names)

检测到数据文件: /kaggle/input/pokemon-captions-dataset/train-00000-of-00001-b976ba0e28fc7cf1.parquet
>>> 数据集列名: ['image', 'en_text', 'ja_text']


In [13]:
# 验证函数:定性验证
from diffusers import StableDiffusionPipeline
import matplotlib.pyplot as plt

# 定义一个验证函数
def log_validation(unet, vae, text_encoder, tokenizer, args, epoch, step):
    print(f"\n>>> 正在生成验证图片 (Epoch {epoch} Step {step})...")
    
    # 1. 创建临时的 Pipeline
    # 我们复用已经加载到显存里的组件，不需要重新下载
    pipeline = StableDiffusionPipeline(
        text_encoder=text_encoder,
        vae=vae,
        unet=unet,
        tokenizer=tokenizer,
        scheduler=noise_scheduler,
        safety_checker=None, # 验证时不需要安全检查器，省显存
        feature_extractor=None,
        requires_safety_checker=False
    )
    
    # 确保 Pipeline 在正确的设备上
    pipeline.set_progress_bar_config(disable=True) # 关闭进度条防止刷屏

    # 2. 定义验证用的提示词 (固定下来，方便对比)
    validation_prompts = [
        "a cute green pokemon with large eyes",
        "a red fire type pokemon",
        "drawing of a blue bird pokemon"
    ]
    
    # 3. 生成图片
    images = []
    # 开启混合精度推理 (跟训练保持一致)
    with torch.cuda.amp.autocast():
        for prompt in validation_prompts:
            # generator参数用于固定随机种子，保证每次生成的随机性一致，只看模型的变化
            generator = torch.Generator(device=DEVICE).manual_seed(42)
            image = pipeline(prompt, num_inference_steps=30, generator=generator).images[0]
            images.append(image)
            
            # 保存到硬盘
            save_path = f"/kaggle/working/val_epoch_{epoch}_step_{step}_{prompt[:10].replace(' ', '_')}.png"
            image.save(save_path)
    
    print(f">>> 验证图片已保存到 /kaggle/working/ 目录")
    
    # (可选) 如果你想在 Notebook 里直接看图，可以用 matplotlib
    # 注意：这可能会稍微阻塞一下训练
    # fig, axs = plt.subplots(1, 3, figsize=(15, 5))
    # for i, img in enumerate(images):
    #     axs[i].imshow(img)
    #     axs[i].set_title(validation_prompts[i])
    #     axs[i].axis('off')
    # plt.show()
    
    # 4. 清理内存 (很重要！防止验证完显存没释放导致训练OOM)
    del pipeline
    torch.cuda.empty_cache()

# --- 在训练循环中调用 ---
# 在你的 for step, batch in enumerate(train_dataloader): 循环内部
# 找个地方插入：

# 例如在打印 Loss 的地方：
# if step % 500 == 0:  <-- 建议每 500 步或者每个 Epoch 结束测一次
#     log_validation(model, vae, text_encoder, tokenizer, None, epoch, step)

### 训练代码

In [None]:
from datasets import load_dataset
import torch.nn.functional as F
import torch
from PIL import Image
from torch.utils.data import DataLoader
from diffusers import AutoencoderKL
from transformers import CLIPTextModel
import torch.nn as nn
import bitsandbytes as bnb
# --- 引入混合精度工具 ---
from torch.cuda.amp import autocast, GradScaler

# 超参数设置
BATCH_SIZE = 2
EPOCHS = 40
LEARNING_RATE = 1e-5


model = UNet
if torch.cuda.device_count() > 1:
    print(f"检测到{torch.cuda.device_count()}张卡,开启并行加速")
    model = nn.DataParallel(model)

model.train()
vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae").to(DEVICE)
# vae.half()
vae.requires_grad_(False)  # 冻结VAE参数
text_encoder = CLIPTextModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="text_encoder").to(DEVICE)
# text_encoder.half()
text_encoder.requires_grad_(False) # 冻结Text Encoder参数


def collate_fn(datas):
    '''
    输入是字典列表,列表中元素的个数就是bs,现在就是要进行一个汇总,将列表中的每个元素的相同key对应的value进行堆叠
    '''
    pixel_values = torch.stack([torch.tensor(data["pixel_values"]) for data in datas])
    input_ids = torch.stack([torch.tensor(data["input_ids"]) for data in datas])
    return {"pixel_values": pixel_values, "input_ids": input_ids}

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)
# 优化器
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=LEARNING_RATE)

# 初始化缩放器
scaler = GradScaler()
# 训练循环
print("开始训练!")
for epoch in range(EPOCHS):
    for step, batch in enumerate(train_dataloader):
        # 阶段A:准备Latents和text embeddings
        # 1. 准备图像压缩(Latents)和text embeddings
        clean_images = batch["pixel_values"].to(DEVICE)
        text_input_ids = batch["input_ids"].to(DEVICE)
        with torch.no_grad():
            latents = vae.encode(clean_images).latent_dist.sample() * 0.18215
            encoder_hidden_states = text_encoder(text_input_ids)[0]
             
        # 阶段B:训练U-Net
        # 2. 准备增加噪声
        # 随机采样噪声
        noise = torch.randn_like(latents)
        bsz = latents.shape[0]

        # 3. 随机采样时间步
        timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=DEVICE).long()
        # 4. 前向加噪
        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

        with autocast():
            # 5. 模型预测:UNet尝试从noisy latents中预测出噪声
            noise_pred = model(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)

        # 6. 反向传播 (使用 scaler)
        scaler.scale(loss).backward() # <--- 缩放 Loss 防止溢出
        scaler.step(optimizer)        # <--- 更新权重
        scaler.update()               # <--- 更新缩放器
        if step % 10 == 0:
            print(f"Epoch: {epoch} | Step: {step} | Loss: {loss.item():.4f}")
        if step % 50 == 0:
            log_validation(model, vae, text_encoder, tokenizer, None, epoch, step)

    print(f"Epoch: {epoch} completed.")
# --- 保存模型 ---
print("正在保存模型...")
model.save_pretrained("./my-trained-sd-pokemon")
print("模型已保存到 ./my-trained-sd-pokemon 文件夹")
print("Training finished.")

  scaler = GradScaler()


开始训练!


  with autocast():


Epoch: 0 | Step: 0 | Loss: 1.0939

>>> 正在生成验证图片 (Epoch 0 Step 0)...


  with torch.cuda.amp.autocast():


>>> 验证图片已保存到 /kaggle/working/ 目录
Epoch: 0 | Step: 10 | Loss: 0.8117
Epoch: 0 | Step: 20 | Loss: 0.3546
Epoch: 0 | Step: 30 | Loss: 0.2214
Epoch: 0 | Step: 40 | Loss: 0.3905
Epoch: 0 | Step: 50 | Loss: 0.4565

>>> 正在生成验证图片 (Epoch 0 Step 50)...
>>> 验证图片已保存到 /kaggle/working/ 目录
Epoch: 0 | Step: 60 | Loss: 0.2655
Epoch: 0 | Step: 70 | Loss: 0.1042
Epoch: 0 | Step: 80 | Loss: 0.1060
Epoch: 0 | Step: 90 | Loss: 0.1168
Epoch: 0 | Step: 100 | Loss: 0.3183

>>> 正在生成验证图片 (Epoch 0 Step 100)...
>>> 验证图片已保存到 /kaggle/working/ 目录
Epoch: 0 | Step: 110 | Loss: 0.3465
Epoch: 0 | Step: 120 | Loss: 0.2227
Epoch: 0 | Step: 130 | Loss: 0.3719
Epoch: 0 | Step: 140 | Loss: 0.0627
Epoch: 0 | Step: 150 | Loss: 0.0956

>>> 正在生成验证图片 (Epoch 0 Step 150)...
>>> 验证图片已保存到 /kaggle/working/ 目录
Epoch: 0 | Step: 160 | Loss: 0.0602
Epoch: 0 | Step: 170 | Loss: 0.5731
Epoch: 0 | Step: 180 | Loss: 0.1076
Epoch: 0 | Step: 190 | Loss: 0.2011
Epoch: 0 | Step: 200 | Loss: 0.3802

>>> 正在生成验证图片 (Epoch 0 Step 200)...
>>> 验证图片已保存到 /