In [1]:
import torch
import dlib
import torch.nn as nn
import os

from data_loader.ffhq_data import FFHQDataset
from torch.utils.data.dataloader import DataLoader
from general_utils.landmarks_utils import parse_roi_box_from_bbox

from pathlib import Path
from model.network import Network

from tqdm import tqdm

import cv2 as cv
import numpy as np

import torchvision.transforms.functional as TF
from torchvision.io import read_image
from torchvision.utils import make_grid


class Args(object):
    def __init__(self):
        self.name = "exp_27-conv3--8"
        self.resolution = 256
        self.load_checkpoint = False
        self.train = False
        self.dataset_path = Path("./dataset")
        self.results_dir = Path("./output")
        self.pretrained_models_path = Path("./pretrained")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False
        self.weights_dir = self.results_dir.joinpath(self.name+"/weights")
        self.cache = True
        self.parameter_embedding = False


args = Args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ffhq_dataset = FFHQDataset(args)
real_dataset_dir = os.path.join(args.dataset_path, "ffhq256_dataset")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ffhq_dataloader = DataLoader(ffhq_dataset, batch_size=5, pin_memory=True)
ffhq_iter = iter(ffhq_dataloader)

attr_path_list = ["./dataset/ffhq256_dataset/image/07000/07983.png",
                  "./dataset/ffhq256_dataset/image/04000/04986.png",
                  "./dataset/ffhq256_dataset/image/07000/07024.png"]

attr_list = []
for attr_path in attr_path_list:
    attr_image = cv.imread(attr_path)

    attr_image = torch.from_numpy(attr_image.transpose((2, 0, 1))).float()
    attr_image.sub_(127.5).div_(128)

    attr_image = attr_image.flip(-3)  # convert to RGB
    attr_list.append(attr_image[None, ...])

attr_images = torch.cat(attr_list, 0).to(device)

id_path_list = ["./dataset/gen_dataset/image/14000/13006.png",
                "dataset/gen_dataset/image/43000/42024.png",
                "dataset/gen_dataset/image/46000/45009.png"]

id_image_list = []
id_seed_list = []
for id_path in id_path_list:
    id_image = cv.imread(id_path)

    id_image = torch.from_numpy(id_image.transpose((2, 0, 1))).float()
    id_image.sub_(127.5).div_(128)

    id_image = id_image.flip(-3)  # convert to RGB
    id_image_list.append(id_image[None, ...])

    id_image_seed = int(os.path.splitext(os.path.basename(id_path))[0])
    id_seed_list.append(id_image_seed)

id_images = torch.cat(id_image_list, 0).to(device)

id_model_path = str(args.pretrained_models_path.joinpath(
    "20180402-114759-vggface2.pt"))
stylegan_G_path = str(args.pretrained_models_path.joinpath(
    "stylegan2-ffhq-256x256.pkl"))
landmarks_model_path = str(
    args.pretrained_models_path.joinpath('3DDFA/phase1_wpdc_vdc.pth.tar'))

network = Network(args=args, id_model_path=id_model_path, base_generator_path=stylegan_G_path,
                    landmarks_detector_path=landmarks_model_path, device=device)

network._load("")

zero_ctrlv = torch.zeros((3, 4928)).to(device)
style_padding1 = torch.zeros((3, 1536)).to(device)
style_padding2 = torch.zeros((3, 576)).to(device)

zs_list = []
for id_seed in id_seed_list:
    z = torch.from_numpy(np.random.RandomState(id_seed).randn(1, 512)).to(device)
    zs_list.append(z)

zs = torch.cat(zs_list, 0)

ws = network.generator.stylegan_generator.mapping(zs, 0)

attr_lnds_result = network.generator.landmarks_detector(attr_images.flip(-3))

attr_lnds, attr_pose = attr_lnds_result[0], attr_lnds_result[1]

