In [None]:
pip install git+https://github.com/openai/CLIP.git

In [None]:
!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
%cd stylegan2-ada-pytorch

In [None]:
import torch
import pandas as pd
import numpy as np
import pickle
import os
import zipfile
import clip

import matplotlib.image as mpimg
import matplotlib.pyplot as plt
from PIL import ImageEnhance, Image
import torchvision.utils as vutils
from torchvision.transforms import ToPILImage

In [None]:
dataset_name='afhqdog' #@param ['ffhq'] {allow-input: true}
# input dataset name 
os.makedirs('./model')

# if not os.path.isfile('./model/'+dataset_name+'.pkl'):
#         url='https://nvlabs-fi-cdn.nvidia.com/stylegan2/networks/'
#         name='stylegan2-'+dataset_name+'-config-f.pkl'
#         os.system('wget ' +url+name + '  -P  ./model/')
#         os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl')
if not os.path.isfile('./model/'+dataset_name+'.pkl'):
    url='https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/'
    name=dataset_name+'.pkl'
    os.system('wget ' +url+name + '  -P  ./model/ > /dev/null 2>&1')
    os.system('mv ./model/'+name+' ./model/'+dataset_name+'.pkl > /dev/null 2>&1')


In [None]:
# load the pretrianed afhqdog model
with open('./model/afhqdog.pkl', 'rb') as f:
    G = pickle.load(f)['G_ema']  # torch.nn.Module

In [None]:
# you can also download other pretrained stylegan models from nvlabs!
#!wget https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl -O metfaces.pkl

In [None]:
# transfer generator to cuda
device = 'cuda'
G = G.to(device)

In [None]:
#rm -rf ~/.cache/torch_extensions/*

In [None]:
# generate a image from the stylegan generator and keep the lantent code
seed = 9
torch.manual_seed(seed)  # set seed
z = torch.randn([1, G.z_dim]).to(device).to(torch.float32)   # latent codes
c = None                             # class labels (not used in this example)
with torch.no_grad():
    w = G.mapping(z, c,truncation_psi=0.7)
    img = G.synthesis(w)
    #img = G(z,c)

In [None]:
# show images
def show_tensor_images(image_tensor, num_images = 16, size=(3, 64, 64), nrows = 4):
    image_tensor = (image_tensor + 1)/2
    image_unflat = image_tensor.detach().cpu().clamp_(0, 1)
    image_grid = vutils.make_grid(image_unflat[:num_images], nrow = nrows, padding=0)
    plt.imshow(image_grid.permute(1,2,0).squeeze())
    plt.axis('off')
    plt.show()

In [None]:
show_tensor_images(img, num_images=1,size=(3,512,512))

In [None]:
img.shape

In [None]:
# clip loss (calculate the similarity between generated images and the target text.)
class clip_loss(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.model, self.preprocess = clip.load("ViT-B/32", device="cuda")
        self.upsample = torch.nn.Upsample(scale_factor=7)
        self.avg_pool = torch.nn.AvgPool2d(kernel_size=16)

    def forward(self, image, text):
        image = self.avg_pool(self.upsample(image))
        similarity = 1 - self.model(image, text)[0]/100
        return similarity

In [None]:
clipLoss = clip_loss()

In [None]:
# We can customize the specific editing we want.
text= 'A really happy dog face with mouth open' # [play with me] e.g. a relly sad face; a dog with blue eyes;
tokenized_text = clip.tokenize([text]).to(device).long()

lr_rampup = 0.05
LR = 0.1
epoch = 150
l2_lambda = 0.0025
save_intermediate_image_every = 1
result_dir = 'results'

In [None]:
import os
import math
import torchvision
from torch import optim


# The learning rate adjustment function.
def get_lr(t, initial_lr, rampdown=0.50, rampup=0.05):
    lr_ramp = min(1, (1 - t) / rampdown)
    lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
    lr_ramp = lr_ramp * min(1, t / rampup)

    return initial_lr * lr_ramp



text_inputs = tokenized_text
os.makedirs(result_dir, exist_ok=True)

# Initialize the latent vector to be updated.
w_star = w.detach().clone()
w_star.requires_grad = True

clipLoss = clip_loss()
optimizer = torch.optim.Adam([w_star], LR)

for i in range(epoch):
    # Adjust the learning rate.
    t = (i+1) / epoch
    lr = get_lr(t,LR)
    optimizer.param_groups[0]["lr"] = lr

    # Generate an image using the latent vector.
    img_gen= G.synthesis(w_star)

    # Calculate the loss value.
    c_loss = clipLoss(img_gen, text_inputs)
    l2_loss = ((w - w_star) ** 2).sum()
    loss = c_loss + l2_lambda * l2_loss
    # Get gradient and update the latent vector.
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

# Log the current state.
    print(f"lr: {lr}, loss: {loss.item():.4f}")
    if save_intermediate_image_every > 0 and i % save_intermediate_image_every == 0:
        with torch.no_grad():
            img_gen = G.synthesis(w_star)
        show_tensor_images(img_gen, num_images=1,size=(3,512,512))
        torchvision.utils.save_image(img_gen, f"./results/{str(i).zfill(5)}.png", normalize=True)

with torch.no_grad():
    img_orig = G.synthesis(w, force_fp32=True)

# Display the initial image and result image.
final_result = torch.cat([img_orig, img_gen])
torchvision.utils.save_image(final_result.detach().cpu(), os.path.join(result_dir, "final_result.jpg"), normalize=True, scale_each=True)


In [None]:
show_tensor_images(img_gen, num_images=1,size=(3,512,512))

In [None]:
show_tensor_images(img_orig, num_images=1,size=(3,512,512))

In [None]:
# generate a video
!ffmpeg -r 15 -i results/%05d.png -c:v libx264 -vf fps=25 -pix_fmt yuv420p out.mp4

In [None]:
# zip the output
import datetime
def file2zip(packagePath, zipPath):

    zip = zipfile.ZipFile(zipPath, 'w', zipfile.ZIP_DEFLATED)
    for path, dirNames, fileNames in os.walk(packagePath):
        fpath = path.replace(packagePath, '')
        for name in fileNames:
            fullName = os.path.join(path, name)
            name = fpath + '\\' + name
            zip.write(fullName, name)
    zip.close()


if __name__ == "__main__":
    # 文件夹路径
    packagePath = './results'
    zipPath = './output.zip'
    if os.path.exists(zipPath):
        os.remove(zipPath)
    file2zip(packagePath, zipPath)
    print(datetime.datetime.utcnow())