In [1]:
import torch
import torch.nn as nn
import numpy as np

import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as functional

from PIL import Image
import os

from LookGenerator.networks.segmentation import UNet
from LookGenerator.networks.bpgm.model.models import BPGM
from LookGenerator.networks.clothes_feature_extractor import ClothAutoencoder
from LookGenerator.networks.encoder_decoder import EncoderDecoder

import LookGenerator.datasets.transforms as custom_transforms
from LookGenerator.datasets.utils import prepare_image_for_segmentation
from LookGenerator.networks.utils import load_model

from tqdm import tqdm

In [2]:
human_path = r'C:\Users\DenisovDmitrii\Desktop\12channels\valData\som'
cloth_path = r'C:\Users\DenisovDmitrii\Desktop\zalando-hd-resize\train\cloth'

In [16]:
segmentation_bin_path = r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\segmetationBackground\weights\testResultsFeatures_32_64_128_256_512\epoch_39.pt"

segmentation_multy_path = r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\segmentationMulty\weights\testMulty_out_12_6features_20to640_fillBack\epoch_37.pt"

tps_path = r'C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\tps\weights\test\epoch_02.pt'

clothes_feature_extractor_path = r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\autoDegradation\weights\testClothes_L1Loss_4features\epoch_39.pt"

encoder_path = r"C:\Users\DenisovDmitrii\OneDrive - ITMO UNIVERSITY\peopleDetector\newEncoder\weights\testWithTPSMask\epoch_29.pt"

In [17]:
cloth_list = os.listdir(cloth_path)
len(cloth_list)

11647

In [18]:
human_list = os.listdir(human_path)
len(human_list)

18

In [19]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [20]:
transforms_resize = transforms.Compose([
    transforms.Resize((256, 192))
])

transform_input_segmentation = transforms.Compose([
    transforms.Resize((256, 192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.25, 0.25, 0.25]
    )
])

transform_for_tps_and_encoder = transforms.Compose([
    transforms.Resize((256,192)),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    )
])

transform_output_segmentation = transforms.Compose([
    transforms.Resize((256, 192)),
    custom_transforms.MinMaxScale(),
    custom_transforms.ThresholdTransform(threshold=0.5)
])


In [21]:
toTensor = transforms.ToTensor()
toPIL = transforms.ToPILImage()

In [22]:
segmentation_bin = UNet(in_channels=3, out_channels=1, features=(32, 64, 128, 256, 512))
segmentation_bin = load_model(segmentation_bin, segmentation_bin_path)
segmentation_bin = segmentation_bin.to(device)


In [23]:
segmentation_multy = UNet(in_channels=3, out_channels=12,
                          features=(20, 40, 80, 160, 320, 640),
                          final_activation=nn.Softmax(dim=1))
segmentation_multy = load_model(segmentation_multy, segmentation_multy_path)
segmentation_multy = segmentation_multy.to(device)

In [24]:
tps = BPGM(in_channels=12, device=device)
tps = load_model(tps, tps_path)
tps = tps.to(device)

initialization method [normal]
initialization method [normal]


In [25]:
clothes_feature_extractor = ClothAutoencoder(
    in_channels=3,
    out_channels=3,
    features=(8, 16, 32, 64),
    latent_dim_size=128,
    encoder_activation_func=nn.LeakyReLU(),
    decoder_activation_func=nn.ReLU()
)
clothes_feature_extractor = load_model(clothes_feature_extractor, clothes_feature_extractor_path)
clothes_feature_extractor = clothes_feature_extractor.to(device)

In [26]:
encoder_decoder = EncoderDecoder(clothes_feature_extractor, in_channels=6, out_channels=3)
encoder_decoder = load_model(encoder_decoder, encoder_path)
encoder_decoder = encoder_decoder.to(device)

In [28]:
for human in tqdm(human_list):
    number_cloth = np.random.randint(0, 11647)
    # print(number_cloth)

    cloth = cloth_list[number_cloth]

    human_image = Image.open(os.path.join(human_path, human))
    human_image = transforms_resize(toTensor(human_image).unsqueeze(0)).to(device)
    img_to_segmentation = transform_input_segmentation(human_image)

    cloth_image = Image.open(os.path.join(cloth_path, cloth))
    cloth_image = toTensor(cloth_image).unsqueeze(0).to(device)
    cloth_to_model = transform_for_tps_and_encoder(cloth_image)

    segmentation_bin_out = transform_output_segmentation(segmentation_bin(img_to_segmentation).detach())
    segmentation_bin_out_bool = torch.tensor(segmentation_bin_out, dtype=torch.bool)
    segmentation_bin_out_clear = human_image * (~segmentation_bin_out_bool) + segmentation_bin_out_bool

    segmentation_multy_out = transform_output_segmentation(segmentation_multy(img_to_segmentation).detach())

    cwm = segmentation_multy_out[:,8,:,:]
    cwm = torch.tensor(cwm, dtype=torch.bool)
    theta = tps(segmentation_multy_out, cloth_to_model)

    warped = functional.grid_sample(cloth_to_model, theta, padding_mode='border', align_corners=True)
    warped = warped / 2 + 0.5
    warped = warped * cwm
    person = segmentation_bin_out_clear * (~cwm) + warped

    human_for_encoder = transform_for_tps_and_encoder(person)
    data_to_encoder = torch.cat((human_for_encoder, cloth_to_model), dim=1)
    model_out_from_encoder = encoder_decoder(data_to_encoder).to('cpu')
    # segmentation_bin_out_clear_cpu = segmentation_bin_out_clear.to('cpu')
    torchvision.utils.save_image(model_out_from_encoder, fr"C:\Users\DenisovDmitrii\Desktop\forEncoderNew\sameOut2\{human[:-4]}_{cloth[:-4]}.png")

  segmentation_bin_out_bool = torch.tensor(segmentation_bin_out, dtype=torch.bool)
  cwm = torch.tensor(cwm, dtype=torch.bool)
100%|██████████| 18/18 [00:01<00:00, 17.87it/s]


In [16]:
segmentation_bin_out_clear_cpu = segmentation_bin_out_clear.to('cpu')
for img in segmentation_bin_out_clear_cpu:
    toPIL(img).show()

In [14]:
segmentation_bin_out_cpu = segmentation_bin_out.to('cpu')
for img in segmentation_bin_out_cpu:
    toPIL(img).show()

In [None]:
segmentation_multy_out_cpu = segmentation_multy_out.to('cpu')
for img in segmentation_multy_out_cpu:
    for chanel in img:
        toPIL(chanel).show()

In [103]:
for img in person:
    toPIL(img).show()

In [17]:
for img in cloth_image:
    toPIL(img).show()

In [16]:
for img in model_out_from_encoder:
    toPIL(img).show()


In [20]:
for img in data_to_encoder[:, :3, :, :]:
    toPIL(img).show()
for img in data_to_encoder[:, 3:, :, :]:
    toPIL(img).show()