# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from LookGenerator.networks.losses import WassersteinLoss, GradientPenalty
from LookGenerator.datasets.refinement_dataset import RefinementGANDataset
from LookGenerator.networks.trainer import WGANGPTrainer
from LookGenerator.networks.refinement import RefinementGenerator, RefinementDiscriminator
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [2]:
transform_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

transform_real = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [3]:
batch_size_train = 192
pin_memory = True
num_workers = 10

In [4]:
train_dataset = RefinementGANDataset(
    restored_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forRefinement\train",
    real_images_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train\image",
    transform_restored_images=transform_restored,
    transform_real_images=transform_real
)

In [5]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

# Обучение модели

In [6]:
generator = RefinementGenerator()
discriminator = RefinementDiscriminator()

optimizer_generator = torch.optim.Adam(generator.parameters(), lr=5e-5)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=5e-5)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

criterion_generator = WassersteinLoss()
criterion_discriminator = WassersteinLoss()
gradient_penalty = GradientPenalty(discriminator, device=device)

print(device)

cuda


In [7]:
save_directory_generator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\generator\session1"
save_directory_discriminator=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\refinement\weights\discriminator\session1"
check_path_and_creat(save_directory_generator)
check_path_and_creat(save_directory_discriminator)


True

In [8]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [9]:
trainer = WGANGPTrainer(
    generator=generator,
    discriminator=discriminator,
    optimizer_generator=optimizer_generator,
    optimizer_discriminator=optimizer_discriminator,
    criterion_generator=criterion_generator,
    criterion_discriminator=criterion_discriminator,
    gradient_penalty=gradient_penalty,
    gp_weight=0.2,
    save_step=1,
    save_directory_discriminator=save_directory_discriminator,
    save_directory_generator=save_directory_generator,
    device=device,
    verbose=True
)

In [None]:
trainer.train(train_dataloader, epoch_num=5)

In [None]:
image, real_image = train_dataset[1]
image = image.unsqueeze(0)
print(image.shape)
image = generator(image)
imaged = discriminator(image)
image = transforms.ToPILImage()(image[0, :, :, :])
image.show()
print(imaged)
