# Импорт

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from LookGenerator.datasets.refinement_dataset import RefinementDiscriminatorDataset
from LookGenerator.networks.refinement import RefinementDiscriminator
from LookGenerator.networks.trainer import Trainer
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [None]:
transform = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

In [None]:
batch_size_train = 32
batch_size_val = 16
pin_memory = True
num_workers = 16

In [None]:
train_dataset = RefinementDiscriminatorDataset(
    fake_images_dir=r"",
    real_images_dir=r"",
    transform=transform
)

val_dataset = RefinementDiscriminatorDataset(
    fake_images_dir=r"",
    real_images_dir=r"",
    transform=transform
)

In [None]:
train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size_val, shuffle=False, pin_memory=pin_memory, num_workers=num_workers
)

# Обучение модели

In [None]:
model = RefinementDiscriminator()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = nn.BCELoss()
print(device)

In [None]:
save_directory=r""
check_path_and_creat(save_directory)

In [None]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [None]:
trainer = Trainer(
    model_=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    save_directory=save_directory,
    save_step=1,
    verbose=True
)

In [None]:
trainer.train(train_dataloader, val_dataloader, epoch_num=3)

In [None]:
trainer.draw_history_plots()