# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import albumentations
from LookGenerator.networks.losses import PerceptualLoss, PerPixelLoss
from LookGenerator.datasets.encoder_decoder_datasets import EncoderDecoderDataset
from LookGenerator.networks.trainer import Trainer
from LookGenerator.networks.clothes_feature_extractor import ClothAutoencoder
from LookGenerator.networks.encoder_decoder import EncoderDecoder
from LookGenerator.networks_training.utils import check_path_and_creat
import LookGenerator.datasets.transforms as custom_transforms
from LookGenerator.networks.utils import load_model

# Загрузка данных

In [2]:
from torchvision.transforms import InterpolationMode

transform_human = transforms.Compose([
    transforms.Resize((256, 192)) #,
    # transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    # transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    # transforms.Normalize(mean=[0.5, 0.5, 0.5],
    #                      std=[0.5, 0.5, 0.5])
])

transform_pose_points=transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale()
])

transform_clothes = transforms.Compose([
    transforms.Resize((256, 192)),
    # transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    transforms.Normalize(mean=[0.5, 0.5, 0.5],
                         std=[0.5, 0.5, 0.5])
])

transform_human_restored = transforms.Compose([
    transforms.Resize((256, 192)),
    # transforms.RandomAffine(scale=(0.8, 1), degrees=(-90,90), fill = 0.9),
    # transforms.ColorJitter(brightness=(0.5, 1), contrast=(0.4,1),  hue=(0, 0.3)),
    custom_transforms.MinMaxScale()
])


In [3]:
batch_size_train = 32
batch_size_val = 16
pin_memory = True
num_workers = 4

In [4]:
train_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\train",
    transform_human=transform_human,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored
)

train_dataloader = DataLoader(
    train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers
)

In [5]:
val_dataset = EncoderDecoderDataset(
    image_dir=r"C:\Users\DenisovDmitrii\Desktop\forEncoder\val",
    transform_human=transform_human,
    transform_clothes=transform_clothes,
    transform_human_restored=transform_human_restored,
)
val_dataloader = DataLoader(
    val_dataset, batch_size=batch_size_val, shuffle=False, pin_memory=pin_memory, num_workers=num_workers
)

In [6]:
for X, y in train_dataloader:
    print(X.shape)
    print(y.shape)
    break

torch.Size([32, 6, 256, 192])
torch.Size([32, 3, 256, 192])


# Лосс

In [7]:
class EncoderDecoderLoss(nn.Module):
    """
    Encoder-decoder custom loss
    """
    def __init__(self, device='cpu'):
        super(EncoderDecoderLoss, self).__init__()
        self.perceptual_loss = PerceptualLoss(device, weights_perceptual=[1.0, 1.0, 1.0, 1.0])
        self.per_pixel_loss = PerPixelLoss().to(device)

    def forward(self, outputs, targets):
        loss = self.perceptual_loss(outputs, targets) + self.per_pixel_loss(outputs, targets)
        return loss

# Обучение модели

In [8]:
clothes_feature_extractor = ClothAutoencoder(
    in_channels=3,
    out_channels=3,
    features=(8, 16, 32, 64),
    latent_dim_size=128,
    encoder_activation_func=nn.LeakyReLU(),
    decoder_activation_func=nn.ReLU()
)
clothes_feature_extractor = load_model(clothes_feature_extractor, r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\autoDegradation\weights\testClothes_L1Loss_4features\epoch_39.pt")

In [9]:
model = EncoderDecoder(clothes_feature_extractor, in_channels=6, out_channels=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = EncoderDecoderLoss(device=device)
print(device)



cuda


In [10]:
save_directory=r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\newEncoder\weights\testBaseParams"
check_path_and_creat(save_directory)

True

In [11]:
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.benchmark = True

In [12]:
trainer = Trainer(
    model_=model,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    save_directory=save_directory,
    save_step=1,
    verbose=True
)

In [13]:
trainer.train(train_dataloader, val_dataloader, epoch_num=20)

start time 30-05-2023 23:21


 13%|█▎        | 47/364 [01:19<08:58,  1.70s/it] 


KeyboardInterrupt: 

In [None]:
trainer.draw_history_plots()

In [None]:
trainer.train(train_dataloader, val_dataloader, epoch_num=20)

In [None]:
trainer.draw_history_plots()