In [None]:
import numpy as np
import PIL.Image as Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import random, json
from torch.utils.data import DataLoader
import torch
import torchvision.transforms.functional as TF
from torchvision import models
import torch.optim as optim
import dnnlib, legacy
from data.GEN_data import FaceLandmarksDataset, Transforms
from training_modified import networks

DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")

train_dataset = FaceLandmarksDataset("train", transform=Transforms())
train_loader = DataLoader(train_dataset, 1, shuffle=True)

with dnnlib.util.open_url("pretrained/ffhq.pkl") as f:
    data = legacy.load_network_pkl(f)
    generator = data["G_ema"].to(DEVICE)
    discriminator = data["D"].to(DEVICE)

style_generator = networks.Generator(**generator.init_kwargs).to(DEVICE)
style_generator.load_state_dict(generator.state_dict())
style_generator.eval()

ran_z = torch.randn((1, 512)).to(DEVICE)
spv = torch.zeros((1, 6048)).to(DEVICE)

ref_ws = style_generator.mapping(ran_z, 0)

gen_images = style_generator.synthesis(ref_ws, spv)
gen_images_denorm = (gen_images * 127.5 + 128).clamp(0, 255) / 255

In [None]:
TF.to_pil_image(TF.resize(gen_images_denorm[0], (256, 256)))

In [None]:
from PIL import Image, ImageDraw


draw = ImageDraw.Draw(ref_img)

r = 4

for x, y in ref_landmarks[0].tolist():
    leftUpPoint = (x-r, y-r)
    rightDownPoint = (x+r, y+r)
    twoPointList = [leftUpPoint, rightDownPoint]
    draw.ellipse(twoPointList, fill="red")

ref_img.save("test.png")

In [None]:
save_images(ref_images = ref_images, gen_images=gen_images_denorm, size=(256,256), output_dir="./test.png")

In [None]:
res = 1024

feature_extractor = model.FeatureExtractor(style_generator.synthesis, [f"b{res}.torgb"])

z = torch.tensor(json_data["images"][str(23)]["z"]).to(DEVICE)
ws = style_generator.mapping(z, 0)

spv_1 = torch.zeros((1, 6048), device=DEVICE)

