In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from LookGenerator.networks.losses import WassersteinLoss, PerPixelLoss, PerceptualLoss, GradientPenalty, FineGANWithMaskLoss
from LookGenerator.datasets.encoder_decoder_datasets import GenerativeDatasetWithMask
from LookGenerator.networks.fine_gan import *
from LookGenerator.networks.clothes_feature_extractor import ClothAutoencoder
from LookGenerator.networks.trainer import WGANGPTrainer
from LookGenerator.networks_training.utils import check_path_and_creat
from LookGenerator.networks.utils import get_num_digits, save_model, load_model
import LookGenerator.datasets.transforms as custom_transforms

In [2]:
transform_human = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    #transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_clothes = transforms.Compose([
    transforms.Resize((256, 192)),
    # transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_mask = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.ThresholdTransform()
])

transform_human_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    # transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    custom_transforms.MinMaxScale()
])

In [3]:
batch_size_train = 24
pin_memory = True
num_workers = 6

In [4]:
train_dataset = GenerativeDatasetWithMask(
    human_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\imageWithNoCloth",
    clothes_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\cloth",
    segmentation_mask_dir=r"C:\Users\DenisovDmitrii\Desktop\zalando-hd-resize\train\agnostic-v3.3",
    human_restored_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\image",
    transform_human=transform_human,
    transform_clothes=transform_clothes,
    transform_mask=transform_mask,
    transform_human_restored=transform_human_restored
)

In [5]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

In [6]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

# Тренировщик

In [8]:
class WGANGPWithMaksTrainer(WGANGPTrainer):
    def __init__(
            self, generator, discriminator,
            optimizer_generator, optimizer_discriminator,
            criterion_generator, criterion_discriminator,
            gradient_penalty, gp_weight=0.2, save_step=1,
            save_directory_discriminator=r"", save_directory_generator=r"",
            device='cpu', verbose=True
    ):
        super().__init__(generator, discriminator, optimizer_generator, optimizer_discriminator, criterion_generator,
                         criterion_discriminator, gradient_penalty, gp_weight, save_step, save_directory_discriminator,
                         save_directory_generator, device, verbose)

    def _train_epoch(self, train_dataloader):
        self.discriminator_real_epoch_batches_loss = []
        self.discriminator_fake_epoch_batches_loss = []
        self.discriminator_epoch_batches_loss = []
        self.generator_epoch_batches_loss = []

        self.generator = self.generator.to(self.device)
        self.discriminator = self.discriminator.to(self.device)

        for iteration, (input_images, mask, real_images) in enumerate(tqdm(train_dataloader), 0):
            input_images = input_images.to(self.device)
            mask = mask.to(self.device)
            real_images = real_images.to(self.device)
            self._train_discriminator(input_images, real_images)

            if iteration % 5 == 0:
                self._train_generator(input_images, mask, real_images)

        loss_real = np.mean(self.discriminator_real_epoch_batches_loss)
        loss_fake = np.mean(self.discriminator_fake_epoch_batches_loss)
        loss_d = np.mean(self.discriminator_fake_epoch_batches_loss)
        loss_g = np.mean(self.generator_epoch_batches_loss)

        self.discriminator_real_history_epochs.append(loss_real)
        self.discriminator_fake_history_epochs.append(loss_fake)
        self.discriminator_history_epochs.append(loss_d)
        self.generator_history_epochs.append(loss_g)

        return loss_real, loss_fake, loss_d, loss_g


    def _train_generator(self, input_images, mask, real_images):
        self.discriminator.eval()
        self.generator.train()

        # Clear generator gradients
        self.optimizer_generator.zero_grad()

        # Generate fake images
        fake_images = self.generator(input_images)

        # Try to fool discriminator
        preds = self.discriminator(fake_images)
        targets = torch.ones(real_images.shape[0], 1, device=self.device)
        loss_g = self.criterion_generator(preds, targets, fake_images, mask, real_images)
        self.generator_history_batches.append(torch.mean(loss_g).item())
        self.generator_epoch_batches_loss.append(torch.mean(loss_g).item())

        # Update generator weights
        loss_g.backward()
        self.optimizer_generator.step()



