In [None]:
import os
import re

from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler
import torch.nn.functional as F
from torch.optim import AdamW
from tqdm import tqdm
import torch.nn as nn
import torch
import timm



In [None]:
batch_size = 64
learning_rate = 1e-4
num_epochs = 10
image_size = 64
data_path = "data/bedroom"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)


In [None]:
def find_folder_with_max_number(base_path: str, type_name: str) -> int:
    max_number = -1

    regex_ = f"{type_name}_" + r"(\d+)"
    pattern = re.compile(regex_)
    
    for item in os.listdir(base_path):
        item_path = os.path.join(base_path, item)

        if os.path.isdir(item_path):
            match = pattern.match(item)
            if match:
                number = int(match.group(1))
                if number > max_number:
                    max_number = number

    return max_number


In [None]:
transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

dataset = datasets.ImageFolder(data_path, transform=transform)

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
last_epoch = find_folder_with_max_number(
    "model_paths/stable_diffusion",
    "all_data"
    )

print("Last epoch:", last_epoch)

if last_epoch == -1:
    model_id = "CompVis/stable-diffusion-v1-4"
    last_epoch = 0
else:
    model_id = f"model_paths/stable_diffusion/all_data_{last_epoch}"


In [None]:
class ResNetDiffusionModel(nn.Module):
    def __init__(self):
        super(ResNetDiffusionModel, self).__init__()
        self.resnet = timm.create_model("resnet18", pretrained=True)
        self.resnet.fc = nn.Identity()
        self.fc = nn.Linear(512, 3 * 64 * 64)
        
        self.time_embed = nn.Sequential(
            nn.Linear(1, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU()
        )

    def forward(self, x, timesteps):
        x = self.resnet(x)
        t = self.time_embed(timesteps.float().view(-1, 1))
        x = x + t
        x = self.fc(x)
        return x.view(-1, 3, 64, 64)


In [None]:
model = ResNetDiffusionModel()
model.to(device)

optimizer = AdamW(model.parameters(), lr=learning_rate)

scheduler = DDPMScheduler(
    num_train_timesteps=1000,
    beta_start=0.0001,
    beta_end=0.02,
    beta_schedule="linear"
)


In [None]:
for epoch in range(last_epoch+1, last_epoch+10):
    model.train()
    for data in tqdm(dataloader):
        inputs, _ = data
        inputs = inputs.to(device)

        noise = torch.randn_like(inputs)
        timesteps = torch.randint(
            0,
            scheduler.num_train_timesteps,
            (inputs.shape[0], ),
            device=device
            ).long()

        noisy_inputs = scheduler.add_noise(inputs, noise, timesteps)

        noise_pred = model(noisy_inputs, timesteps)
        loss = F.mse_loss(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    torch.save(model.state_dict(), f"model_paths/resnet/all_data_{epoch}.pth")
    print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
