# Импорт

In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from LookGenerator.datasets.encoder_decoder_datasets import EncoderDecoderDataset
from LookGenerator.networks.trainer import Trainer
from LookGenerator.networks.encoder_decoder import EncoderDecoder

import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [None]:
transform_human = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.25, 0.25, 0.25]
    )
])

transform_pose_points=transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

transform_clothes = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.25, 0.25, 0.25]
    )
])

transform_human_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.25, 0.25, 0.25]
    )
])

In [None]:
batch_size_train = 24
batch_size_val = 16

In [None]:
train_dataset = EncoderDecoderDataset(
    image_dir="",
    transform_human=transform_human,
    transform_pose_points=transform_pose_points,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored
)

test_dataset = EncoderDecoderDataset(
    image_dir="",
    transform_human=transform_human,
    transform_pose_points=transform_pose_points,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored
)

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True
)

test_dataset = DataLoader(
    test_dataset, batch_size=batch_size_train, shuffle=False
)

# Обучение модели

In [None]:
model = EncoderDecoder(in_channels=23, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
# criterion = nn.CrossEntropy

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

In [None]:
save_directory=""

In [None]:
trainer = Trainer(
    model_=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    save_directory=save_directory,
    save_step=1,
    verbose=True
)