gen_image_1 = feature_extractor(ws, spv_1)[f"b{res}.torgb"][0]
gen_image_de_1 = (gen_image_1.permute(1,2,0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

gen_image_2 = style_generator.synthesis(ws, spv_1)
gen_image_de_2 = (gen_image_2.permute(0,2,3,1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)

In [None]:
Image.fromarray(gen_image_de_1.cpu().numpy(), 'RGB')

In [None]:
Image.fromarray(gen_image_de_2[0].cpu().numpy(), 'RGB')

In [None]:
a = TF.resize(TF.to_tensor(Image.open(json_data["images"]["31"]["dir"])), (256, 256))
b = TF.resize(TF.to_tensor(Image.open(json_data["images"]["23"]["dir"])), (256, 256))

a_pil = TF.to_pil_image(a)
b_pil = TF.to_pil_image(b)

out_img = Image.new("RGB", (512, 256))
out_img.paste(a_pil, (0, 0))
out_img.paste(b_pil, (256, 0))
out_img

In [None]:
import matplotlib.pyplot as plt

def show_fun(idx):
    style = torch.tensor([j for i in json_data["images"][str(idx)]["styles"] for j in i])
    style_diff = style - torch.tensor(json_data["style_train_mean"])
    fig, ax = plt.subplots()
    ax.plot(style_diff)
    plt.show()

In [None]:
show_fun(23), show_fun(31), show_fun(34)

In [None]:
style_mean = torch.tensor(json_data["style_train_mean"]).to(DEVICE)

# ref_images, ref_landmarks, ref_zs, ref_bboxs = next(iter(train_loader))
# ref_images = ref_images.to(DEVICE)
# ref_landmarks = ref_landmarks.to(DEVICE)
# ref_zs = ref_zs.to(DEVICE)
# ref_bboxs = ref_bboxs.to(DEVICE)

ref_zs = torch.tensor(json_data["images"]["9"]["z"]).to(DEVICE)
ref_lnd = torch.tensor(json_data["images"]["9"]["landmark"]).to(DEVICE).unsqueeze(0)

In [None]:
spv_1 = torch.zeros((1, 6048)).to(DEVICE)

spw = ref_mapping_network(ref_lnd.reshape(ref_lnd.shape[0], -1))
spv_2 = spw * style_mean

ws = style_generator.mapping(ref_zs, 0)
gen_images_1 = style_generator.synthesis(ws, spv_1)
gen_images_denorm_1 = (gen_images_1 * 127.5 + 128).clamp(0, 255).to(torch.uint8)

gen_images_2 = style_generator.synthesis(ws, spv_2)
gen_images_denorm_2 = (gen_images_2 * 127.5 + 128).clamp(0, 255).to(torch.uint8)

gen_img_1 = TF.to_pil_image(TF.resize(gen_images_denorm_1[0], (256, 256)))
gen_img_2 = TF.to_pil_image(TF.resize(gen_images_denorm_2[0], (256, 256)))

out_img = Image.new("RGB", (512, 256))
out_img.paste(gen_img_1, (0, 0))
out_img.paste(gen_img_2, (256, 0))
out_img

In [None]:
def test_fun(chn, start, stop, step):
    assert abs(stop-start) % step == 0
    num = int(abs(stop-start) / step)
    out_img = Image.new("RGB", (num*256, 256))
    coun=0
    for i in range(start, stop, step):

        spv_ = torch.zeros((1, 6048), device=DEVICE)
        spv_[:, chn] = i

        ws = style_generator.mapping(ref_zs, 0)
        gen_images_ = style_generator.synthesis(ws, spv_)
        gen_images_denorm_ = (gen_images_ * 127.5 + 128).clamp(0, 255).to(torch.uint8)

        gen_img = TF.to_pil_image(TF.resize(gen_images_denorm_[0], (256, 256)))

        out_img.paste(gen_img, (coun*256, 0))
        coun += 1
    return out_img

In [None]:
spv_1 = torch.zeros((1, 6048), device=DEVICE)

spv_2 = torch.zeros((1, 6048), device=DEVICE)
spv_2[:, 4938] = 6

spv_2 = spv_2 * style_mean

ws = style_generator.mapping(ref_zs, 0)

gen_images_1 = style_generator.synthesis(ws, spv_1)
gen_images_denorm_1 = (gen_images_1 * 127.5 + 128).clamp(0, 255).to(torch.uint8)

gen_images_2 = style_generator.synthesis(ws, spv_2)
gen_images_denorm_2 = (gen_images_2 * 127.5 + 128).clamp(0, 255).to(torch.uint8)

gen_img_1 = TF.to_pil_image(TF.resize(gen_images_denorm_1[0], (256, 256)))
gen_img_2 = TF.to_pil_image(TF.resize(gen_images_denorm_2[0], (256, 256)))

out_img = Image.new("RGB", (512, 256))
out_img.paste(gen_img_1, (0, 0))
out_img.paste(gen_img_2, (256, 0))
out_img

In [None]:
out = test_fun(38, -400, 400, 80)
out

In [None]:
out = test_fun(938, -400, 400, 80)
out

In [None]:
out = test_fun(1938, -400, 400, 80)
out

In [None]:
out = test_fun(2938, -400, 400, 80)
out

In [None]:
out = test_fun(3938, -400, 400, 80)
out

In [None]:
out = test_fun(4938, -400, 400, 80)
out

In [None]:
out = test_fun(5938, -400, 400, 80)
out

In [None]:
out = test_fun(6000, -400, 400, 80)
out

In [None]:
out = test_fun(4937, -400, 400, 80)
out

In [None]:
out = test_fun(4936, -400, 400, 80)
out

In [None]:
out = test_fun(4935, -400, 400, 80)
out

In [None]:
out = test_fun(4934, -400, 400, 80)
out

In [None]:
out = test_fun(4933, -400, 400, 80)
out

In [None]:
out = test_fun(4932, -400, 400, 80)
out

In [None]:
out = test_fun(4931, -400, 400, 80)
out

In [None]:
with open("./dataset/gen_dataset/label.json", "r") as jsonf:
    json_data = json.load(jsonf)

ref_res = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32]

json_data["train_dataset_num"] = 3000
json_data["validate_dataset_num"] = 1000

spv = []
for res in ref_res:
    w = torch.ones((1, int(res)))
    spv.append(w)
spv = torch.cat(spv, 1).to(DEVICE)

style_train = []
style_validate = []
for idx in range(len(json_data["images"])):

    z = torch.tensor(json_data["images"][str(idx)]["z"]).to(DEVICE)
    ws = style_generator.mapping(z, 0)

    gen_images = style_generator.synthesis(ws, spv)
    gen_images = (gen_images * 127.5 + 128).clamp(0, 255).to(torch.uint8)

    original_style = []
    for block in style_generator.synthesis.children():
        for layer in block.children():
            if not isinstance(layer, networks.ToRGBLayer):
                original_style.append(layer.original_style.squeeze().tolist())
    
    json_data["images"][str(idx)]["styles"] = original_style
    
    if json_data["images"][str(idx)]["type"] == "train":
        style_train.append(original_style)
    else:
        style_validate.append(original_style)

style_train_mean = []
style_validate_mean = []
for layer in range(len(ref_res)):
    for channel in range(ref_res[layer]):
        _acc = 0
        for idx in range(len(style_train)):
            _acc += style_train[idx][layer][channel]
        style_train_mean.append(_acc/len(style_train))

        _acc = 0
        for idx in range(len(style_validate)):
            _acc += style_train[idx][layer][channel]
        style_validate_mean.append(_acc/len(style_validate))

json_data["style_train_mean"] = style_train_mean
json_data["style_validate_mean"] = style_train_mean

with open("./dataset/gen_dataset/label1.json", "w") as jsonf:
    json.dump(json_data, jsonf)

In [None]:
z = torch.tensor(json_data["images"][str(0)]["z"]).to(DEVICE)
ws = style_generator.mapping(z, 0)

ref_res = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32]

spv = []
for res in ref_res:
    w = torch.ones((1, int(res)))
    spv.append(w)

spv = torch.cat(spv, 1).to(DEVICE)


gen_images = style_generator.synthesis(ws, spv)
gen_images = (gen_images * 127.5 + 128).clamp(0, 255)

In [None]:
original_style = []
modified_style = []
for block in style_generator.synthesis.children():
    for layer in block.children():
        if not isinstance(layer, networks.ToRGBLayer):
            original_style.append(layer.original_style)
            modified_style.append(layer.modified_style)

original_style = torch.cat(original_style, dim=1)
modified_style = torch.cat(modified_style, dim=1)

In [None]:
import model
face_lnd_estimator = model.FaceLandmarkEstimator()

with dnnlib.util.open_url("pretrained/ffhq.pkl") as f:
    data = legacy.load_network_pkl(f)
    generator = data["G_ema"].to(DEVICE)
    discriminator = data["D"].to(DEVICE)

style_generator = networks.Generator(**generator.init_kwargs).to(DEVICE)
fa_network = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
style_generator.load_state_dict(generator.state_dict())
style_generator.eval()

ref_mapping_network = model.RefMappingNetwork().to(DEVICE)
ref_mapping_network.load_state_dict(torch.load("./output/saved_model/sp/mapping_network.pth"))
ref_mapping_network.eval()

inception_v3 = models.inception_v3(pretrained=True).to(DEVICE)
inception_v3.load_state_dict(torch.load("./output/saved_model/sp/inception_v3.pth"))
inception_v3.eval()
inception_features = model.FeatureExtractor(inception_v3, ["fc"])

ref_images, ref_landmarks = next(iter(training_loader))
ref_images = ref_images.to(DEVICE)
ref_landmarks = ref_landmarks.view(ref_landmarks.shape[0], -1).to(DEVICE)

# seed = random.randint(0, 2**23 - 1)
# z = torch.tensor(np.random.RandomState(seed).randn(6, 512)).to(DEVICE)
# ws = style_generator.mapping(z, 0)

ref_features = inception_features(ref_images)
spv = ref_mapping_network(ref_features["fc"][0])

# gen_images = style_generator.synthesis(ws, spv)
# gen_images = (gen_images * 127.5 + 128).clamp(0, 255)

# def show_landmarks(image, landmarks, bbox=None, retuire_bbox=False):
#     fig, ax = plt.subplots()
#     ax.imshow(image)
#     ax.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
#     if retuire_bbox:
#         bbox = bbox[0]
#         rect = patches.Rectangle((bbox[0], bbox[3] - (bbox[3] - bbox[1])), bbox[2]- bbox[0], bbox[3] - bbox[1], linewidth=1, edgecolor='g', facecolor='none')
#         ax.add_patch(rect)
#     plt.pause(0.001)

# for i in range(6):
#     lnd = fa_network.get_landmarks_from_image(gen_images[i].permute(1,2,0))
#     if lnd != None:
#         show_landmarks(gen_images[i].detach().cpu().permute(1,2,0), lnd[0])

# pred_landmarks, unusual_index = face_lnd_estimator(gen_images)

In [None]:
ref_res = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32]
a = 0
b = 1024
for i in range(len(ref_res)):
    #print(a, b, ref_res[i])
    print("weight:", spv[0][a:b].tolist()[0:int(ref_res[i]/2)])
    print("bias:", spv[0][a:b].tolist()[int(ref_res[i]/2):ref_res[i]])
    a += ref_res[i]
    b += ref_res[i]

