In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from diffusers import UNet2DConditionModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from datasets import load_dataset
from PIL import Image
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm


In [None]:
data_path = "data/processed_data.pt"
data = torch.load(data_path)

prompts, image_paths = data["prompts"], data["image_paths"]

print(f"Количество записей: {len(prompts)}")
print(f"Пример запроса: {prompts[0]}")
print(f"Пример пути к изображению: {image_paths[0]}")


In [None]:
# Просмотр первых нескольких изображений
for i in range(5):
    img = Image.open(image_paths[i])
    plt.figure()
    plt.imshow(img)
    plt.title(prompts[i])
    plt.axis("off")
    plt.show()


In [None]:
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

unet = UNet2DConditionModel.from_pretrained("path_to_pretrained_unet")

scheduler = DDPMScheduler.from_pretrained("path_to_pretrained_scheduler")


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
unet = unet.to(device)
text_encoder = text_encoder.to(device)


In [None]:
batch_size = 8
learning_rate = 5e-5
num_epochs = 5

optimizer = torch.optim.Adam(unet.parameters(), lr=learning_rate)


In [None]:
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

transform = Compose([
    Resize((256, 256)),
    ToTensor(),
    Normalize((0.5,), (0.5,))
])

dataset = [
    (prompts[i], transform(torch.tensor(np.array(Image.open(image_paths[i])))))
    for i in range(len(prompts))
]

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


In [None]:
for epoch in range(num_epochs):
    epoch_loss = 0
    for prompts_batch, images_batch in tqdm(dataloader, desc=f"Epoch {epoch}"):
        inputs = tokenizer(prompts_batch, return_tensors="pt", padding=True, truncation=True)
        inputs = {key: val.to(device) for key, val in inputs.items()}

        images_batch = images_batch.to(device)

        noise = torch.randn_like(images_batch).to(device)
        noisy_images = scheduler.add_noise(images_batch, noise)

        optimizer.zero_grad()
        noise_pred = unet(noisy_images, text_embeds=text_encoder(**inputs).last_hidden_state).sample
        loss = torch.nn.functional.mse_loss(noise_pred, noise)

        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch}, Loss: {epoch_loss / len(dataloader):.4f}")


In [None]:
losses = [0.123, 0.110, 0.105, 0.098, 0.090]  # Пример данных

plt.plot(range(1, len(losses) + 1), losses, marker="o")
plt.title("График потерь (Loss)")
plt.xlabel("Эпоха")
plt.ylabel("Loss")
plt.show()


In [None]:
# Генерация тестового изображения
test_prompt = "A tribal tattoo design with sharp edges"
inputs = tokenizer([test_prompt], return_tensors="pt", padding=True, truncation=True)
inputs = {key: val.to(device) for key, val in inputs.items()}

# Генерация
noise = torch.randn((1, 3, 256, 256)).to(device)
generated_image = unet(noise, text_embeds=text_encoder(**inputs).last_hidden_state).sample

# Преобразование изображения для отображения
generated_image = (generated_image[0].cpu().detach().numpy() * 0.5 + 0.5).transpose(1, 2, 0)
plt.imshow(generated_image)
plt.title(test_prompt)
plt.axis("off")
plt.show()
