# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import albumentations
from LookGenerator.networks.losses import PerceptualLoss
from LookGenerator.datasets.encoder_decoder_datasets import EncoderDecoderDataset
from LookGenerator.networks.trainer import Trainer
from LookGenerator.networks.encoder_decoder import EncoderDecoder
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms

# Загрузка данных

In [2]:
from torchvision.transforms import InterpolationMode

transform_human = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    custom_transforms.Normalize()
])

transform_pose_points=transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

transform_clothes = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    custom_transforms.Normalize()
])

transform_human_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    custom_transforms.MinMaxScale()
])


In [3]:
batch_size_train = 32
batch_size_val = 16
pin_memory = True
num_workers = 16

In [4]:
train_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train",
    transform_human=transform_human,
    transform_pose_points=transform_pose_points,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored
)

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

In [5]:
val_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\val",
    transform_human=transform_human,
    transform_pose_points=transform_pose_points,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored,
)
val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size_val, shuffle=False, pin_memory=pin_memory, num_workers=num_workers
)

# Обучение модели

In [None]:
model = EncoderDecoder(in_channels=6, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
# criterion = nn.MSELoss()
# criterion = nn.CrossEntropy
criterion = PerceptualLoss(device, weight_per_pixel=1.0, weights_perceptual=[1.0, 1.0, 1.0, 1.0])
print(device)

In [3]:
save_directory=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\encoder\weights\session13"
check_path_and_creat(save_directory)

True

In [8]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [9]:
trainer = Trainer(
    model_=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    save_directory=save_directory,
    save_step=1,
    verbose=True
)

In [None]:
trainer.train(train_dataloader, val_dataloader, epoch_num=20)