<a href="https://colab.research.google.com/github/cedro3/StyleCLIP/blob/main/StyleCLIP_global.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# セットアップ

In [None]:
%tensorflow_version 1.x
! pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git
! git clone https://github.com/cedro3/StyleCLIP.git

In [None]:
# input dataset name 
dataset_name='ffhq' # input dataset name, currently, only support ffhq

% cd StyleCLIP/global/

# input prepare data 
!python GetCode.py --dataset_name $dataset_name --code_type 'w'
!python GetCode.py --dataset_name $dataset_name --code_type 's'
!python GetCode.py --dataset_name $dataset_name --code_type 's_mean_std'

import tensorflow as tf
import numpy as np 
import torch
import clip
from PIL import Image
import pickle
import copy
import matplotlib.pyplot as plt
from MapTS import GetFs,GetBoundary,GetDt
from manipulate import Manipulator
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

M=Manipulator(dataset_name='ffhq')
fs3=np.load('./npy/ffhq/fs3.npy')
np.set_printoptions(suppress=True)

# 画像編集

In [None]:
# --- 画像の選択 ---

pt_folder = 'vec/'
pt_name = '009.pt' #@param {type:"string"}
latents=torch.load(pt_folder+pt_name)
w_plus=latents.cpu().detach().numpy()
M.dlatents=M.W2S(w_plus)

img_indexs=[0]
dlatent_tmp=[tmp[img_indexs] for tmp in M.dlatents]
M.num_images=len(img_indexs)

M.alpha=[0]
M.manipulate_layers=[0]
codes,out=M.EditOneC(0,dlatent_tmp) 
original=Image.fromarray(out[0,0]).resize((512,512))
M.manipulate_layers=None
original

In [None]:
# --- テキスト入力 ---

neutral='face with hair' #@param {type:"string"}
target='smiling face with curly hair' #@param {type:"string"}
classnames=[target,neutral]
dt=GetDt(classnames,model)

In [None]:
# --- alpha & beta の設定 ---

beta = 0.1 #@param {type:"slider", min:0.08, max:0.3, step:0.01}
alpha = 2.5 #@param {type:"slider", min:-10, max:10, step:0.1}
M.alpha=[alpha]
boundary_tmp2,c=GetBoundary(fs3,dt,M,threshold=beta)
codes=M.MSCode(dlatent_tmp,boundary_tmp2)
out=M.GenerateImg(codes)
generated=Image.fromarray(out[0,0])#.resize((512,512))

plt.figure(figsize=(14,7), dpi= 100)
plt.subplot(1,2,1)
plt.imshow(original)
plt.title('original')
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(generated)
plt.title('manipulated')
plt.axis('off')

In [None]:
# --- 編集画像の連続生成 ---

max_alpha = 2.5 #@param {type:"slider", min:0, max:10, step:0.1}
num = int(max_alpha*10)
beta = 0.1

from tqdm import trange
import os
import shutil

# pic フォルダーリセット
if os.path.isdir('pic'):
     shutil.rmtree('pic')
os.makedirs('pic', exist_ok=True)

# 画像生成関数
def generate_img(alpha, cnt):
     M.alpha=[alpha]
     boundary_tmp2,c=GetBoundary(fs3,dt,M,threshold=beta)
     codes=M.MSCode(dlatent_tmp,boundary_tmp2)
     out=M.GenerateImg(codes)
     pic = Image.fromarray(out[0,0])
     pic = pic.resize((512,512))  

     # orjinal と連結
     dst = Image.new('RGB', (original.width + pic.width, original.height))
     dst.paste(original, (0,0))
     dst.paste(pic, (original.width, 0))

     dst.save('./pic/'+str(cnt).zfill(6)+'.png')   

# alpha を変化させて画像生成
cnt = 0
for i in trange(15, desc='alpha = 0'):
     generate_img(0, cnt)
     cnt +=1

for i in trange(0, num, 1, desc='alpha = 0 -> max'):
     generate_img(i/10, cnt)
     cnt +=1

for i in trange(60, desc='alpha = max'):
     generate_img(num/10, cnt)
     cnt +=1
     

In [None]:
# --- mp4動画の作成 ---

# 既に output.mp4 があれば削除する
import os
if os.path.exists('./output.mp4'):
   os.remove('./output.mp4')

# pic フォルダの画像から動画を生成
! ffmpeg -r 30 -i pic/%6d.png\
               -vcodec libx264 -pix_fmt yuv420p output.mp4

# movieフォルダへ名前を付けてコピー
import shutil
os.makedirs('movie', exist_ok=True)
shutil.copy('output.mp4', 'movie/'+target+'_'+pt_name[:-3]+'.mp4')

In [None]:
# --- mp4動画の再生 ---

from IPython.display import HTML
from base64 import b64encode

mp4 = open('./output.mp4', 'rb').read()
data_url = 'data:video/mp4;base64,' + b64encode(mp4).decode()
HTML(f"""
<video width="70%" height="70%" controls>
      <source src="{data_url}" type="video/mp4">
</video>""")