In [1]:
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/DDPM.py' -O 'DDPM.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/VAE.py' -O 'VAE.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/DataReader.py' -O 'DataReader.py';
!wget 'https://raw.githubusercontent.com/Sergey-Pidenko/DDPM/refs/heads/main/Noise.py' -O 'Noise.py';

from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import kagglehub
import gc

import VAE, DDPM, DataReader, Noise

clear_output()

In [2]:
# seed
SEED = 42
# Параметры
high_res_dir = kagglehub.dataset_download("arnaud58/flickrfaceshq-dataset-ffhq")
batch_size = 8
num_workers = 4

# Количество шагов
T = 1000
# Настройка бета (variance schedule) от 1e-4 до 0.02 в течение T шагов
betas = torch.linspace(1e-4, 0.02, T)

# Параметры модели DDPM
epochs = 10
learning_rate = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

best_val_loss = float('inf')  # Инициализация лучшего значения валидационной потери
best_model_path = 'best_model_DDPM.pth'  # Путь для сохранения лучшей модели
ddpm_input_path = '/kaggle/input/ddpm-v1/pytorch/default/2/best_model_DDPM-2.pth'
vae_input_path = "/kaggle/input/best_model_vae/pytorch/default/1/best_model_VAE.pth"

# Загружаем модель
vae = VAE.VAEUNet(latent_dim=1024, bilinear=True)
# Загружаем сохранённые веса в модель
vae.load_state_dict(torch.load(vae_input_path, map_location=device, weights_only=True))
vae = vae.to(device)
# Переводим модель в режим оценки
vae.eval();

# Инициализация модели
ddpm = DDPM.UNet(in_channels=7, out_channels=3, num_layers=4).to(device)
ddpm.load_state_dict(torch.load(ddpm_input_path, map_location=device, weights_only=True))
ddpm = ddpm.to(device)

# Функции потерь
reconstruction_loss_fn = nn.MSELoss()

# Оптимизатор
optimizer = optim.Adam(ddpm.parameters(), lr=learning_rate)

In [3]:
torch.manual_seed(SEED)

dataset = DataReader.SuperResolutionDataset(dir=high_res_dir, size=10000)
# Определяем размеры разбиений
train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size

# Разбиваем набор данных
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Создаем DataLoader для каждой выборки
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

In [4]:
def time_condition(tensor, shape=(512, 512)):
    batch_size = tensor.size(0)
    # Генерируем все необходимые seed
    seeds = tensor.unsqueeze(1).unsqueeze(2).expand(batch_size, *shape)
    # Создаем тензор случайных чисел
    generated_tensors = torch.empty((batch_size, *shape))
    
    for i in range(batch_size):
        torch.manual_seed(tensor[i].item())
        generated_tensors[i] = torch.rand(*shape)
        
    return generated_tensors.unsqueeze(1)

In [5]:
def train(model, cond_model, train_loader, optimizer, loss_fn, device):
    model.train()
    for low_res, high_res in tqdm(train_loader, desc="Training"):
        t = torch.randint(0, T, (batch_size,), dtype=torch.long)
        noisy_images, noise = Noise.q_sample(high_res, t, betas)

        low_res, high_res = low_res.to(device), high_res.to(device)
        noisy_images, noise = noisy_images.to(device), noise.to(device)
        t_cond = time_condition(tensor=t).to(device)
        
        cond, _, _ = cond_model(low_res)

        inp = torch.concat([noisy_images, cond, t_cond], dim=1)
        
        optimizer.zero_grad()
        # Прямой проход
        outputs = model(inp)
        # Вычисление потери
        loss = loss_fn(outputs, noise)
        # Назад и оптимизация
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()
        gc.collect()

