In [None]:
import numpy as np
import pandas as pd
from sklearn.svm import SVC, LinearSVC
from sklearn.preprocessing import Binarizer
from torchvision.utils import make_grid
from torchvision import transforms
from PIL import Image

In [None]:
# Mount G-Drive
from google.colab import drive
drive.mount('/content/drive', force_remount=True)

In [None]:
!git clone https://github.com/NVlabs/stylegan2-ada-pytorch.git
!pip install ninja

In [None]:
!ls /content/stylegan2-ada-pytorch/

In [None]:
import sys
sys.path.insert(0, "/content/stylegan2-ada-pytorch")
import pickle
import os
import numpy as np
import PIL.Image
from IPython.display import Image
import matplotlib.pyplot as plt
import IPython.display
import torch
import dnnlib
import legacy

def seed2vec(G, seed):
  return np.random.RandomState(seed).randn(1, G.z_dim)

def display_image(image):
  plt.axis('off')
  plt.imshow(image)
  plt.show()

def generate_image(G, z, truncation_psi):
    # Render images for dlatents initialized from random seeds.
    Gs_kwargs = {
        'output_transform': dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True),
        'randomize_noise': False
    }
    if truncation_psi is not None:
        Gs_kwargs['truncation_psi'] = truncation_psi

    label = np.zeros([1] + G.input_shapes[1][1:])
    images = G.run(z, label, **G_kwargs) # [minibatch, height, width, channel]
    return images[0]

def get_label(G, device, class_idx):
  label = torch.zeros([1, G.c_dim], device=device)
  if G.c_dim != 0:
      if class_idx is None:
          ctx.fail('Must specify class label with --class when using a conditional network')
      label[:, class_idx] = 1
  else:
      if class_idx is not None:
          print ('warn: --class=lbl ignored when running on an unconditional network')
  return label

def generate_image(device, G, z, truncation_psi=1.0, noise_mode='const', class_idx=None):
  z = torch.from_numpy(z).to(device)
  label = get_label(G, device, class_idx)
  img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode)
  img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
  #PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png')
  return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')

In [None]:
#Loading pre-trained StyleGAN2-ada
URL = "https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada/pretrained/ffhq.pkl"

print(f'Loading networks from "{URL}"...')
device = torch.device('cuda')
with dnnlib.util.open_url(URL) as f:
    G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore

In [None]:
f = open('/content/drive/My Drive/gan_latent_dataset/latent_vector.txt')

In [None]:
m = 25000 # number of training observations

In [None]:
X = np.zeros((m,512), dtype = float) # array to store features

In [None]:
for i in range(m):
    s = f.readline()
    index1 = s.index('[')
    index2 = s.index(']')
    l = s[index1+1:index2].split(',')
    l = list(map(float,l))
    X[i,:] = l

In [None]:
f.close()

In [None]:
X = pd.DataFrame(X)

In [None]:
y = np.load('/content/drive/My Drive/gan_latent_dataset/prob_image.npy')

In [None]:
y = pd.DataFrame(y.T, columns = ['no double chin','no eyeglasses','female','no moustache','smiling','age'])
#y = pd.DataFrame(y.T, columns = ['blond','smiling'])

In [None]:
y.head()

In [None]:
binarizer = Binarizer(threshold = 0.5)

In [None]:
def get_model(X, y):
    y = binarizer.fit_transform(y.reshape(m,1)).squeeze()
    svm_clf = LinearSVC(max_iter = 10000)
    svm_clf.fit(X, y)
    return svm_clf

In [None]:
def image_grid(imgs, rows, cols):
    assert len(imgs) == rows*cols

    w, h = imgs[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    grid_w, grid_h = grid.size
    
    for i, img in enumerate(imgs):
        grid.paste(img, box=(i%cols*w, i//cols*h))
    return grid

In [1]:
#SVM to get hyperplanes for each atrribute
svm_models = {} # model_name : model object

In [None]:
for col in y.columns:
    model = get_model(X.values, y[col].values)
    svm_models[col] = model

In [None]:
svm_models

In [None]:
ndc_model = svm_models['no double chin']
ne_model = svm_models['no eyeglasses']
ng_model = svm_models['no moustache']
nmo_model = svm_models['smiling']
female_model = svm_models['female']
nm_model = svm_models['age']

In [None]:
def shift_latent_vec(z, alpha, svm_model):
    
    return z + alpha * svm_model.coef_.squeeze()

In [None]:
z = X.iloc[74,:].values
img = generate_image(device, G, z.reshape(1,512))
display_image(img)

In [None]:
alphas = [i for i in range(0, 100, 10)]

l = []

for alpha in alphas:
  z_shifted = z - alpha * ne_model.coef_ 
  img = generate_image(device, G, z_shifted.reshape(1, 512))
  l.append(img)

In [None]:
from PIL import Image

In [None]:
grid = image_grid(l, 1, len(alphas))
grid