In [None]:
from diffusers import UNet2DModel, DDIMScheduler
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from transformers import AdamW
from tqdm import tqdm
import torch

import os
import re


In [None]:
batch_size = 64
learning_rate = 1e-4
num_epochs = 100
image_size = 64
data_path = "data/bedroom"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:
def find_folder_with_max_number(base_path: str, type_name: str) -> int:
    max_number = -1

    regex_ = f"{type_name}_" + r"(\d+)"
    pattern = re.compile(regex_)
    
    for item in os.listdir(base_path):
        item_path = os.path.join(base_path, item)

        if os.path.isdir(item_path):
            match = pattern.match(item)
            if match:
                number = int(match.group(1))
                if number > max_number:
                    max_number = number

    return max_number


In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(data_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
last_epoch = find_folder_with_max_number("model_paths/ddim", "all_data")
print("Last epoch:", last_epoch)

if last_epoch == -1:
    path = "google/ddpm-cifar10-32"
else:
    path = f"model_paths/ddim/all_data_{last_epoch}"


In [None]:
model = UNet2DModel.from_pretrained(path)
model.to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)

scheduler = DDIMScheduler(num_train_timesteps=1000)


In [None]:
for epoch in range(last_epoch+1, last_epoch+10):
    model.train()

    for data in tqdm(dataloader):
        inputs, _ = data
        inputs = inputs.to(device)

        noise = torch.randn_like(inputs)
        timesteps = torch.randint(
            0,
            scheduler.num_train_timesteps,
            (inputs.shape[0],),
            device=device
            ).long()

        noisy_inputs = scheduler.add_noise(inputs, noise, timesteps)

        noise_pred = model(noisy_inputs, timesteps).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    model.save_pretrained(f"model_paths/ddim/all_data_{epoch}")
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