with torch.no_grad():
    attr_embeddings = network.generator.attr_encoder(attr_images)

    control_vectors = network.generator.reference_network(attr_embeddings)
    control_vectors = torch.cat([style_padding1, control_vectors, style_padding2], -1)
    # control_vector = torch.cat([control_vector, style_padding], -1)

    gen_images = network.generator.stylegan_generator.synthesis(ws, control_vectors)



    id_images = ((id_images + 1) / 2).clamp(0, 1)
    attr_images = ((attr_images + 1) / 2).clamp(0, 1)
    gen_images = ((gen_images + 1) / 2).clamp(0, 1)

    gen_images = torch.cat([gen_image for gen_image in gen_images], 2)

    attr_images = torch.cat([attr_image for attr_image in attr_images], 2)

    gen_results = torch.cat([id_images[0], gen_images], 2)

    attr_images = torch.cat([torch.ones(3, 256, 256).to(device), attr_images], 2)
    gen_results = torch.cat([attr_images, gen_results], 1)

TF.to_pil_image(gen_results)

Downloading: "https://download.pytorch.org/models/inception_v3_google-0cc3c7bd.pth" to /home/zhuo/.cache/torch/hub/checkpoints/inception_v3_google-0cc3c7bd.pth
100%|██████████| 104M/104M [00:01<00:00, 108MB/s]  


AttrEncoder loads checkpoint from: output/exp_27-conv3--8/weights/AttrEncoder.pth


RuntimeError: Error(s) in loading state_dict for Sequential:
	size mismatch for 6.weight: copying a param with shape torch.Size([2816, 4096]) from checkpoint, the shape in current model is torch.Size([4928, 4096]).
	size mismatch for 6.bias: copying a param with shape torch.Size([2816]) from checkpoint, the shape in current model is torch.Size([4928]).

In [None]:
import torch
import dlib
import torch.nn as nn
import os

from data_loader.ffhq_data import FFHQDataset
from torch.utils.data.dataloader import DataLoader
from general_utils.landmarks_utils import parse_roi_box_from_bbox

from pathlib import Path
from model.network import Network

from tqdm import tqdm

import cv2 as cv
import numpy as np

import torchvision.transforms.functional as TF
from torchvision.io import read_image
from torchvision.utils import make_grid


class Args(object):
    def __init__(self):
        self.name = "exp_27-conv3--8"
        self.resolution = 256
        self.load_checkpoint = False
        self.train = False
        self.dataset_path = Path("./dataset")
        self.results_dir = Path("./output")
        self.pretrained_models_path = Path("./pretrained")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False
        self.weights_dir = self.results_dir.joinpath(self.name+"/weights")
        self.cache = True
        self.parameter_embedding = False


args = Args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ffhq_dataset = FFHQDataset(args)
real_dataset_dir = os.path.join(args.dataset_path, "ffhq256_dataset")

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

ffhq_dataloader = DataLoader(ffhq_dataset, batch_size=5, pin_memory=True)
ffhq_iter = iter(ffhq_dataloader)

attr_path_list = ["./dataset/ffhq256_dataset/image/07000/07983.png",
                  "./dataset/ffhq256_dataset/image/04000/04986.png",
                  "./dataset/ffhq256_dataset/image/07000/07024.png"]

attr_list = []
for attr_path in attr_path_list:
    attr_image = cv.imread(attr_path)

    attr_image = torch.from_numpy(attr_image.transpose((2, 0, 1))).float()
    attr_image.sub_(127.5).div_(128)

    attr_image = attr_image.flip(-3)  # convert to RGB
    attr_list.append(attr_image[None, ...])

attr_images = torch.cat(attr_list, 0).to(device)

id_path_list = ["./dataset/gen_dataset/image/14000/13006.png",
                "dataset/gen_dataset/image/43000/42024.png",
                "dataset/gen_dataset/image/46000/45009.png"]

