In [15]:
import torch, dnnlib
from general_utils import legacy
import torchvision.transforms.functional as TF
import numpy as np

from model import networks_stylegan2

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with dnnlib.util.open_url("./pretrained/stylegan2-ffhq-256x256.pkl") as f:
    data = legacy.load_network_pkl(f)
    G = data["G_ema"].to(device)

stylegan_generator = networks_stylegan2.Generator(**G.init_kwargs).to(device)
stylegan_generator.load_state_dict(G.state_dict())
stylegan_generator.eval()

label = torch.zeros([1, G.c_dim], device=device)
z = torch.from_numpy(np.random.RandomState(0).randn(1, G.z_dim)).to(device)
ctrlv = torch.zeros([1, 7424], device=device)
img = stylegan_generator(z, label, ctrlv)
img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8)


In [None]:
import os
from pathlib import Path
import random
import torch
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data.dataloader import DataLoader
from data_loader.gen_data import GeneratedDataset

import sys
import logging
from model.network import Network
from model.generator import Generator
import torchvision.transforms.functional as F

from writer import Writer
from trainer import Trainer
from general_utils import arglib
import matplotlib.pyplot as plt


class Args(object):
    def __init__(self):
        self.name = "exp01"
        self.resolution = 256
        self.load_checkpoint = False
        self.train = True
        self.dataset_path = Path("./dataset")
        self.results_dir = Path("./output/")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False
        self.weights_dir = Path("./output/exp01/weights")

args = Args()

def get_id_attr_sampler(dataset_length, cross_frequency):
    split = dataset_length // (cross_frequency + 1)
    indices = list(range(dataset_length))
    random.shuffle(indices)
    id_sampler = SubsetRandomSampler(indices[split:])
    attr_sampler = SubsetRandomSampler(indices[:split])
    return id_sampler, attr_sampler

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Writer.set_writer(args.results_dir)

id_model_path = "./pretrained/resnet50_scratch_weight.pkl"
stylegan_G_path = "./pretrained/ffhq.pkl"
landmarks_model_path = "./pretrained/3DDFA/phase1_wpdc_vdc.pth.tar"

network = Network(args=args, id_model_path=id_model_path, base_generator_path=stylegan_G_path,
                  landmark_model_path=landmarks_model_path, device=DEVICE)
# Dataset
train_dataset = GeneratedDataset(args, "train")
train_dataset_length = len(train_dataset)
train_id_sampler, train_attr_sampler = get_id_attr_sampler(train_dataset_length, 3)
train_id_loader = DataLoader(train_dataset, batch_size=6, sampler=train_id_sampler)
train_attr_loader = DataLoader(train_dataset, batch_size=6, sampler=train_attr_sampler)

id_images, id_zs = next(iter(train_id_loader))
attr_images = id_images

id_images = id_images.to(DEVICE)
id_zs = id_zs.to(DEVICE)
attr_images = attr_images.to(DEVICE)

# Identity embedding and attribute embedding
id_embedding = network.generator.id_encoder(id_images)
attr_embedding = network.generator.attr_encoder(attr_images)

# Attribute landmarks
attr_landmarks, attr_idx_list = network.generator.landmarks_detector(attr_images)

# Style+ embedding and control vector generation
feature_tag = torch.concat([id_embedding, attr_embedding], -1)
sp_embedding = network.generator.reference_encoder(feature_tag)
control_vector = network.generator.reference_decoder(sp_embedding)

gen_images = network.generator.stylegan_generator(id_zs, control_vector)

gen_images = F.resize(gen_images, (256, 256))

# Convert RGB to BGR to make the generated image compatible with the landmark detector
gen_images = gen_images.flip(-3)

pred_landmarks, gen_idx_list = network.generator.landmarks_detector(gen_images)

In [None]:
loss = pred_landmarks - attr_landmarks

In [None]:
loss.mean().backward()

In [None]:
list(network.generator.reference_decoder.parameters())[-1].grad

In [None]:
if len(gen_idx_list) > len(attr_idx_list):
    for e in gen_idx_list:
        if e not in attr_idx_list:
            gen_images = gen_images.narrow(0, e, 1).squeeze()
else:
    for e in attr_idx_list:
        if e not in gen_idx_list:
            attr_images = attr_images.narrow(e, 1, 1).squeeze()

In [None]:
from general_utils import general_utils
general_utils.save_images(id_images, attr_images, gen_images, attr_landmarks, pred_landmarks, 256, "./01.png", 3)

In [None]:
import matplotlib.pyplot as plt
img_ori = attr_images[1].detach().cpu().permute(1,2,0)
img_ori = (img_ori * 127.5 + 128).to(torch.uint8).clamp(0, 255).numpy()
pts68 = attr_landmarks[0].cpu().detach()
height, width = img_ori.shape[:2]
plt.imshow(img_ori[:,:,::-1])
#plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.axis('off')

