In [None]:
# AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from LookGenerator.networks.losses import FineGANLoss
from LookGenerator.datasets.refinement_dataset import RefinementGANDataset
from LookGenerator.networks.FineGAN import *
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms

In [None]:
transform_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_real = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [None]:
batch_size_train = 64
pin_memory = True
num_workers = 16

In [None]:
train_dataset = RefinementGANDataset(
    restored_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\train",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train\image",
    transform_restored_images=transform_restored,
    transform_real_images=transform_real
)

In [None]:
train_dataset = RefinementGANDataset(
    restored_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\train",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train\image",
    transform_restored_images=transform_restored,
    transform_real_images=transform_real
)

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

In [None]:
save_directory_generator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\generator\session4"
save_directory_discriminator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\discriminator\session4"
check_path_and_creat(save_directory_generator)
check_path_and_creat(save_directory_discriminator)


In [None]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [None]:
def fit(model, criterion, train_dl, device, epochs, g_lr, d_lr):
    model["discriminator"].train()
    model["generator"].train()
    torch.cuda.empty_cache()

    # Losses & scores
    losses_g = []
    losses_d = []
    real_scores = []
    fake_scores = []

    # Create optimizers
    optimizer = {
        "discriminator": torch.optim.Adam(model["discriminator"].parameters(),
                                          lr=d_lr, betas=(0.5, 0.999)),
        "generator": torch.optim.Adam(model["generator"].parameters(),
                                      lr=g_lr, betas=(0.5, 0.999))
    }

    for epoch in range(epochs):
        loss_d_per_epoch = []
        loss_g_per_epoch = []
        real_score_per_epoch = []
        fake_score_per_epoch = []
        for input_images, real_images in tqdm(train_dl):
            input_images = input_images.to(device)
            real_images = real_images.to(device)
            # Train discriminator
            # Clear discriminator gradients
            optimizer["discriminator"].zero_grad()

            real_images = real_images.to(device)

            # Pass real images through discriminator
            real_preds = model["discriminator"](real_images)
            real_targets = torch.ones(real_images.shape[0], 1, device=device)
            real_loss = criterion["discriminator"](real_preds, real_targets)
            cur_real_score = torch.mean(real_preds).item()

            # Generate fake images
            fake_images = model["generator"](input_images)

            # Pass fake images through discriminator
            fake_targets = torch.ones(fake_images.shape[0], 1, device=device)
            fake_preds = model["discriminator"](fake_images)
            fake_loss = criterion["discriminator"](fake_preds, fake_targets)
            cur_fake_score = torch.mean(fake_preds).item()

            real_score_per_epoch.append(cur_real_score)
            fake_score_per_epoch.append(cur_fake_score)

            # Update discriminator weights
            loss_d = real_loss + fake_loss
            loss_d.backward()
            optimizer["discriminator"].step()
            loss_d_per_epoch.append(loss_d.item())

            # Train generator

            # Clear generator gradients
            optimizer["generator"].zero_grad()

            # Generate fake images
            fake_images = model["generator"](input_images)

            # Try to fool the discriminator
            preds = model["discriminator"](fake_images)
            targets = torch.ones(real_images.shape[0], 1, device=device)
            loss_g = criterion["generator"](preds, targets)

            # Update generator weights
            loss_g.backward()
            optimizer["generator"].step()
            loss_g_per_epoch.append(loss_g.item())

            losses_g.append(np.mean(loss_g_per_epoch))

        # Record losses & scores
        losses_d.append(np.mean(loss_d_per_epoch))
        real_scores.append(np.mean(real_score_per_epoch))
        fake_scores.append(np.mean(fake_score_per_epoch))

        # Log losses & scores (last batch)
        print("Epoch [{}/{}], loss_g: {:.4f}, loss_d: {:.4f}, real_score: {:.4f}, fake_score: {:.4f}".format(
            epoch+1, epochs,
            losses_g[-1], losses_d[-1], real_scores[-1], fake_scores[-1])
        )

    return losses_g, losses_d, real_scores, fake_scores

In [None]:
generator = EncoderDecoderGenerator()
discriminator = Discriminator()

device = 'cuda' if torch.cuda.is_available() else 'cpu'

criterion_generator = FineGANLoss(adversarial_criterion=nn.BCELoss())
criterion_discriminator = nn.BCELoss()

print(device)

In [None]:
model = {
    "discriminator": generator.to(device),
    "generator": discriminator.to(device)
}

criterion = {
    "discriminator": criterion_discriminator,
    "generator": criterion_generator
}

In [None]:
history = fit(model=model,
              criterion=criterion,
              train_dl=train_dataloader,
              device=device,
              epochs=30,
              g_lr=0.0004,
              d_lr=0.0008)

In [None]:
generator.to('cpu')
discriminator.to('cpu')
image, real_image = train_dataset[1]
image = image.unsqueeze(0)
print(image)
print(image.shape)
image = generator(image)
imaged = discriminator(image)
image = transforms.ToPILImage()(image[0, :, :, :])
image.show()
print(imaged)