id_image_list = []
id_seed_list = []
for id_path in id_path_list:
    id_image = cv.imread(id_path)

    id_image = torch.from_numpy(id_image.transpose((2, 0, 1))).float()
    id_image.sub_(127.5).div_(128)

    id_image = id_image.flip(-3)  # convert to RGB
    id_image_list.append(id_image[None, ...])

    id_image_seed = int(os.path.splitext(os.path.basename(id_path))[0])
    id_seed_list.append(id_image_seed)

id_images = torch.cat(id_image_list, 0).to(device)

# id_image_path = "gen_dataset/image/3000/2641.png"
# id_image_seed = int(os.path.splitext(os.path.basename(id_image_path))[0])
# id_image = cv.imread(str(args.dataset_path.joinpath(id_image_path)))
# id_image = torch.from_numpy(id_image.transpose(
#     (2, 0, 1))).float().to(device).flip(-3)
# id_image.sub_(127.5).div_(128)

id_images = torch.broadcast_to(id_image, [3, *id_image.shape])

detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(
    "./pretrained/shape_predictor_68_face_landmarks.dat")


def generate_image(args, id_images, id_seed_list, attr_images):

    id_model_path = str(args.pretrained_models_path.joinpath(
        "20180402-114759-vggface2.pt"))
    stylegan_G_path = str(args.pretrained_models_path.joinpath(
        "stylegan2-ffhq-256x256.pkl"))
    landmarks_model_path = str(
        args.pretrained_models_path.joinpath('3DDFA/phase1_wpdc_vdc.pth.tar'))

    network = Network(args=args, id_model_path=id_model_path, base_generator_path=stylegan_G_path,
                      landmarks_detector_path=landmarks_model_path, device=device)

    network._load("")

    zero_ctrlv = torch.zeros((1, 4928)).to(device)
    style_padding1 = torch.zeros((1, 2560)).to(device)
    style_padding2 = torch.zeros((1, 576)).to(device)

    identity_score = 0
    expression_score = 0
    pose_score = 0
    #pbar = tqdm(range(1), ncols=80)
    for _ in attr_images:
        #z = torch.randn((1, 512)).to(device).clamp(-1, 1)
        z = torch.from_numpy(np.random.RandomState(id_img_seed).randn(1, 512)).to(device)
        ws = network.generator.stylegan_generator.mapping(
            torch.broadcast_to(z, (5, 512)), 0)

        # attr_images = next(ffhq_iter)
        # attr_images = attr_images.to(device)

        attr_lnds_result = network.generator.landmarks_detector(
            attr_images.flip(-3))

        if attr_lnds_result is None:
            continue

        attr_lnds, attr_pose = attr_lnds_result[0], attr_lnds_result[1]

        with torch.no_grad():
            gen_id_image = network.generator.stylegan_generator.synthesis(
                ws, zero_ctrlv)

            gen_id_embedding = network.generator.id_encoder(gen_id_image)
            attr_embedding = network.generator.attr_encoder(attr_images)

            # feature_input = torch.concat([gen_id_embedding, attr_embedding], -1)
            feature_input = attr_embedding

            control_vector = network.generator.reference_network(feature_input)

            #control_vector = torch.cat([style_padding1, control_vector, style_padding2], -1)

            # control_vector = torch.cat([control_vector, style_padding], -1)

            gen_images = network.generator.stylegan_generator.synthesis(
                ws, control_vector)

            gen_embedding = network.generator.id_encoder(gen_images)

            # identity_score += nn.functional.cosine_similarity(
            #     gen_id_embedding, gen_embedding).mean().item()

            # current_identity_score = nn.functional.cosine_similarity(
            #     gen_id_embedding, gen_embedding)
            # identity_score += torch.pow((current_identity_score -
            #                             0.84), 2).item()

            # gen_lnds_result = network.generator.landmarks_detector(
            #     gen_images.flip(-3))

            # if gen_lnds_result is None:
            #     continue

            #gen_lnds, gen_pose = gen_lnds_result[0], gen_lnds_result[1]

            # print(torch.pow((gen_lnds[0]/256 - attr_lnds[0]/256), 2).sum())

            # expression_score += torch.pow(
            #     (gen_lnds[0]/256 - attr_lnds[0]/256), 2).sum().item()
            # pose_score += torch.pow((gen_pose[0] -
            #                         attr_pose[0]), 2).sum().item()

            # current_expression_score = torch.pow(
            #     (gen_lnds[0]/256 - attr_lnds[0]/256), 2).sum()
            # expression_score += torch.pow(
            #     (current_expression_score - 0.026), 2).item()
            # current_pose_score = torch.pow(
            #     (gen_pose[0] - attr_pose[0]), 2).sum()
            # pose_score += torch.pow((current_pose_score -
            #                         0.05741964), 2).item()

            id_images = ((gen_id_image + 1) / 2).clamp(0, 1)
            attr_images = ((attr_images + 1) / 2).clamp(0, 1)
            gen_images = ((gen_images + 1) / 2).clamp(0, 1)

            gen_images = torch.cat(
                [gen_image for gen_image in gen_images], 2)

            attr_images = torch.cat(
                [attr_image for attr_image in attr_images], 2)

            gen_results = torch.cat([id_images[0], gen_images], 2)

            attr_images = torch.cat(
                [torch.ones(3, 256, 256).to(device), attr_images], 2)
            gen_results = torch.cat([attr_images, gen_results], 1)

    return gen_results

