In [1]:
import math
import os
import torch
from PIL import Image
import numpy as np
from model import StyledGenerator
import tqdm
import matplotlib.pyplot as plt
from IPython.display import Video
import glob
import cv2
from natsort import natsorted # pip install natsort

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
## 生成图片
def make_image(tensor):
    return (
        tensor.detach()
        .clamp_(min=-1, max=1)
        .add(1)
        .div_(2)
        .mul(255)
        .type(torch.uint8)
        .permute(0, 2, 3, 1)
        .to("cpu")
        .numpy()
    )


In [3]:
# 加载模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## 载入模型
netG = StyledGenerator(512)
netG.load_state_dict(torch.load('./checkpoints/stylegan-1024px-new.model',map_location=device)["g_running"], strict=False)
netG.eval()
netG = netG.to(device)
step = int(math.log(1024, 2)) - 2

In [4]:
def generateMix(face1_name,face2_name,mode='z_mode',samples=30):
    ## 载入向量
    latent_in_1 = torch.from_numpy(np.load(face1_name + '.npy')).to(device)
    latent_in_2 = torch.from_numpy(np.load(face2_name + '.npy')).to(device)
    
    # 增加维度
    if len(latent_in_1.shape) == 1:
        latent_in_1 = latent_in_1.unsqueeze(0)
    if len(latent_in_2.shape) == 1:
        latent_in_2 = latent_in_2.unsqueeze(0)
        
    # 合成插值画面，并存储
    for i in tqdm.tqdm(range(samples)):
        lamda = float(i)/float(samples)
        
        ## 样式混合后的向量，z = lamda * z1 + (1-lamda) * z2
        new_latent = lamda * latent_in_1 + (1-lamda) * latent_in_2
        # 基于Z混合
        if mode=='z_mode':
            
            img_gen = netG([new_latent], step=step) ##生成的图片
            img_name = './output/interpolation_z/{}.png'.format(i)
        else:
            # 基于W混合
            img_gen = netG([latent_in_1],mean_style=new_latent,step=step,style_weight=0) ##生成的图片
            img_name = './output/interpolation_w/{}.png'.format(i)            
        
        img_ar = make_image(img_gen)
        pil_img = Image.fromarray(img_ar[0])
        
        pil_img.save(img_name) 

In [5]:
# 合称为视频
def animatePics(mode='z_mode',img_size = (512,512)):
        
    if mode == 'z_mode':
        video_name = './output/interpolation_z/output.mp4'
    else:
        video_name = './output/interpolation_w/output.mp4'        
    # 删除原来的视频文件
    if os.path.exists(video_name):
        os.unlink(video_name)
        
     # 获取文件列表
    if mode == 'z_mode':
        filelist = natsorted( glob.glob('./output/interpolation_z/*'))
    else:
        filelist = natsorted( glob.glob('./output/interpolation_w/*'))
            
        
    # 视频写入
    videoWriter = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'H264'), 25, img_size)

    for index in tqdm.tqdm(range(len(filelist))):
        # 读取图片
        img = cv2.imread(filelist[index])
        
        # resize
        img = cv2.resize(img,img_size)
        # 写入
        videoWriter.write(img)
        
    videoWriter.release()   
    
    return Video(video_name,embed=True)

In [6]:
# Z 模式
generateMix('./latents/z/fake3','./latents/z/fake2','z_mode',50)

100%|██████████████████████████████████████████████████████████████████████████████████| 50/50 [01:07<00:00,  1.34s/it]


In [8]:
animatePics('z_mode',(500,500))

In [None]:
# W 模式
generateMix('./latents/w/fake3','./latents/w/fake2','w_mode',50)

In [None]:
animatePics('w_mode',(500,500))

In [None]:
# Z转W模式
generateMix('./latents/z/related_w/fake3','./latents/z/related_w/fake2','w_mode',50)

In [None]:
animatePics('w_mode',(500,500))