def validate(model, cond_model, val_loader, loss_fn, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for low_res, high_res in tqdm(val_loader, desc="Validation"):
            t = torch.randint(0, T, (batch_size,), dtype=torch.long)
            noisy_images, noise = Noise.q_sample(high_res, t, betas)

            low_res, high_res = low_res.to(device), high_res.to(device)
            noisy_images, noise = noisy_images.to(device), noise.to(device)
            t_cond = time_condition(tensor=t).to(device)
            
            cond, _, _ = cond_model(low_res)

            inp = torch.concat([noisy_images, cond, t_cond], dim=1)
            
            # Прямой проход
            outputs = model(inp)
            # Вычисление потери
            val_loss += loss_fn(outputs, noise).item()
            torch.cuda.empty_cache()
            gc.collect()

    return val_loss / len(val_loader)

In [6]:
for epoch in range(3, epochs + 1):
    train(ddpm, vae, train_loader, optimizer, reconstruction_loss_fn, device)
    train_loss = validate(ddpm, vae, train_loader, reconstruction_loss_fn, device)
    val_loss = validate(ddpm, vae, val_loader, reconstruction_loss_fn, device)
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

    # Сохранение модели, если валидационная потеря улучшилась
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(ddpm.state_dict(), best_model_path)
        print(f'Model saved at epoch {epoch} with validation loss: {val_loss:.4f}')

Training: 100%|██████████| 1000/1000 [42:33<00:00,  2.55s/it]
Validation: 100%|██████████| 1000/1000 [16:58<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:06<00:00,  1.02s/it]


Epoch [3/10], Train Loss: 0.0260, Validation Loss: 0.0253
Model saved at epoch 3 with validation loss: 0.0253


Training: 100%|██████████| 1000/1000 [42:46<00:00,  2.57s/it]
Validation: 100%|██████████| 1000/1000 [16:57<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:07<00:00,  1.02s/it]


Epoch [4/10], Train Loss: 0.0293, Validation Loss: 0.0282


Training: 100%|██████████| 1000/1000 [42:47<00:00,  2.57s/it]
Validation: 100%|██████████| 1000/1000 [16:58<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:07<00:00,  1.02s/it]


Epoch [5/10], Train Loss: 0.0231, Validation Loss: 0.0219
Model saved at epoch 5 with validation loss: 0.0219


Training: 100%|██████████| 1000/1000 [42:44<00:00,  2.56s/it]
Validation: 100%|██████████| 1000/1000 [17:06<00:00,  1.03s/it]
Validation: 100%|██████████| 125/125 [02:08<00:00,  1.03s/it]


Epoch [6/10], Train Loss: 0.0239, Validation Loss: 0.0228


Training: 100%|██████████| 1000/1000 [42:51<00:00,  2.57s/it]
Validation: 100%|██████████| 1000/1000 [17:05<00:00,  1.03s/it]
Validation: 100%|██████████| 125/125 [02:08<00:00,  1.03s/it]


Epoch [7/10], Train Loss: 0.0210, Validation Loss: 0.0198
Model saved at epoch 7 with validation loss: 0.0198


Training: 100%|██████████| 1000/1000 [42:47<00:00,  2.57s/it]
Validation: 100%|██████████| 1000/1000 [17:03<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:08<00:00,  1.03s/it]


Epoch [8/10], Train Loss: 0.0195, Validation Loss: 0.0183
Model saved at epoch 8 with validation loss: 0.0183


Training: 100%|██████████| 1000/1000 [42:42<00:00,  2.56s/it]
Validation: 100%|██████████| 1000/1000 [17:02<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:07<00:00,  1.02s/it]


Epoch [9/10], Train Loss: 0.0185, Validation Loss: 0.0172
Model saved at epoch 9 with validation loss: 0.0172


Training: 100%|██████████| 1000/1000 [42:39<00:00,  2.56s/it]
Validation: 100%|██████████| 1000/1000 [17:01<00:00,  1.02s/it]
Validation: 100%|██████████| 125/125 [02:08<00:00,  1.03s/it]

Epoch [10/10], Train Loss: 0.0216, Validation Loss: 0.0203