In [None]:
args.name = "exp_11"
args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
gen_results_0 = generate_image(args)
# args.name = "exp_05-no_concatenation"
# args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
# gen_results_1 = generate_image(args)
# args.name = "exp_06-add_style_regularizer"
# args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
# gen_results_2 = generate_image(args)
# args.name = "exp_11"
# args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
# gen_results_3 = generate_image(args)

In [None]:
TF.to_pil_image(gen_results_0)

In [None]:
gen_results_0[0] / 1000, gen_results_0[1] / \
    1000, gen_results_0[2] / 1000 * 180 / 3.1415,

In [None]:
import dlib
import cv2 as cv
from imutils import face_utils
detector = dlib.get_frontal_face_detector()
predictor = dlib.shape_predictor(
    "./pretrained/shape_predictor_68_face_landmarks.dat")

while True:
    # Getting out image by webcam
    image = cv.imread("./dataset/ffhq256_dataset/image/01000/01000.png")
    # Converting the image to gray scale
    gray = cv.cvtColor(image, cv.COLOR_BGR2GRAY)

    # Get faces into webcam's image
    rects = detector(gray, 0)

    # For each detected face, find the landmark.
    for (i, rect) in enumerate(rects):
        # Make the prediction and transfom it to numpy array
        shape = predictor(gray, rect)
        shape = face_utils.shape_to_np(shape)

        # Draw on our image, all the finded cordinate points (x,y)
        for (x, y) in shape:
            cv.circle(image, (x, y), 2, (0, 255, 0), -1)

    # Show the image
    cv.imshow("Output", image)

In [None]:
args.name = "exp_05-no_concatenation"
args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
gen_results_1 = generate_image(args, id_images, id_img_seed, attr_images)
args.name = "exp_06-add_style_regularizer"
args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
gen_results_2 = generate_image(args, id_images, id_img_seed, attr_images)
args.name = "exp_11"
args.weights_dir = args.results_dir.joinpath(args.name+"/weights")
gen_results_3 = generate_image(args, id_images, id_img_seed, attr_images)

In [None]:
gen_results = torch.cat([gen_results_1[:, 256:],
                        gen_results_2[:, 256:], gen_results_3[:, 256:]], 1)

In [None]:
img01 = cv.imread("./output2.png")
img01 = torch.from_numpy(img01.transpose((2, 0, 1))).float()
img01 = img01.flip(-3)

img02 = cv.imread("./output1.png")
img02 = torch.from_numpy(img02.transpose((2, 0, 1))).float()
img02 = img02.flip(-3)

