In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torchvision.utils import save_image
from diffusers import DDPMPipeline

from diffusers import UNet2DModel, DDPMScheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
import torch

import os
import re


In [None]:
# Hyperparameters
latent_dim = 100
batch_size = 64
lr = 0.0002
epochs = 200
image_size = 64
data_path = "data/bedroom"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:
pipeline = DDPMPipeline.from_pretrained("google/ddpm-cifar10-32")
pipeline.to(device)


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, image_size):
        super(Generator, self).__init__()
        self.image_size = image_size
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 256),
            nn.BatchNorm1d(256, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(256, 512),
            nn.BatchNorm1d(512, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(512, 1024),
            nn.BatchNorm1d(1024, 0.8),
            nn.ReLU(inplace=True),
            nn.Linear(1024, image_size * image_size * 3),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), 3, self.image_size, self.image_size)
        return img


class Discriminator(nn.Module):
    def __init__(self, image_size):
        super(Discriminator, self).__init__()
        self.image_size = image_size
        self.model = nn.Sequential(
            nn.Linear(image_size * image_size * 3, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)
        return validity


In [None]:
# Initialize models
generator = Generator(latent_dim, image_size).to(device)
discriminator = Discriminator(image_size).to(device)

# Optimizers
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))

# Loss function
adversarial_loss = nn.BCELoss()

# Dataloader
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(data_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
for epoch in range(epochs):
    for i, (imgs, _) in tqdm(enumerate(dataloader)):

        # Adversarial ground truths
        valid = torch.ones(imgs.size(0), 1, device=device, dtype=torch.float)
        fake = torch.zeros(imgs.size(0), 1, device=device, dtype=torch.float)

        # Configure input
        real_imgs = imgs.to(device)

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = torch.randn(imgs.size(0), latent_dim, device=device)

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Loss for real images
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        # Loss for fake images
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        # Total discriminator loss
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()

    print(f"[Epoch {epoch}/{epochs}]", end=" ")
    print(f"[D loss: {d_loss.item()}]", end=" ")
    print(f"[G loss: {g_loss.item()}]")

    # Save models
    torch.save(generator.state_dict(), f"models/generator_epoch_{epoch}.pth")
    torch.save(discriminator.state_dict(), f"models/discriminator_epoch_{epoch}.pth")

    # Save generated images every epoch
    save_image(gen_imgs.data[:25], f"images/{epoch}.png", nrow=5, normalize=True)

    # Generate samples using the diffusion model
    diffusion_samples = pipeline(num_inference_steps=50).images
    save_image(diffusion_samples[:25], f"images/diffusion_{epoch}.png", nrow=5, normalize=True)
