# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from LookGenerator.networks.losses import WassersteinLoss, GradientPenalty
from LookGenerator.datasets.refinement_dataset import RefinementGANDataset
from LookGenerator.networks.trainer import WGANGPTrainer
from LookGenerator.networks.refinement import RefinementGenerator, RefinementDiscriminator
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [2]:
transform_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

transform_real = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [3]:
batch_size_train = 192
pin_memory = True
num_workers = 10

In [4]:
train_dataset = RefinementGANDataset(
    restored_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\train",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train\image",
    transform_restored_images=transform_restored,
    transform_real_images=transform_real
)

FileNotFoundError: [WinError 3] Системе не удается найти указанный путь: 'C:\\Users\\DenisovDmitrii\\Desktop\\forRefinement\\train'

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

# Обучение модели

In [None]:
generator = RefinementGenerator()
discriminator = RefinementDiscriminator()

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=5e-5)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=5e-5)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

criterion_generator = WassersteinLoss()
criterion_discriminator = WassersteinLoss()
gradient_penalty = GradientPenalty(discriminator, device=device)

print(device)

In [None]:
save_directory_generator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\generator\session1"
save_directory_discriminator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\discriminator\session1"
check_path_and_creat(save_directory_generator)
check_path_and_creat(save_directory_discriminator)


In [None]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [None]:
#trainer = WGANGPTrainer(
#    generator=generator,
#    discriminator=discriminator,
#    optimizer_generator=optimizer_generator,
#    optimizer_discriminator=optimizer_discriminator,
#    criterion_generator=criterion_generator,
#    criterion_discriminator=criterion_discriminator,
#    gradient_penalty=gradient_penalty,
#    gp_weight=0.2,
#    save_step=1,
#    save_directory_discriminator=save_directory_discriminator,
#    save_directory_generator=save_directory_generator,
#    device=device,
#    verbose=True
#)

In [None]:
#trainer.train(train_dataloader, epoch_num=5)

In [None]:
model = {
    "discriminator": RefinementDiscriminator().to(device),
    "generator": RefinementGenerator().to(device)
}

criterion = {
    "discriminator": criterion_discriminator(),
    "generator": criterion_generator()
}

In [5]:
def fit(model, criterion, train_dl, device, epochs, lr):
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()

    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    # Create optimizers
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(),
                                          lr=lr, betas=(0.5, 0.999)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=lr, betas=(0.5, 0.999))
    }

    for epoch in range(epochs):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for iteration, (input_images, real_images) in enumerate(tqdm(train_dl), 0):
            input_images = input_images.to(device)
            real_images = real_images.to(device)
            # Train discriminator
            # Clear discriminator gradients
            optimizer["discriminator"].zero_grad()

            real_images = real_images.to(device)

            # Pass real images through discriminator
            real_preds = model["discriminator"](real_images)
            real_targets = torch.ones(real_images.shape[0], 1, device=device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            cur_real_score = torch.mean(real_preds).item()

            # Generate fake images
            fake_images = model["generator"](input_images)

            # Pass fake images through discriminator
            fake_targets = -torch.ones(fake_images.shape[0], 1, device=device)
            fake_preds = model["discriminator"](fake_images)
            fake_loss = criterion["discriminator"](fake_preds, fake_targets)
            cur_fake_score = torch.mean(fake_preds).item()

            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)
            gp = gradient_penalty(model["discriminator"], real_images, fake_images, device)

            # Update discriminator weights
            loss_d = real_loss + fake_loss + 0.2 * gp
            loss_d.backward()
            optimizer["discriminator"].step()
            loss_d_per_epoch.append(loss_d.item())


            # Train generator
            if iteration % 5 == 0:
                # Clear generator gradients
                optimizer["generator"].zero_grad()

                # Generate fake images
                fake_images = model["generator"](input_images)

                # Try to fool the discriminator
                preds = model["discriminator"](fake_images)
                targets = torch.ones(real_images.shape[0], 1, device=device)
                loss_g = criterion["generator"](preds, targets)

                # Update generator weights
                loss_g.backward()
                optimizer["generator"].step()
                loss_g_per_epoch.append(loss_g.item())

                losses_g.append(np.mean(loss_g_per_epoch))

        # Record losses & scores
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))

        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs,
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1])
        )

    return losses_g, losses_d, real_scores, fake_scores

In [None]:
history = fit(model, criterion, train_dataloader, device, 2, 0.00005)

In [None]:
image, real_image = train_dataset[1]
image = image.unsqueeze(0)
print(image.shape)
image = generator(image)
imaged = discriminator(image)
image = transforms.ToPILImage()(image[0, :, :, :])
image.show()
print(imaged)