style = "simple"

if not type(pts68) in [tuple, list]:
        pts = [pts68]
for i in range(len(pts)):
    if style == 'simple':
        plt.plot(pts[i][0, :], pts[i][1, :], 'o', markersize=2, color='r')
plt.show()

In [None]:
import matplotlib.pyplot as plt
img_ori = gen_images[1].detach().cpu().permute(1,2,0)
img_ori = (img_ori * 127.5 + 128).to(torch.uint8).clamp(0, 255).numpy()
pts68 = pred_landmarks[1].cpu().detach()
height, width = img_ori.shape[:2]
plt.imshow(img_ori[:,:,::-1])
#plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
plt.axis('off')

style = "simple"

if not type(pts68) in [tuple, list]:
        pts = [pts68]
for i in range(len(pts)):
    if style == 'simple':
        plt.plot(pts[i][0, :], pts[i][1, :], 'o', markersize=2, color='r')
plt.show()

In [None]:
import cv2
import numpy as np
import kornia as K
import torchvision.transforms.functional as F
from kornia.contrib import FaceDetector, FaceDetectorResult, FaceKeypoint


def scale_image(img: np.ndarray, size: int) -> np.ndarray:
    h, w = img.shape[:2]
    scale = 1.0 * size / w
    return cv2.resize(img, (int(w * scale), int(h * scale)))


img_raw = cv2.imread("./dataset/gen_dataset/data/1000/0.png", cv2.IMREAD_COLOR)
img_raw = scale_image(img_raw, 320)
img_vis = img_raw.copy()

img = K.image_to_tensor(img_raw, keepdim=False).cuda()
img = K.color.bgr_to_rgb(img.float())
face_detector = FaceDetector().cuda()

dets = face_detector(img)
dets = [FaceDetectorResult(o) for o in dets[0]]


def draw_keypoint(img: np.ndarray, det: FaceDetectorResult, kpt_type: FaceKeypoint) -> np.ndarray:
    kpt = tuple(det.get_keypoint(kpt_type).int().tolist())
    return cv2.circle(img, kpt, 2, (255, 0, 0), 2)


In [None]:
for b in dets:
    if b.score < 0.8:
        continue

    print(b.top_left.int().tolist(), b.bottom_right.int().tolist())

    img_vis = cv2.rectangle(img_vis, tuple(b.top_left.int().tolist()), tuple(
        b.bottom_right.int().tolist()), (0, 255, 0), 4)

    img_vis = draw_keypoint(img_vis, b, FaceKeypoint.EYE_LEFT)
    img_vis = draw_keypoint(img_vis, b, FaceKeypoint.EYE_RIGHT)
    img_vis = draw_keypoint(img_vis, b, FaceKeypoint.NOSE)
    img_vis = draw_keypoint(img_vis, b, FaceKeypoint.MOUTH_LEFT)
    img_vis = draw_keypoint(img_vis, b, FaceKeypoint.MOUTH_RIGHT)

    # draw the text score
    cx = int(b.xmin)
    cy = int(b.ymin + 12)
    img_vis = cv2.putText(
        img_vis, f"{b.score:.2f}", (cx, cy), cv2.FONT_HERSHEY_DUPLEX, 0.5, (255, 255, 255))

cv2.imwrite("./test.png", img_vis)


In [None]:
for p in G.attr_encoder.parameters():
    print(p.requires_grad)


In [None]:
G.attr_encoder(torch.randn((6, 3, 256, 256), requires_grad=True)).requires_grad


In [None]:
from data_loader.gen_data import FaceLandmarksDataset, Transforms
from torch.utils.data import DataLoader

train_dataset = FaceLandmarksDataset("train", transform=Transforms())
train_loader = DataLoader(train_dataset, 6, shuffle=True)


In [None]:
ref_images, ref_landmarks, ref_zs, ref_bboxs = next(iter(train_loader))


In [None]:
ref_images.shape


In [None]:
from model.id_encoder import ID_Encoder
import torch
from pathlib import Path


class Args(object):
    def __init__(self):
        self.resolution = 256
        self.load_checkpoint = False
        self.train = True
        self.dataset_path = Path("./dataset")
        self.train_data_size = 50000
        self.batch_size = 6
        self.reals = False
        self.test_real_attr = True
        self.train_real_attr = False


args = Args()
id_encoder = ID_Encoder(args, "./pretrained/resnet50_scratch_weight.pkl")


In [None]:
id_encoder(torch.randn((6, 3, 256, 256))).shape


In [None]:
id_encoder._train()


In [None]:
for name, param in id_encoder.base_model.named_parameters():
    print(name, param.requires_grad)