In [None]:
img = torch.cat([img01, img02], 1)
img.sub_(127.5).div_(128)

img = ((img + 1) / 2).clamp(0, 1)

In [None]:
TF.to_pil_image(img)

In [None]:
face_scale = 0
pbar = tqdm(range(10000), ncols=80)
for _ in pbar:
    images = next(ffhq_iter)
    images = images.to(device)

    lnds_results = network.generator.landmarks_detector(images.flip(-3))

    if lnds_results is None:
        continue

    lnds, poses = lnds_results[4], lnds_results[1]

    face_scale += lnds[0, 0, 16] - lnds[0, 0, 0]

mean_face_scale = face_scale / 10000

In [None]:
mean_face_scale

In [None]:
id_img_path = "gen_dataset/image/8000/7029.png"
id_img_seed = int(os.path.splitext(os.path.basename(id_img_path))[0])

attr_img_path = "gen_dataset/image/1000/964.png"

id_image = cv.imread(str(args.dataset_path.joinpath(id_img_path)))
attr_image = cv.imread(str(args.dataset_path.joinpath(attr_img_path)))

id_image = torch.from_numpy(id_image.transpose(
    (2, 0, 1))).float().to(DEVICE).flip(-3)
id_image.sub_(127.5).div_(128)
attr_image = torch.from_numpy(attr_image.transpose(
    (2, 0, 1))).float().to(DEVICE).flip(-3)
attr_image.sub_(127.5).div_(128)

show_image = torch.concat((id_image, attr_image), -1)
TF.to_pil_image(((show_image+1)/2))

In [None]:
attr_lnd_results = network.generator.landmarks_detector(
    attr_image.flip(-3)[None, ...])
attr_calib_lnds = attr_lnd_results[2]

id_lnd_results = network.generator.landmarks_detector(
    id_image.flip(-3)[None, ...])
id_calib_lnds = id_lnd_results[2]

In [None]:
from PIL import Image, ImageDraw

pil_attr = TF.to_pil_image((id_image+1)/2)

image_draw = ImageDraw.Draw(pil_attr)
# for x, y, _ in attr_calib_lnds[0, :, :3]:
#     image_landmark_coords = [(x-1, y-1), (x+1, y+1)]
#     image_draw.ellipse(image_landmark_coords, fill="red")
for x, y in id_lnd_results[2].permute(0, 2, 1)[0]:
    image_landmark_coords = [(x-1, y-1), (x+1, y+1)]
    image_draw.ellipse(image_landmark_coords, fill="blue")
# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     ori_e1.squeeze().tolist()[:2])], fill="green", width=2)
# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     ori_e2.squeeze().tolist()[:2])], fill="green", width=2)
# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     ori_e3.squeeze().tolist()[:2])], fill="green", width=2)

# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     new_e1.squeeze().tolist()[:2])], fill="red", width=2)
# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     new_e2.squeeze().tolist()[:2])], fill="red", width=2)
# image_draw.line([tuple(ori_o.squeeze().tolist()[:2]), tuple(
#     new_e3.squeeze().tolist()[:2])], fill="red", width=2)

In [None]:
pil_attr

In [None]:
attr_lnds = attr_lnds.permute(0, 2, 1)
id_lnds = id_lnds.permute(0, 2, 1)

attr_recov_lnds = attr_lnds.matmul(attr_R[0].T.inverse())
id_recov_lnds = id_lnds.matmul(id_R[0].T.inverse())

In [None]:
ori_o = (src_lnds[0, :, 0] + src_lnds.reshape(1, 68, 3)[0, 16, :]) / 2
ori_e1 = ori_o
ori_e2 = torch.tensor([128., 178., 0.], device=DEVICE)
ori_e3 = torch.tensor([128., 128., 50.], device=DEVICE)
new_o = ori_o[None, ...].mm(R[0].T)
new_e1 = ori_e1[None, ...].mm(R[0].T)
new_e2 = ori_e2[None, ...].mm(R[0].T)
new_e3 = ori_e3[None, ...].mm(R[0].T)
# new_e1[:, 1] = - new_e1[:, 1]
# new_e2[:, 1] = - new_e2[:, 1]
# new_e3[:, 1] = - new_e3[:, 1]

