In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from LookGenerator.networks.losses import WassersteinLoss, GradientPenalty, FineGANLoss
from LookGenerator.datasets.encoder_decoder_datasets import EncoderDecoderDataset
from LookGenerator.networks.fine_gan import *
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms
from LookGenerator.networks.utils import get_num_digits, save_model

In [2]:
transform_input = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_real = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [3]:
batch_size_train = 32
pin_memory = True
num_workers = 4

In [4]:
train_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train",
    transform_human=transform_input,
    transform_clothes=transform_input,
    transform_human_restored=transform_real
)

In [5]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

In [6]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [7]:
def _epoch_string(epoch, epoch_num):
    num_digits_epoch_num = get_num_digits(epoch_num)
    num_digits_epoch = get_num_digits(epoch)

    epoch_string = "0"*(num_digits_epoch_num - num_digits_epoch) + str(epoch)
    return epoch_string


In [8]:
def fit(model, criterion, gradient_penalty, train_dl, device, epochs, g_lr, d_lr,
        save_directory_generator, save_directory_discriminator, save_step=1):
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()

    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    # Create optimizers
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(),
                                          lr=d_lr, betas=(0.5, 0.999)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=g_lr, betas=(0.5, 0.999))
    }

    for epoch in range(epochs):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        model['discriminator'] = model['discriminator'].to(device)
        model['generator'] = model['generator'].to(device)
        for iteration, (input_images, real_images) in enumerate(tqdm(train_dl), 0):
            input_images = input_images.to(device)
            real_images = real_images.to(device)
            # Train discriminator
            # Clear discriminator gradients
            optimizer["discriminator"].zero_grad()

            real_images = real_images.to(device)

            # Pass real images through discriminator
            real_preds = model["discriminator"](real_images)
            real_targets = torch.ones(real_images.shape[0], 1, device=device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            cur_real_score = torch.mean(real_preds).item()

            # Generate fake images
            fake_images = model["generator"](input_images)

            # Pass fake images through discriminator
            fake_targets = torch.ones(fake_images.shape[0], 1, device=device)
            fake_preds = model["discriminator"](fake_images)
            fake_loss = criterion["discriminator"](fake_preds, fake_targets)
            cur_fake_score = torch.mean(fake_preds).item()
            gp = gradient_penalty(model["discriminator"], real_images, fake_images, device)

            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)

            # Update discriminator weights
            loss_d = real_loss + fake_loss + 10.0 * gp
            loss_d.backward()
            optimizer["discriminator"].step()
            loss_d_per_epoch.append(loss_d.item())

            # Train generator
            if iteration % 5 == 0:
            # Clear generator gradients
                optimizer["generator"].zero_grad()

                # Generate fake images
                fake_images = model["generator"](input_images)

                # Try to fool the discriminator
                preds = model["discriminator"](fake_images)
                targets = torch.ones(real_images.shape[0], 1, device=device)
                loss_g = criterion["generator"](preds, targets, fake_images, real_images)

                # Update generator weights
                loss_g.backward()
                optimizer["generator"].step()
                loss_g_per_epoch.append(loss_g.item())

                losses_g.append(np.mean(loss_g_per_epoch))

        # Record losses & scores
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))

        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs,
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1])
        )

        if (epoch + 1) % save_step == 0:
            save_model(
                model["discriminator"].to('cpu'),
                path=f"{save_directory_discriminator}\\discriminator_epoch_{_epoch_string(epoch, epochs)}.pt"
            )
            save_model(
                model["generator"].to('cpu'),
                path=f"{save_directory_generator}\\generator_epoch_{_epoch_string(epoch, epochs)}.pt"
            )

    return losses_g, losses_d, real_scores, fake_scores

In [9]:
save_directory_generator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\encoderGAN\weights\gen\session2"
save_directory_discriminator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\encoderGAN\weights\discr\session2"
check_path_and_creat(save_directory_generator)
check_path_and_creat(save_directory_discriminator)

True

In [None]:
generator = EncoderDecoderGenerator(in_channels=6)
discriminator = Discriminator()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

wasserstein_criterion = WassersteinLoss()
gradient_penalty = GradientPenalty(discriminator, device=device)

criterion_generator = FineGANLoss(
    adversarial_criterion=wasserstein_criterion, adv_loss_weight=0.25,
    l1_criterion=True, l1_loss_weight=4,
    perceptual=True, perceptual_loss_weight=1, device=device
)
criterion_discriminator = WassersteinLoss()

print(device)

In [11]:
model = {
    "discriminator": discriminator.to(device),
    "generator": generator.to(device)
}

criterion = {
    "discriminator": criterion_discriminator.to(device),
    "generator": criterion_generator.to(device)
}

In [None]:
history = fit(model=model,
              criterion=criterion,
              gradient_penalty=gradient_penalty,
              train_dl=train_dataloader,
              device=device,
              epochs=10,
              g_lr=0.0001,
              d_lr=0.0001,
              save_directory_generator=save_directory_generator,
              save_directory_discriminator=save_directory_discriminator)

In [13]:
test_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\val",
    transform_human=transform_input,
    transform_clothes=transform_input,
    transform_human_restored=transform_real
)

In [None]:
generator.to('cpu')
discriminator.to('cpu')
image, real_image = test_dataset[2]
image = image.unsqueeze(0)
#print(image)
print(image.shape)
image = generator(image)
imaged = discriminator(image)
image = transforms.ToPILImage()(image[0, :, :, :])
image.show()
print(imaged)
