In [1]:
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/DDPM.py' -O 'DDPM.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/VAE.py' -O 'VAE.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/DataReader.py' -O 'DataReader.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/UNet.py' -O 'UNet.py';

from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import kagglehub
import gc

import VAE, UNet, DDPM, DataReader
from DDPM import q_sample, get_time_condition, train, validate, denoise_image
clear_output()

In [2]:
# seed
SEED = 42
# Параметры
high_res_dir = kagglehub.dataset_download("arnaud58/flickrfaceshq-dataset-ffhq")
batch_size = 1
num_workers = 4

# Количество шагов
T = 1000
# Настройка бета (variance schedule) от 1e-4 до 0.02 в течение T шагов
betas = torch.linspace(1e-4, 0.02, T)

# Параметры модели DDPM
epochs = 10
learning_rate = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_val_loss = float('inf')  # Инициализация лучшего значения валидационной потери
best_model_path = 'best_model_DDPM_1.pth'  # Путь для сохранения лучшей модели
# ddpm_input_path = '/kaggle/input/ddpm-v1/pytorch/default/2/best_model_DDPM-2.pth'
vae_input_path = "/kaggle/input/best_model_vae/pytorch/default/1/best_model_VAE.pth"

# Загружаем модель
vae = VAE.VAEUNet(latent_dim=1024, bilinear=True)
# Загружаем сохранённые веса в модель
vae.load_state_dict(torch.load(vae_input_path, map_location=device, weights_only=True))
vae = vae.to(device)
# Переводим модель в режим оценки
vae.eval();

# Инициализация модели
ddpm = UNet.UNet(in_channels=7, out_channels=3, num_layers=4).to(device)
# ddpm.load_state_dict(torch.load(ddpm_input_path, map_location=device, weights_only=True))
ddpm = ddpm.to(device)

# Функции потерь
reconstruction_loss_fn = nn.MSELoss()

# Оптимизатор
optimizer = optim.Adam(ddpm.parameters(), lr=learning_rate)

dataset = DataReader.SuperResolutionDataset(dir=high_res_dir, size=10)
# Определяем размеры разбиений
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Разбиваем набор данных
torch.manual_seed(SEED)
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Создаем DataLoader для каждой выборки
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [3]:
for epoch in range(1):
    train(ddpm, vae, train_loader, optimizer, reconstruction_loss_fn, device, betas, T, batch_size)
    train_loss = validate(ddpm, vae, train_loader, reconstruction_loss_fn, device, betas, T, batch_size)
    val_loss = validate(ddpm, vae, val_loader, reconstruction_loss_fn, device, betas, T, batch_size)
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Сохранение модели, если валидационная потеря улучшилась
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(ddpm.state_dict(), best_model_path)
        print(f'Model saved at epoch {epoch} with validation loss: {val_loss:.4f}')

Training: 100%|██████████| 8/8 [00:03<00:00,  2.53it/s]
Validation: 100%|██████████| 8/8 [00:01<00:00,  7.83it/s]
Validation: 100%|██████████| 1/1 [00:00<00:00,  3.74it/s]


Epoch [0/10], Train Loss: 0.9933, Validation Loss: 0.9954
Model saved at epoch 0 with validation loss: 0.9954