In [None]:
with dnnlib.util.open_url("pretrained/ffhq.pkl") as f:
    generator = data["G_ema"].to(DEVICE)
    data = legacy.load_network_pkl(f)
    discriminator = data["D"].to(DEVICE)
style_generator = networks.Generator(**generator.init_kwargs).to(DEVICE)
fa_network = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False)
style_generator.load_state_dict(generator.state_dict())
style_generator.eval()
seed = random.randint(0, 2**23 - 1)
z = torch.tensor(np.random.RandomState(seed).randn(1, 512)).to(DEVICE)
ws = style_generator.mapping(z, 0)

ref_res = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 512, 512, 256, 256, 128, 128, 64]

spv = []
for res in ref_res:
    w = torch.ones((1, int(res / 2)))
    b = torch.zeros((1, int(res / 2)))
    spv.append(torch.cat((w, b), 1))

spv = torch.cat(spv, 1).to(DEVICE)

gen_images = style_generator.synthesis(ws, spv)
gen_images = (gen_images * 127.5 + 128).clamp(0, 255)

pred_ = fa_network.get_landmarks_from_image(gen_images.squeeze().permute(1, 2, 0))

In [None]:
# Network
with dnnlib.util.open_url("pretrained/ffhq.pkl") as f:
    data = legacy.load_network_pkl(f)
    generator = data["G_ema"].to(DEVICE)
    discriminator = data["D"].to(DEVICE)

