In [65]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm

import VAE, DDPM, DataReader, Noise

In [47]:
# Параметры
high_res_dir = '/Users/sergei0000/.cache/kagglehub/datasets/arnaud58/flickrfaceshq-dataset-ffhq/versions/1'
batch_size = 1
num_workers = 0

dataset = DataReader.SuperResolutionDataset(high_res_dir)
# Определяем размеры разбиений
train_size = 8  # int(0.8 * len(dataset))
val_size = 1  # int(0.1 * len(dataset))
test_size = 1   # len(dataset) - train_size - val_size

# Разбиваем набор данных
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])

# Создаем DataLoader для каждой выборки
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Проверка даталоадера
for low_res, high_res in train_loader:
    print(low_res.shape)
    break

torch.Size([1, 3, 128, 128])


In [31]:
vae = VAE.VAEUNet(latent_dim=1024, bilinear=True)
# Загружаем сохранённые веса в модель
vae.load_state_dict(torch.load("best_model_VAE.pth", map_location=torch.device('cpu')))
# Переводим модель в режим оценки
vae.eval()

input_tensor = torch.randn(4, 3, 128, 128)
output, mu, logvar = vae(input_tensor)
print("Output shape:", output.shape)  # должно выводить (4, 3, 512, 512)

  vae.load_state_dict(torch.load("best_model_VAE.pth", map_location=torch.device('cpu')))


Output shape: torch.Size([4, 3, 512, 512])


In [33]:
ddpm = DDPM.UNet(in_channels=3, out_channels=3, num_layers=4)
x, _, _  = vae(input_tensor)
output = ddpm(x)
print(output.shape)  # Ожидаемое: torch.Size([1, 3, 512, 512])


torch.Size([4, 3, 512, 512])


In [51]:
res = torch.concat([output, x], dim=1)
res.shape

torch.Size([4, 6, 512, 512])

In [53]:
high_res.shape

torch.Size([1, 3, 512, 512])

In [63]:
batch_size = 1
# Исходные изображения
x_start = high_res

# Количество шагов
T = 1000

# Настройка бета (variance schedule) от 1e-4 до 0.02 в течение T шагов
betas = torch.linspace(1e-4, 0.02, T)

# Пример генерации зашумленного изображения на случайном шаге t
t = torch.randint(0, T, (batch_size,), dtype=torch.long)
noisy_images, noise = Noise.q_sample(x_start, t, betas)

noisy_images.shape

cond, _, _ = vae(low_res)
cond.shape

inp = torch.concat([noisy_images, cond], dim=1)
inp.shape

torch.Size([1, 6, 512, 512])

In [68]:
def train(model, cond_model, train_loader, optimizer, loss_fn, device):
    model.train()
    for low_res, high_res in tqdm(train_loader, desc="Training"):
        t = torch.randint(0, T, (batch_size,), dtype=torch.long)
        noisy_images, noise = Noise.q_sample(x_start, t, betas)

        low_res, high_res = low_res.to(device), high_res.to(device)
        noisy_images, noise = noisy_images.to(device), noise.to(device)
        
        cond, _, _ = cond_model(low_res)

        inp = torch.concat([noisy_images, cond], dim=1)
        
        optimizer.zero_grad()
        # Прямой проход
        outputs = model(inp)
        # Вычисление потери
        loss = loss_fn(outputs, noise)
        # Назад и оптимизация
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()

def validate(model, cond_model, val_loader, loss_fn, device):
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for low_res, high_res in tqdm(val_loader, desc="Validation"):
            t = torch.randint(0, T, (batch_size,), dtype=torch.long)
            noisy_images, noise = Noise.q_sample(x_start, t, betas)

            low_res, high_res = low_res.to(device), high_res.to(device)
            noisy_images, noise = noisy_images.to(device), noise.to(device)
            
            cond, _, _ = cond_model(low_res)

            inp = torch.concat([noisy_images, cond], dim=1)
            
            # Прямой проход
            outputs = model(inp)
            # Вычисление потери
            val_loss += loss_fn(outputs, noise).item()

    return val_loss / len(val_loader)

In [69]:
# Параметры обучения

# Количество шагов
T = 1000
# Настройка бета (variance schedule) от 1e-4 до 0.02 в течение T шагов
betas = torch.linspace(1e-4, 0.02, T)

epochs = 2
learning_rate = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Инициализация модели
ddpm = DDPM.UNet(in_channels=6, out_channels=3, num_layers=4)

# Функции потерь
reconstruction_loss_fn = nn.MSELoss()

# Оптимизатор
optimizer = optim.Adam(ddpm.parameters(), lr=learning_rate)

best_val_loss = float('inf')  # Инициализация лучшего значения валидационной потери
best_model_path = 'best_model_VAE.pth'  # Путь для сохранения лучшей модели

for epoch in range(1, epochs + 1):
    train(ddpm, vae, train_loader, optimizer, reconstruction_loss_fn, device)
    train_loss = validate(ddpm, vae, train_loader, reconstruction_loss_fn, device)
    val_loss = validate(ddpm, vae, val_loader, reconstruction_loss_fn, device)
    print(f'Epoch [{epoch}/{epochs}], Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')
    
    # # Сохранение модели, если валидационная потеря улучшилась
    # if val_loss < best_val_loss:
    #     best_val_loss = val_loss
    #     torch.save(model.state_dict(), best_model_path)
    #     print(f'Model saved at epoch {epoch} with validation loss: {val_loss:.4f}')

Training: 100%|██████████| 8/8 [00:28<00:00,  3.62s/it]
Validation: 100%|██████████| 8/8 [00:09<00:00,  1.23s/it]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.25s/it]


Epoch [1/2], Train Loss: 0.9950, Validation Loss: 0.9955


Training: 100%|██████████| 8/8 [00:29<00:00,  3.65s/it]
Validation: 100%|██████████| 8/8 [00:09<00:00,  1.23s/it]
Validation: 100%|██████████| 1/1 [00:01<00:00,  1.26s/it]

Epoch [2/2], Train Loss: 0.9822, Validation Loss: 0.9773



