In [1]:
from pathlib import Path
from model.network import Network

import cv2 as cv
import numpy as np

import torch
import torchvision.transforms.functional as TF
from torchvision.io import read_image
from torchvision.utils import make_grid


class Args(object):
    def __init__(self):
        self.name = "exp_07"
        self.resolution = 256
        self.load_checkpoint = True
        self.train = False
        self.dataset_path = Path("./dataset/gen_dataset/")
        self.results_dir = Path("./output")
        self.pretrained_models_path = Path("./pretrained")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False
        self.weights_dir = self.results_dir.joinpath(self.name+"/weights")

args = Args()

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

id_model_path = str(args.pretrained_models_path.joinpath("resnet50_scratch_weight.pkl"))
stylegan_G_path = str(args.pretrained_models_path.joinpath("ffhq.pkl"))
landmarks_model_path = str(args.pretrained_models_path.joinpath('3DDFA/phase1_wpdc_vdc.pth.tar'))

network = Network(args=args, id_model_path=id_model_path, base_generator_path=stylegan_G_path, 
                  landmarks_detector_path=landmarks_model_path, device=DEVICE, load_chackpoint=args.load_checkpoint)

Attr_Encoder loads checkpoint from: output/exp_07/weights/Attr_Encoder.pth


AttributeError: 'ReferenceNetwork' object has no attribute '_load'

In [None]:
id_image = cv.imread(str(args.dataset_path.joinpath("image/1000/3.png")))
attr_image = cv.imread(str(args.dataset_path.joinpath("image/1000/29.png")))

id_image = torch.from_numpy(id_image.transpose((2, 0, 1))).float().to(DEVICE)
attr_image = torch.from_numpy(attr_image.transpose((2, 0, 1))).float().to(DEVICE)

TF.to_pil_image(make_grid([id_image.flip(-3).to(torch.uint8), attr_image.flip(-3).to(torch.uint8)]))

In [None]:
id_image.sub_(127.5).div_(128)
attr_image.sub_(127.5).div_(128)

with torch.no_grad():
    #z = torch.randn((6,512)).to(DEVICE).clamp(-1, 1)
    z = torch.from_numpy(np.load(str(args.dataset_path.joinpath("z/1000/3.npy")))).to(DEVICE)
    ws = network.generator.stylegan_generator.mapping(z, 0)
    ctrlv = torch.zeros((6, 6048)).to(DEVICE)
    
    gen_id_image = network.generator.stylegan_generator.synthesis(ws, ctrlv)
    gen_id_image = TF.resize(gen_id_image, (256, 256))
    
    id_embedding = network.generator.id_encoder(gen_id_image)
    attr_embedding = network.generator.attr_encoder(torch.broadcast_to(attr_image, [6, *attr_image.shape]))
    
    feature_tag = torch.concat([id_embedding, attr_embedding], -1)
    
    pose_sp_embedding = network.generator.reference_pose_encoder(feature_tag)
    pose_control_vector = network.generator.reference_pose_decoder(pose_sp_embedding)
    expression_sp_embedding = network.generator.reference_expression_encoder(feature_tag)
    expression_control_vector = network.generator.reference_expression_decoder(expression_sp_embedding)

    control_vector = pose_control_vector + expression_control_vector
    
    gen_image = network.generator.stylegan_generator.synthesis(ws, control_vector)
    gen_image = TF.resize(gen_image, (256, 256))

TF.to_pil_image(make_grid([((gen_id_image[0] + 1) / 2).clamp(0, 1), ((gen_image[0] + 1) / 2).clamp(0, 1)]))

In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance

fid = FrechetInceptionDistance(feature=2048)

# real_image = cv.imread("./dataset/ffhq256_dataset/00000/00037.png")
# real_image = torch.from_numpy(real_image.transpose((2, 0, 1)))[None, ...]
# real_image = real_image.flip(-3)
# real_image = TF.resize(real_image, (299, 299))

real_image = ((gen_id_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8).cpu()
fake_image = TF.resize(real_image, (299, 299))
fake_image = ((gen_image + 1) / 2 * 255).clamp(0, 255).to(torch.uint8).cpu()
fake_image = TF.resize(fake_image, (299, 299))

fid.update(real_image, real=True)
fid.update(fake_image, real=False)
fid.compute()

In [None]:
torch.sqrt(torch.nn.functional.mse_loss(gen_id_image, gen_image)).item()

In [None]:
from torchmetrics.functional import peak_signal_noise_ratio as psnr_metric

psnr_metric(gen_image, gen_id_image, data_range=1.0).item()