style_generator = networks.Generator(**generator.init_kwargs).to(DEVICE)
style_space_discriminator = model.StyleSpaceDiscriminator().to(DEVICE)
style_discriminator = networks.Discriminator(**discriminator.init_kwargs).to(DEVICE)

face_lnd_estimator = model.FaceLandmarkEstimator()

style_generator.load_state_dict(generator.state_dict())
style_generator.eval()

style_discriminator.load_state_dict(discriminator.state_dict())
style_discriminator.eval()

ref_mapping_network = model.RefMappingNetwork().to(DEVICE)
ref_mapping_network.load_state_dict(torch.load("./output/saved_model/sp"+"/mapping_network.pth"))
ref_mapping_network.eval()

inception_v3 = models.inception_v3(init_weights=True).to(DEVICE)
inception_v3.load_state_dict(torch.load("./output/saved_model/sp"+"/inception_v3.pth"))
inception_v3.eval()
inception_features = model.FeatureExtractor(inception_v3, ["fc"])

inception_v3_optimizer = optim.Adam(inception_v3.parameters(), lr=0.01)
mapping_network_optimizer = optim.Adam(ref_mapping_network.parameters(), lr=0.01)
style_discriminator_optimizer = optim.Adam(style_discriminator.parameters(), lr=0.01)

# ref_res = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 512, 512, 256, 256, 128, 128, 64]

