# Импорт

In [1]:
import os

import matplotlib.pyplot as plt
import numpy as np
import PIL.Image as Image

import albumentations

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

from LookGenerator.datasets.utils import load_image, prepare_image_for_segmentation, to_array_from_model_bin, show_array_multichannel
from LookGenerator.datasets.background_cut_dataset import PersonDataset
from LookGenerator.networks.losses import FocalLoss
from LookGenerator.networks.segmentation import UNet, train_unet
from LookGenerator.networks.utils import load_model
import LookGenerator.datasets.transforms as custom_transforms
from LookGenerator.networks_training.utils import check_path_and_creat
from LookGenerator.config.config import DatasetConfig
from LookGenerator.networks.trainer import Trainer

# Загрузка данных

In [2]:
transform_input = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.25, 0.25, 0.25]
    )
])

transform_output = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale(),
    custom_transforms.ThresholdTransform(threshold=0.5)
])

In [3]:
batch_size_train = 24
batch_size_val = 16
pin_memory = True
num_workers = 16

In [4]:
transform_train = albumentations.Compose([
        albumentations.Resize(height=256, width=192),
        albumentations.RandomBrightnessContrast(brightness_limit=(0.1,0.3), contrast_limit=(0.2,0.7), p =0.2),
        albumentations.Equalize(p = 0.2),
        albumentations.GaussNoise(p = 0.2),
        albumentations.Affine(translate_percent=0.1, scale=(0.8, 1), rotate=(-90,90), p=0.2),
        albumentations.Normalize(mean = (0.5, 0.5, 0.5), std = (0.25, 0.25, 0.25)),
    ])

transform_valid = albumentations.Compose([
        albumentations.Resize(height=256, width=192),
        albumentations.Normalize(mean = (0.5, 0.5, 0.5), std = (0.25, 0.25, 0.25)),
    ])

In [None]:
Config = DatasetConfig(os.environ)

In [5]:
train_dataset = PersonDataset(
    Config.DATASET_DIR, "image", "image-parse-v3",
    background_root_dir=r"", dir_name_background="ScreenShot",
    transform_input=transform_input, transform_output=transform_output, augment=transform_train
)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, pin_memory=pin_memory, num_workers=num_workers)
(len(train_dataset), len(train_dataloader))

(11647, 486)

In [6]:
val_dataset = PersonDataset(
    Config.DATASET_DIR, "image", "image-parse-v3",
    background_root_dir=r"", dir_name_background="ScreenShot",
    transform_input=transform_input, transform_output=transform_output, augment=transform_train
)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size_val, shuffle=False, pin_memory=pin_memory, num_workers=num_workers)
(len(val_dataset), len(val_dataloader))

(2032, 127)

In [None]:
for X, y in train_dataloader:
    print(X.shape)
    print(y.shape)
    plt.imshow(X.detach().numpy()[0,0,:,:], cmap = 'binary')
    plt.show()
    show_array_multichannel(y.detach().numpy()[0,:, :, :], 15)

    # modelled_img = to_array_from_model_bin_transpose(transform_output(X.detach()))
    # plt.imshow(modelled_img,cmap = 'binary')
    # plt.show()
    # modelled_img = to_array_from_model_bin_transpose(transform_output(y.detach()))
    # plt.imshow(modelled_img,cmap = 'binary')
    # plt.show()
    break

# Обучение модели

In [7]:
model = UNet(in_channels=3, out_channels = 1)
#model = load_model(model, r'')                 # активировать, если модель надо загрузить и дообучить
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
criterion = FocalLoss()
print(device)

cuda


In [None]:
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# import datetime
# now = datetime.datetime.now()
# print("start time", now.strftime("%d-%m-%Y %H:%M"))
#TODO: указать директорию для сохранения весов
# save_directory = ""
# check_path_and_creat(save_directory)
# train_history, val_history = train_unet(
#     model,
#     train_dataloader,
#     val_dataloader,
#     optimizer,
#     device=device,
#     epoch_num=20,
#     save_directory= save_directory
# )
# old = now
# now = datetime.datetime.now()
# print("end time", now.strftime("%d-%m-%Y %H:%M"))
# print("delta", now - old)


In [None]:
save_directory = ""
epoch_num = 30

In [None]:
trainer = Trainer(
    model_=model, optimizer=optimizer, criterion = criterion,
    device=device, save_directory=save_directory, save_step=2, verbose=True
)

In [None]:
trainer.train(train_dataloader=train_dataloader, val_dataloader=val_dataloader, epoch_num=epoch_num)

In [None]:
trainer.draw_history_plots()

# Загрузка модели и проверка визуально качества работы

In [None]:
model = load_model(model, r'C:\Users\Даша\PycharmProjects\SMBackEnd\LookGenerator\weights\unet_epoch_0_0.0161572862694324.pt')
model.eval()

In [None]:
test_dir = r""
test_folder = ""
save_masks_dir = r""
list_files = os.listdir(test_dir)
images = [file.split('.') for file in list_files]

In [None]:
for image in images:
    print(image)
    img = load_image(test_dir, test_folder, image, '.jpg')
    img_to_model = prepare_image_for_segmentation(img, transform_input)
    modelled = model(img_to_model)
    mask = to_array_from_model_bin(modelled)

    plt.figure(figsize=(18, 6))
    plt.subplot(1,2,0)
    plt.imshow(img)
    plt.subplot(1,2,1)
    plt.imshow(mask)

    Image.fromarray(mask, 'L').save(save_masks_dir + image + '.png')