In [7]:
save_directory = r'C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\encoderGAN\weights\testBaseParam'
check_path_and_creat(save_directory)
save_directory_generator = os.path.join(save_directory, 'gen')
save_directory_discriminator = os.path.join(save_directory, 'discr')
check_path_and_creat(save_directory_generator)
check_path_and_creat(save_directory_discriminator)

True

In [8]:
clothes_feature_extractor = ClothAutoencoder(
    in_channels=3,
    out_channels=3,
    features=(8, 16, 32, 64),
    latent_dim_size=128,
    encoder_activation_func=nn.LeakyReLU(),
    decoder_activation_func=nn.ReLU()
)
clothes_feature_extractor = load_model(clothes_feature_extractor, r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\autoDegradation\weights\testClothes_L1Loss_4features\epoch_39.pt")

In [9]:
generator = EncoderDecoderGenerator(
    clothes_feature_extractor=clothes_feature_extractor,
    in_channels=3, out_channels=3,
    final_activation_func=nn.Sigmoid()
)
discriminator = Discriminator(
    in_channels=3, batch_norm=False,
    dropout=True, sigmoid=False
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

wasserstein_criterion = WassersteinLoss()
l1_criterion = PerPixelLoss()
perceptual_criterion = PerceptualLoss(
    device=device, weights_perceptual=[1.0, 1.0, 1.0, 1.0]
)
gradient_penalty = GradientPenalty(discriminator, device=device)

# criterion_generator = FineGANLoss(
#     adversarial_criterion=wasserstein_criterion, adv_loss_weight=0.25,
#     l1_criterion=True, l1_loss_weight=4,
#     perceptual=True, perceptual_loss_weight=1, device=device
# )

criterion_generator = FineGANWithMaskLoss(
    adversarial_criterion=wasserstein_criterion,
    l1_criterion=l1_criterion,
    perceptual_criterion=perceptual_criterion,
    weights=[1.0, 1.0, 0.0, 1.0, 0.0, 3.0],
    device=device
)
criterion_discriminator = WassersteinLoss()

print(device)



cuda


In [10]:
optimizer_generator = torch.optim.Adam(params=generator.parameters(), lr=0.00005)
optimizer_discriminator = torch.optim.Adam(params=discriminator.parameters(), lr=0.00005)

In [14]:
trainer = WGANGPWithMaksTrainer(
    generator=generator,
    discriminator=discriminator,
    optimizer_generator=optimizer_generator,
    optimizer_discriminator=optimizer_discriminator,
    criterion_generator=criterion_generator,
    criterion_discriminator=criterion_discriminator,
    gradient_penalty=gradient_penalty,
    gp_weight=10,
    save_step=1,
    save_directory_discriminator=save_directory_discriminator,
    save_directory_generator=save_directory_generator,
    device=device,
    verbose=True
)


In [15]:
# history = fit(model=model,
#               criterion=criterion,
#               gradient_penalty=gradient_penalty,
#               train_dl=train_dataloader,
#               device=device,
#               epochs=10,
#               g_lr=0.0001,
#               d_lr=0.0001,
#               save_directory_generator=save_directory_generator,
#               save_directory_discriminator=save_directory_discriminator)

In [16]:
epoch_num = 20

In [17]:
history = trainer.train(
    train_dataloader=train_dataloader,
    epoch_num=epoch_num
)

start time 02-06-2023 15:35


100%|██████████| 486/486 [05:53<00:00,  1.38it/s]


Epoch 0 of 19, discriminator loss: 0.00251
Epoch 0 of 19, generator loss: 4.11509
Epoch end time 02-06-2023 15:41


100%|██████████| 486/486 [06:34<00:00,  1.23it/s]


Epoch 1 of 19, discriminator loss: 0.00000
Epoch 1 of 19, generator loss: 3.10050
Epoch end time 02-06-2023 15:47


100%|██████████| 486/486 [07:00<00:00,  1.16it/s]


Epoch 2 of 19, discriminator loss: 0.00000
Epoch 2 of 19, generator loss: 2.88922
Epoch end time 02-06-2023 15:54


100%|██████████| 486/486 [06:40<00:00,  1.21it/s]


Epoch 3 of 19, discriminator loss: 0.00000
Epoch 3 of 19, generator loss: 2.78009
Epoch end time 02-06-2023 16:01


100%|██████████| 486/486 [06:20<00:00,  1.28it/s]


Epoch 4 of 19, discriminator loss: 0.00000
Epoch 4 of 19, generator loss: 2.72695
Epoch end time 02-06-2023 16:07


100%|██████████| 486/486 [06:15<00:00,  1.29it/s]


Epoch 5 of 19, discriminator loss: 0.00000
Epoch 5 of 19, generator loss: 2.65125
Epoch end time 02-06-2023 16:14


100%|██████████| 486/486 [05:54<00:00,  1.37it/s]


Epoch 6 of 19, discriminator loss: 0.00000
Epoch 6 of 19, generator loss: 2.61035
Epoch end time 02-06-2023 16:20


100%|██████████| 486/486 [05:52<00:00,  1.38it/s]


Epoch 7 of 19, discriminator loss: 0.00000
Epoch 7 of 19, generator loss: 2.53768
Epoch end time 02-06-2023 16:26


100%|██████████| 486/486 [05:23<00:00,  1.50it/s]


Epoch 8 of 19, discriminator loss: 0.00000
Epoch 8 of 19, generator loss: 2.51223
Epoch end time 02-06-2023 16:31


100%|██████████| 486/486 [05:20<00:00,  1.52it/s]


Epoch 9 of 19, discriminator loss: 0.00000
Epoch 9 of 19, generator loss: 2.49246
Epoch end time 02-06-2023 16:36


100%|██████████| 486/486 [05:22<00:00,  1.51it/s]


Epoch 10 of 19, discriminator loss: 0.00000
Epoch 10 of 19, generator loss: 2.45236
Epoch end time 02-06-2023 16:42


100%|██████████| 486/486 [05:27<00:00,  1.48it/s]


Epoch 11 of 19, discriminator loss: 0.00000
Epoch 11 of 19, generator loss: 2.47527
Epoch end time 02-06-2023 16:47


100%|██████████| 486/486 [05:27<00:00,  1.48it/s]


Epoch 12 of 19, discriminator loss: 0.00000
Epoch 12 of 19, generator loss: 2.43512
Epoch end time 02-06-2023 16:53


100%|██████████| 486/486 [05:28<00:00,  1.48it/s]


Epoch 13 of 19, discriminator loss: 0.00000
Epoch 13 of 19, generator loss: 2.41783
Epoch end time 02-06-2023 16:58


 15%|█▍        | 72/486 [00:55<04:39,  1.48it/s] 

In [18]:
trainer.draw_history_plots()

In [19]:
trainer.save_history_plots(save_directory)

In [20]:
trainer.create_readme(save_directory)

In [21]:
test_dataset = GenerativeDatasetWithMask(
    human_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\imageWithNoCloth",
    clothes_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\cloth",
    segmentation_mask_dir=r"C:\Users\DenisovDmitrii\Desktop\zalando-hd-resize\train\agnostic-v3.3",
    human_restored_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\train\image",
    transform_human=transform_human,
    transform_clothes=transform_clothes,
    transform_mask=transform_mask,
    transform_human_restored=transform_human_restored
)

In [22]:
generator.to('cpu')
discriminator.to('cpu')
image, real_image = test_dataset[2]
image = image.unsqueeze(0)
#print(image)
print(image.shape)
image = generator(image)
imaged = discriminator(image)
image = transforms.ToPILImage()(image[0, :, :, :])
image.show()
print(imaged)