# spv = []
# for res in ref_res:
#     w = torch.ones((6, int(res / 2)))
#     b = torch.zeros((6, int(res / 2)))
#     spv.append(torch.cat((w, b), 1))

# spv = torch.cat(spv, 1).to(DEVICE)

In [None]:
images, landmarks = next(iter(training_loader))
images = images.to(DEVICE)
landmarks = landmarks.view(landmarks.shape[0], -1).to(DEVICE)

seed = random.randint(0, 2**23 - 1)
z = torch.tensor(np.random.RandomState(seed).randn(1, 512)).to(DEVICE)
ws = style_generator.mapping(z, 0)

features = inception_features(images)
spv = ref_mapping_network(features["fc"][0])

generated_images = style_generator.synthesis(ws, spv)
generated_images = (generated_images * 127.5 + 128).clamp(0, 255)

In [None]:
ws.shape

In [None]:
f_logit = style_discriminator(generated_images, 0)
r_logit = style_discriminator(TF.resize(images, (1024, 1024)), 0)

In [None]:
d_loss = torch.nn.functional.softplus(-r_logit) + torch.nn.functional.softplus(r_logit)

In [None]:
d_loss

In [None]:
from os import listdir
from os.path import isfile, join

onlyfiles = [f for f in listdir("./output/saved_img/lnd/") if isfile(join("./output/saved_img/lnd/", f))]

In [None]:
logits_list = []
for image_path in onlyfiles:
    _image = Image.open(join("./output/saved_img/lnd/", image_path))
    gen_image = TF.to_tensor(_image)[:,:,256:].to(DEVICE)
    ref_image = TF.to_tensor(_image)[:,:,0:256].to(DEVICE)
    f_logits = style_discriminator(TF.resize(gen_image[None, ...], (1024, 1024)), 0)
    r_logits = style_discriminator(TF.resize(ref_image[None, ...], (1024, 1024)), 0)
    print(f"Image:{image_path}", (torch.nn.functional.softplus(-r_logits)+torch.nn.functional.softplus(f_logits)).item())

In [None]:
b = TF.to_tensor(a)[:,:,256:].to(DEVICE)

In [None]:
logits = style_discriminator(TF.resize(b[None, ...], (1024, 1024)), 0)

In [None]:
torch.nn.functional.softplus(logits)

In [None]:
torch.nn.functional.softplus(f_logit), torch.nn.functional.softplus(r_logit)

In [None]:
TF.to_pil_image(generated_images[0].detach().cpu().to(torch.uint8))

In [None]:
importlib.reload(sys.modules["model"])
face_lnd_estimator = model.FaceLandmarkEstimator()

In [None]:
pred, ind = face_lnd_estimator(generated_images)

In [None]:
pred

In [None]:
def show_landmarks(image, landmarks, bbox=None, retuire_bbox=False):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.scatter(landmarks[:, 0], landmarks[:, 1], s=10, marker='.', c='r')
    if retuire_bbox:
        bbox = bbox[0]
        rect = patches.Rectangle((bbox[0], bbox[3] - (bbox[3] - bbox[1])), bbox[2]- bbox[0], bbox[3] - bbox[1], linewidth=1, edgecolor='g', facecolor='none')
        ax.add_patch(rect)
    plt.pause(0.001)

In [None]:
generated_images.shape, pred.shape

In [None]:
show_landmarks(generated_images[0].permute(1,2,0).detach().cpu().to(torch.uint8), pred[0].detach().cpu())

In [None]:
import sys, torch
import importlib
import face_alignment
import face_alignment.utils as fan_utils
import model
from ffhq_data import FaceLandmarksDataset, Transforms
from torch.utils.data import DataLoader
import torchvision.transforms.functional as TF
import matplotlib.pyplot as plt

importlib.reload(sys.modules["model"])
importlib.reload(sys.modules["face_alignment.utils"])
import model
import face_alignment.utils as fan_utils

DEVICE = torch.device('cuda' if torch.cuda.is_available() else "cpu")

face_lnd_estimator = model.FaceLandmarkEstimator()

In [None]:
training_dataset = FaceLandmarksDataset("training", scope=30000, transform=Transforms())
training_loader = DataLoader(training_dataset, 6, shuffle=True)