In [None]:
ori_e1 = new_e1.mm(R[0].T.inverse())
ori_e2 = new_e2.mm(R[0].T.inverse())
ori_e3 = new_e3.mm(R[0].T.inverse())

In [None]:
pil_attr

In [None]:
style_padding = torch.zeros((1, 1344)).to(DEVICE)
with torch.no_grad():
    # z = torch.randn((1, 512)).to(DEVICE).clamp(-1, 1)
    z = torch.from_numpy(np.random.RandomState(
        id_img_seed).randn(1, 512)).to(DEVICE)
    ws = network.generator.stylegan_generator.mapping(z, 0)
    # ctrlv = torch.zeros((1, 6048)).to(DEVICE)

    # gen_id_image = network.generator.stylegan_generator.synthesis(ws, ctrlv)

    id_embedding = network.generator.id_encoder(id_image[None, ...])
    attr_embedding = network.generator.attr_encoder(
        torch.broadcast_to(attr_image, [1, *attr_image.shape]))

    # feature_input = torch.concat([id_embedding, attr_embedding], -1)
    feature_input = attr_embedding

    if args.parameter_embedding:
        pose_sp_embedding = network.generator.reference_pose_encoder(
            feature_input)
        pose_control_vector = network.generator.reference_pose_decoder(
            pose_sp_embedding)
        expression_sp_embedding = network.generator.reference_expression_encoder(
            feature_input)
        expression_control_vector = network.generator.reference_expression_decoder(
            expression_sp_embedding)
        base_control_vector = 0
    else:
        base_control_vector = network.generator.reference_network(
            feature_input)
        pose_control_vector = expression_control_vector = 0

    control_vector = base_control_vector + \
        pose_control_vector + expression_control_vector

    control_vector = torch.cat([control_vector, style_padding], -1)

    gen_image = network.generator.stylegan_generator.synthesis(
        ws, control_vector)

TF.to_pil_image(make_grid(
    [((id_image + 1) / 2).clamp(0, 1), ((gen_image[0] + 1) / 2).clamp(0, 1)]))

In [None]:
base_control_vector.shape

