In [None]:
import torch
from diffusers import StableDiffusionPipeline
import torchvision
from datasets import load_dataset
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from tqdm import tqdm

In [None]:
import pytorch_lightning as pl

In [None]:
device = torch.device("cuda")

In [None]:

dataset = load_dataset("huggan/smithsonian_butterflies_subset", split="train")

# Or load images from a local folder
# dataset = load_dataset("imagefolder", data_dir="path/to/folder")

# We'll train on 32-pixel square images, but you can try larger sizes too
image_size = 32
# You can lower your batch size if you're running out of GPU memory
batch_size = 64

# Define data augmentations
preprocess = transforms.Compose(
    [
        transforms.Resize((image_size, image_size)),  # Resize
        transforms.RandomHorizontalFlip(),  # Randomly flip (data augmentation)
        transforms.ToTensor(),  # Convert to tensor (0, 1)
        transforms.Normalize([0.5], [0.5]),  # Map to (-1, 1)
    ]
)


def transform(examples):
    images = [preprocess(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}


dataset.set_transform(transform)

# Create a dataloader from the dataset to serve up the transformed images in batches
train_dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [None]:
xb = next(iter(train_dataloader))["images"].to(device)[:8]
print("X shape:", xb.shape)

In [None]:
from diffusers import DDPMScheduler

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)

In [None]:
from diffusers import UNet2DModel

# Create a model
model = UNet2DModel(
    sample_size=image_size,  # the target image resolution
    in_channels=3,  # the number of input channels, 3 for RGB images
    out_channels=3,  # the number of output channels
    layers_per_block=2,  # how many ResNet layers to use per UNet block
    block_out_channels=(64, 128, 128, 256),  # More channels -> more parameters
    down_block_types=(
        "DownBlock2D",  # a regular ResNet downsampling block
        "DownBlock2D",
        "AttnDownBlock2D",  # a ResNet downsampling block with spatial self-attention
        "AttnDownBlock2D",
    ),
    up_block_types=(
        "AttnUpBlock2D",
        "AttnUpBlock2D",  # a ResNet upsampling block with spatial self-attention
        "UpBlock2D",
        "UpBlock2D",  # a regular ResNet upsampling block
    ),
)
model.to(device)

In [None]:
# # Set the noise scheduler
# noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule="squaredcos_cap_v2")

# # Training loop
# optimizer = torch.optim.AdamW(model.parameters(), lr=4e-4)

# losses = []

# for epoch in range(30):
#     for step, batch in tqdm(enumerate(train_dataloader), total=len(train_dataloader)):
#         clean_images = batch["images"].to(device)
#         # Sample noise to add to the images
#         noise = torch.randn(clean_images.shape).to(clean_images.device)
#         bs = clean_images.shape[0]

#         # Sample a random timestep for each image
#         timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=clean_images.device).long()

#         # Add noise to the clean images according to the noise magnitude at each timestep
#         noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)

#         # Get the model prediction
#         noise_pred = model(noisy_images, timesteps, return_dict=False)[0]

#         # Calculate the loss
#         loss = F.mse_loss(noise_pred, noise)
#         loss.backward(loss)
#         losses.append(loss.item())

#         # Update the model parameters with the optimizer
#         optimizer.step()
#         optimizer.zero_grad()

#     if (epoch + 1) % 5 == 0:
#         loss_last_epoch = sum(losses[-len(train_dataloader) :]) / len(train_dataloader)
#         print(f"Epoch:{epoch+1}, loss: {loss_last_epoch}")

In [None]:
class DiffusionModelPL(pl.LightningModule):
    def __init__(self, model, noise_scheduler):
        super().__init__()
        self.model = model
        self.noise_scheduler = noise_scheduler

    def training_step(self, batch, batch_idx):
        # 获取干净图像并生成噪声
        clean_images = batch["images"]
        noise = torch.randn_like(clean_images)
        batch_size = clean_images.shape[0]

        # 生成随机时间步
        timesteps = torch.randint(
            0, self.noise_scheduler.num_train_timesteps,
            (batch_size,),
            device=self.device
        ).long()

        # 添加噪声
        noisy_images = self.noise_scheduler.add_noise(clean_images, noise, timesteps)
        
        # 模型预测
        noise_pred = self.model(noisy_images, timesteps, return_dict=False)[0]
        
        # 计算损失
        loss = F.mse_loss(noise_pred, noise)
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def configure_optimizers(self):
        return torch.optim.AdamW(self.model.parameters(), lr=4e-4)

    def on_train_epoch_end(self):
        # 每5个epoch打印一次平均损失
        if (self.current_epoch + 1) % 5 == 0:
            avg_loss = self.trainer.callback_metrics["train_loss"]
            print(f"Epoch: {self.current_epoch+1}, Loss: {avg_loss:.4f}")

pl_model = DiffusionModelPL(model, noise_scheduler)

# 创建训练器并训练
trainer = pl.Trainer(
    max_epochs=30,
    accelerator="auto",  # 自动选择可用加速器（GPU/TPU等）
    devices="auto"        # 自动选择可用设备
)
trainer.fit(pl_model, train_dataloaders=train_dataloader)

In [None]:
from diffusers import DDPMPipeline

image_pipe = DDPMPipeline(unet=model, scheduler=noise_scheduler).to(device)

In [None]:
pipeline_output = image_pipe()
generated_image = pipeline_output.images[0]
generated_image