images, landmarks = next(iter(training_loader))
images = images.to(DEVICE)
landmarks = landmarks.view(landmarks.shape[0], -1).to(DEVICE)

In [None]:
i = 5
show_landmarks(images[i].permute(1,2,0).cpu(), l[i].cpu().detach())

In [None]:
image = (images * 127.5 + 128).clamp(0, 255)
d = face_lnd_estimator.fa_network.face_detector.detect_from_image(image[0].permute(1,2,0).detach().cpu())
bbox_width = d[0][2] - d[0][0]
bbox_height = d[0][3] - d[0][1]
top = d[0][1]
left = d[0][0]

In [None]:
a = TF.crop(image[0].to(torch.uint8), int(top), int(left), int(bbox_height), int(bbox_width))

In [None]:
TF.to_pil_image(a)

In [None]:
center = torch.tensor([d[0][2] - (d[0][2] - d[0][0]) / 2.0, d[0][3] - (d[0][3] - d[0][1]) / 2.0])
center[1] = center[1] - (d[0][3] - d[0][1]) * 0.12
scale = torch.tensor((d[0][2] - d[0][0] + d[0][3] - d[0][1]) / face_lnd_estimator.fa_network.face_detector.reference_scale)

inp = face_lnd_estimator.differentiableCrop(image[0].permute(1,2,0), center, scale)

In [None]:
inp.shape

In [None]:
TF.to_pil_image(inp.permute(2,0,1).to(torch.uint8))

In [None]:
pred_ = face_lnd_estimator(generated_images)

In [None]:
torch.eye(3)

In [None]:
pred_.shape

In [None]:
i = 4
show_landmarks(generated_images[5].to(torch.uint8).detach().cpu().permute(1,2,0), pred_[4].detach().cpu())

In [None]:
landmarks_batch_loss = torch.pow((landmarks - pred_.view(6, -1)), 2)

In [None]:
landmark_batch_loss = 0
for i in range(6):
    pred_ = fa_network.get_landmarks_from_image(generated_images[i].permute(1, 2, 0))

    if pred_ is not None and len(pred_[0]) == 68:
        pred_landmark = torch.from_numpy(pred_[0]).requires_grad_()
        landmark_batch_loss += torch.pow((landmarks[i] - pred_landmark.view(-1).to(DEVICE)), 2)

if isinstance(landmark_batch_loss, torch.Tensor):
    landmark_training_loss = landmark_batch_loss.mean()

In [None]:
list(ref_mapping_network.parameters())[0].grad

In [None]:
generated_images = style_generator.synthesis(ws, spv).detach()
generated_images = (generated_images * 127.5 + 128).clamp(0, 255).to(torch.uint8)

original_style = []
modified_style = []
for block in style_generator.synthesis.children():
    for layer in block.children():
        if not isinstance(layer, networks.ToRGBLayer):
            original_style.append(layer.original_style)
            modified_style.append(layer.modified_style)
        
original_style = torch.cat(original_style, dim=1)
modified_style = torch.cat(modified_style, dim=1)

original_style_score = style_space_discriminator(original_style)
modified_style_score = style_space_discriminator(modified_style)

In [None]:
# Facial landmark detection
landmark_acc_loss = 0
bbox_acc_loss = 0

a = (images * 127.5 + 128).clamp(0, 255).to(torch.uint8)
for i in range(6):
    pred_ = fa_network.get_landmarks_from_image(a[i].permute(1, 2, 0))


In [None]:
pred_[0]

In [None]:
landmarks[0]

In [None]:
# Disable spv input
spw_dims = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 512, 512, 256, 256, 128, 128, 64]
spv = []
for i in spw_dims:
    spv.append(torch.cat((torch.ones((10, int(i / 2))), torch.zeros((10, int(i/2)))), dim=1))
spv = torch.cat(spv, dim=1).to(DEVICE)

generated_images = style_generator.synthesis(ws, spv).detach()
generated_images = (generated_images * 127.5 + 128).clamp(0, 255).to(torch.uint8)

In [None]:
pred_landmarks = fa_network.get_landmarks_from_image(generated_images[0].permute(1,2,0), return_bboxes=True)