In [None]:
train_img = ((gen_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)
test_img = ((gen_id_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8)

train_img = TF.resize(train_img, (299, 299)).cpu()
test_img = TF.resize(test_img, (299, 299)).cpu()

In [None]:
from metrics.frechet_inception_distance import FIDScore

fid_score = FIDScore()

fid_score.calculate_fid(train_img, test_img)

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance()
fid.update(test_img, real=True)
fid.update(train_img, real=False)
fid.compute()

In [None]:
def evaluation(network, id_image_path, attr_image_path_list, feature_embedding=False):

    id_image = cv.imread(str(args.dataset_path.joinpath(id_image_path)))
    id_image_seed = int(os.path.splitext(os.path.basename(id_image_path))[0])

    id_image = torch.from_numpy(id_image.transpose(
        (2, 0, 1))).float().to(DEVICE).flip(-3)
    id_image.sub_(127.5).div_(128)
    id_image = id_image.expand(len(attr_image_path_list), *id_image.shape)

    attr_image_list = []
    for attr_image_path in attr_image_path_list:

        attr_image = cv.imread(
            str(args.dataset_path.joinpath(attr_image_path)))
        attr_image = torch.from_numpy(attr_image.transpose(
            (2, 0, 1))).float().to(DEVICE).flip(-3)
        attr_image.sub_(127.5).div_(128)

        attr_image_list.append(attr_image[None, ...])

    attr_image = torch.concat(attr_image_list, 0)

    with torch.no_grad():
        # z = torch.randn((6,512)).to(DEVICE).clamp(-1, 1)
        zs = torch.from_numpy(np.random.RandomState(id_image_seed).randn(
            1, 512)).to(DEVICE).expand(len(attr_image_path_list), 512)
        ws = network.generator.stylegan_generator.mapping(zs, 0)
        # ctrlv = torch.zeros((1, 6048)).to(DEVICE)

        # gen_id_image = network.generator.stylegan_generator.synthesis(ws, ctrlv)

        id_embedding = network.generator.id_encoder(id_image)
        attr_embedding = network.generator.attr_encoder(attr_image)

        feature_input = torch.concat([id_embedding, attr_embedding], -1)

        if feature_embedding:
            pose_sp_embedding = network.generator.reference_pose_encoder(
                feature_input)
            pose_control_vector = network.generator.reference_pose_decoder(
                pose_sp_embedding)
            expression_sp_embedding = network.generator.reference_expression_encoder(
                feature_input)
            expression_control_vector = network.generator.reference_expression_decoder(
                expression_sp_embedding)
            base_control_vector = 0
        else:
            base_control_vector = network.generator.reference_network(
                feature_input)
            pose_control_vector = expression_control_vector = 0

        control_vector = base_control_vector + \
            pose_control_vector + expression_control_vector

        gen_image = network.generator.stylegan_generator.synthesis(
            ws, control_vector)

        out_image = torch.concat((attr_image, gen_image, id_image), -1)

    return gen_image, out_image, control_vector

In [None]:
id_image_path = "image/4000/3000.png"
# id_image_path = "image/1000/877.png"
attr_image_path_list = ["image/1000/964.png", "image/1000/865.png",
                        "image/1000/877.png", "image/1000/14.png", "image/1000/26.png"]
gen_image, out_image, control_vector = evaluation(
    network, id_img_path, attr_image_path_list, args.feature_embedding)

In [None]:
TF.to_pil_image(((out_image[0]+1)/2).clamp(0, 1))

In [None]:
torch.topk(control_vector[0].abs(), 10).values, torch.topk(
    control_vector[0].abs(), 10).indices

In [None]:
torch.topk(control_vector[1].abs(), 10).values, torch.topk(
    control_vector[1].abs(), 10).indices

In [None]:
torch.topk(control_vector[2].abs(), 10).values, torch.topk(
    control_vector[2].abs(), 10).indices

In [None]:
torch.topk(control_vector[3].abs(), 10).values, torch.topk(
    control_vector[3].abs(), 10).indices

In [None]:
torch.topk(control_vector[4].abs(), 10).values, torch.topk(
    control_vector[4].abs(), 10).indices

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

channels = list(range(4928))
values = control_vector[0].tolist()

ax.bar(channels, values)

plt.show()

In [None]:
import matplotlib.pyplot as plt

fig, ax = plt.subplots()

channels = list(range(4928))
values = control_vector[1].squeeze().cpu().numpy()

ax.bar(channels, values)

plt.show()

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048)

# real_image = cv.imread("./dataset/ffhq256_dataset/00000/00037.png")
# real_image = torch.from_numpy(real_image.transpose((2, 0, 1)))[None, ...]
# real_image = real_image.flip(-3)
# real_image = TF.resize(real_image, (299, 299))

real_image = ((gen_id_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8).cpu()
fake_image = TF.resize(real_image, (299, 299))
fake_image = ((gen_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8).cpu()
fake_image = TF.resize(fake_image, (299, 299))

fid.update(real_image, real=True)
fid.update(fake_image, real=False)
fid.compute()

In [None]:
torch.sqrt(torch.nn.functional.mse_loss(gen_id_image, gen_image)).item()

In [None]:
from torchmetrics.functional import peak_signal_noise_ratio as psnr_metric

psnr_metric(id_image, gen_image, data_range=1.0).item()