In [None]:
#@title Setup model

import os
os.chdir('/content')

!unzip stylegan2.zip

!wget https://github.com/ninja-build/ninja/releases/download/v1.8.2/ninja-linux.zip
!sudo unzip ninja-linux.zip -d /usr/local/bin/
!sudo update-alternatives --install /usr/bin/ninja ninja /usr/local/bin/ninja 1 --force

In [4]:
#@title import modules

import numpy as np
from PIL import Image
import torch
import torchvision.transforms as transforms

%pip install -q ipywidgets
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

from stylegan2.model import Generator

In [5]:
#download checkpoint
#@title Setup files downloader
from pydrive.auth import GoogleAuth
from pydrive.drive import GoogleDrive
from google.colab import auth
from oauth2client.client import GoogleCredentials

download_with_pydrive = True 

CODE_DIR = ''
class Downloader(object):
    def __init__(self, use_pydrive):
        self.use_pydrive = use_pydrive
        current_directory = os.getcwd()
        self.save_dir = os.path.join(os.path.dirname(current_directory), CODE_DIR, "pretrained_models")
        os.makedirs(self.save_dir, exist_ok=True)
        print(self.save_dir)
        if self.use_pydrive:
            self.authenticate()

    def authenticate(self):
        auth.authenticate_user()
        gauth = GoogleAuth()
        gauth.credentials = GoogleCredentials.get_application_default()
        self.drive = GoogleDrive(gauth)

    def download_file(self, file_id, file_name):
        file_dst = f'{self.save_dir}/{file_name}'
        if os.path.exists(file_dst):
            print(f'{file_name} already exists!')
            return
        if self.use_pydrive:
            downloaded = self.drive.CreateFile({'id':file_id})
            downloaded.FetchMetadata(fetch_all=True)
            downloaded.GetContentFile(file_dst)
        else:
            !gdown --id $file_id -O $file_dst

downloader = Downloader(download_with_pydrive)
id = '1EM87UquaoQmk17Q8d5kYIAHqu0dkYqdT'
file_name = 'stylegan2-ffhq-config-f.pt'

downloader.download_file(id, file_name)


/pretrained_models


In [50]:
class Generator_Wrapper:
  def __init__(self):
    self.model = Generator(1024, 512, 8).to('cuda')
    self.init_weights()
    self.face_pool = torch.nn.AdaptiveAvgPool2d((256, 256))
  def init_weights(self):
    self.model.load_state_dict(torch.load('../pretrained_models/stylegan2-ffhq-config-f.pt')['g_ema'])

  def generate_image(self, input = None, input_is_latent = False, num_images = 1):
    if input is None:
      input = torch.randn((num_images, 512)).cuda()
    
    with torch.no_grad():
      out, latents = self.model([input], input_is_latent = input_is_latent, return_latents = True)
      out = self.face_pool(out)
      return out, latents

  def tensor2image(self, out):
    #TODO: given tensor of size (B, 3, 256, 256) in the range(-1, 1) return a PIL Image with size (256, 256 * B, 3)
    

In [53]:
G = Generator_Wrapper()
images, latents = G.generate_image(num_images = 4)

In [None]:
G.tensor2image(images)

In [None]:
#@title Latent interpolation

#TODO: generate two images with latents and visualize

In [117]:
#interpolate two latents
def interpolate_latents(interpolate_scale):
  #TODO: interpolate between two latents generated above. interpolate_scale is the interpolation coefficient

  res = Image.fromarray(np.zeros((256, 256, 3)).astype('uint8'))
  return res

In [None]:
#make slider and show
interact(interpolate_latents, interpolate_scale=widgets.FloatSlider(0.5, min=0, max=1., step=0.05))

In [None]:
#@title Mixing latents

#TODO: make a base image with latents and show

In [None]:
#TODO: make 3 new faces. They will be edited using base latent

In [111]:
def mix_latents(mixing_point):
  #TODO: mix latents. the first mixing_point blocks should get the latent from base_latent and the rest come from target_latents
  #edited_latents = ...
  edited_images, _ = G.generate_image(edited_latents, input_is_latent = True)
  return G.tensor2image(edited_images)
 

In [None]:
interact(mix_latents, mixing_point=widgets.IntSlider(0, min=0, max=18, step=1))