In [None]:
landmark_acc_loss = 0
bbox_acc_loss = 0
for i in range(10):
    pred_ = fa_network.get_landmarks_from_image(generated_images[i].permute(1, 2, 0), return_bboxes=True)
    if len(pred_[0][0]) == 68:
        top = pred_[2][0][1]
        left = pred_[2][0][0]
        width = pred_[2][0][2] - pred_[2][0][0]
        height = pred_[2][0][3] - pred_[2][0][1]

        pred_landmark = torch.from_numpy(pred_[0][0]).requires_grad_() - torch.tensor([left, top])
        pred_landmark /= torch.tensor([width, height])

        landmark_acc_loss += landmarks[i].pow(2) - pred_landmark.view(-1).to(DEVICE).pow(2)

        bbox_acc_loss += bboxs[i] - torch.tensor([top, left, width, height]).to(DEVICE)

In [None]:
len(train_loader)

In [None]:
landmark_acc_loss.mean()

In [None]:
bbox_acc_loss.mean()

In [None]:
if isinstance(landmark_acc_loss, torch.Tensor) and isinstance(bbox_acc_loss, torch.Tensor):
    pred_train_step_loss = landmark_acc_loss.mean() + bbox_acc_loss.mean()

In [None]:
pred_train_step_loss.item()

In [None]:
show_landmarks(generated_images[0].detach().cpu().permute(1,2,0), pred_landmarks[0][0], pred_landmarks[2], retuire_bbox=True)

In [None]:
top = pred_landmarks[2][0][1]
left = pred_landmarks[2][0][0]
width = pred_landmarks[2][0][2] - pred_landmarks[2][0][0]
height = pred_landmarks[2][0][3] - pred_landmarks[2][0][1]

In [None]:
pred_lanmark = torch.tensor(pred_landmarks[0][0]) - torch.tensor([left, top])
pred_lanmark = pred_lanmark / torch.tensor([width, height])

In [None]:
loss = landmarks[0].pow(2) - pred_lanmark.view(-1).pow(2).to(DEVICE)

In [None]:
bboxs[0]

In [None]:
bbox_acc_loss = bboxs[0] - torch.tensor([top, left, width, height]).to(DEVICE)

In [None]:
bbox_acc_loss.mean()

In [None]:
pred_lanmark.shape

In [None]:
original_style = []
modified_style = []
for block in style_generator.synthesis.children():
    for layer in block.children():
        if not isinstance(layer, networks.ToRGBLayer):
            original_style.append(layer.original_style)
            modified_style.append(layer.modified_style)

style_ori = torch.cat(original_style, dim=1)
style_mod = torch.cat(modified_style, dim=1)

score_ori = style_space_discriminator(style_ori)
score_mod = style_space_discriminator(style_mod)

In [None]:
spw = torch.randn(10, 12096)
spw_dims = [512, 512, 512, 512, 512, 512, 512, 512, 512, 512, 256, 256, 128, 128, 64, 64, 32]
spw_dims = [1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 1024, 512, 512, 256, 256, 128, 128, 64]
#spw_dims = [512, 1024, 1024, 1024, 1024, 768, 384, 192, 96]

spw_idx = 0
spw_ = []
spw_.append(spw.narrow(1, spw_idx, spw_dims[0]))
spw_idx += spw_dims[0]
for idx in range(1, len(spw_dims), 2):
    spw_1 = spw.narrow(1, spw_idx, spw_dims[idx])
    spw_idx += spw_dims[idx]
    spw_2 = spw.narrow(1, spw_idx, spw_dims[idx+1])
    spw_idx += spw_dims[idx+1]
    spw_.append((spw_1, spw_2))

In [None]:
for i in spw_:
    if isinstance(i, tuple):
        print(i[0].shape, i[1].shape)
    else:
        print(i.shape)

In [None]:
import PIL.Image as Image

a = TF.to_pil_image(images[0])
b = TF.to_pil_image(generated_image[0])

out_img = Image.new("RGB", (512, 256))

out_img.paste(a, (0, 0))

out_img.paste(b, (256, 0))

out_img.save("./output/saved_img/01.jpg")

In [None]:
pred_results[2][0][0]