<a href="https://colab.research.google.com/github/sdhnshu/Fusion-Vision/blob/dev/notebooks/explore_latent_space.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!git clone https://github.com/sdhnshu/Fusion-Vision.git
!pip install ninja

%cd /content/Fusion-Vision
!git checkout dev
!git pull

%cd /content/Fusion-Vision/ganspace
import torch
from tqdm import tqdm
import numpy as np
import os
from PIL import Image
from models import get_instrumented_model
from decomposition import get_or_compute
from config import Config
from IPython.utils import io
from IPython.display import display
from IPython.display import Javascript
from ipywidgets import fixed
import ipywidgets as widgets

torch.autograd.set_grad_enabled(False)
torch.backends.cudnn.benchmark = True

/content/Fusion-Vision/ganspace


  "Distutils was imported before Setuptools. This usage is discouraged "


# Select the model and run PCA

In [4]:
config = Config(
  model='StyleGAN2',
  layer='style',
  batch_mode=True,
  batch_size=5000,
  use_w=True,
  output_class='ffhq',
  components=80,
  n=1_000_000,
)
inst = get_instrumented_model(config.model, config.output_class,
                              config.layer, torch.device('cuda'), 
                              use_w=config.use_w)
model = inst.model
comps = np.load(get_or_compute(config, inst))
latent_dirs = []

for item in comps.files:
    if item == 'lat_comp':
        for i in range(comps[item].shape[0]):
            latent_dirs.append(comps[item][i])

"The first ten or so principal components, such as head rotation (E(u1, 0-2)) and lightness/background (E(u8, 5)), operate well in the range [−2...2], beyond which the image becomes unrealistic. In contrast, face roundness (E(u37, 0-4)) can work well in the range [−20...20], when using 0.7 as the truncation parameter."

-- Ganspace 3.4 Layer-wise Edits

# Generate samples in each component

In [6]:
seed = 5
scale = 1
model.truncation = 0.7
no_of_seeds = 3
w = model.sample_latent(no_of_seeds, seed=seed).cpu().numpy()
w = [w]*model.get_max_latents()
dir = f'out/directions/{config.model}-{config.output_class}'
os.makedirs(dir, exist_ok=True)

# Loop over each component obtained from the PCA
for dir_idx in tqdm(range(len(latent_dirs)), desc='Visualizing components'):
    direction = latent_dirs[dir_idx]
    if dir_idx <= 10:
        distances = [-2,-1,0,1,2]
    elif 10 < dir_idx <= 20:
        distances = [-5,-2,0,2,5]
    elif 20 < dir_idx:
        distances = [-20,-10,0,10,20]

    # Modify coarse, middle, fine layers
    for start,end in [(0,4), (4,8), (8,18)]:
        
        canvas_image = Image.new("RGB", (2000, 1200), 'white')
        
        # Make 5 different images for each seed
        for i, distance in enumerate(distances):
            latent = w.copy() 
            
            # Modify the batch of style vectors for all seeds
            for l in range(start, end):
                latent[l] = latent[l] + np.repeat(direction, no_of_seeds, 
                                                  axis=0) * distance * scale
            
            # Compute latents in a batch
            out = model.sample_np(latent)
            
            # Put them in the right spot of the canvas
            for j in range(no_of_seeds):
                img = Image.fromarray((out[j] * 255).astype(np.uint8)).resize((400,400),Image.LANCZOS)
                canvas_image.paste(img, (400*i, 400*j))
        
        canvas_image.save(f'{dir}/comp{dir_idx}layers{start}-{end}.jpg')

Visualizing components: 100%|██████████| 80/80 [03:52<00:00,  2.90s/it]


In [None]:
# Zip for download
!zip -r out/ffhq-directions.zip out/directions/StyleGAN2-ffhq

Go through all the samples, note the interesting controls and which components and layers they belong to (controls are mixed and belong to various layers). Then use the UI below to extract the controls and save them as .npy files

# Explore the components using a UI

In [8]:
def name_direction(sender):
    dir = f'out/directions/{config.model}-{config.output_class}'
    os.makedirs(dir, exist_ok=True)
    np_arr = np.array([component, start_layer.value, end_layer.value])
    np.save(f'{dir}/c{component_no}-{name.value}.npy', np_arr, 
            allow_pickle=True, fix_imports=True)
    print(f'Direction saved as c{component_no}-{name.value}')
    
def display_sample_pytorch(seed, truncation, direction, distance, scale, start, end):
    with io.capture_output() as captured:
        w = model.sample_latent(1, seed=seed).cpu().numpy()
        model.truncation = truncation
        w = [w]*model.get_max_latents()
        for l in range(start, end):
            w[l] = w[l] + direction * distance * scale
        out = model.sample_np(w)
        final_im = Image.fromarray((out * 255).astype(np.uint8)).resize((500,500),Image.LANCZOS)
    display(final_im)

seed = np.random.randint(0,100000)
seed = widgets.IntSlider(min=0, max=100000, step=1, value=seed, description='Seed', 
                         continuous_update=False)
truncation = widgets.FloatSlider(min=0, max=2, step=0.1, value=0.7, description='Truncation', 
                                 continuous_update=False)
distance = widgets.FloatSlider(min=-10, max=10, value=0,description='Distance',
                               continuous_update=False)
scale = widgets.FloatSlider(min=0, max=10, value=1,description='Scale', continuous_update=False)
start_layer = widgets.IntSlider(min=0, max=18, step=1, value=0, description='Start Layer', 
                                continuous_update=False)
end_layer = widgets.IntSlider(min=0, max=18, step=1, value=18, description='End Layer', 
                              continuous_update=False)
name = widgets.Text(description="Name", width=200)

def update_range_start(*args):
    end_layer.min = start_layer.value
def update_range_end(*args):
    start_layer.max = end_layer.value
start_layer.observe(update_range_start, 'value')
end_layer.observe(update_range_end, 'value')

In [12]:
# Run this cell again if any issues with UI

# Load PCA component
component_no = 30
component = latent_dirs[component_no]

control = widgets.VBox([
widgets.HBox([seed, truncation, name]),
widgets.HBox([distance, scale]),
widgets.HBox([start_layer, end_layer]),
])
out = widgets.interactive_output(display_sample_pytorch, 
                                 {'seed': seed, 'truncation': truncation, 'distance': distance, 
                                  'scale': scale, 'start': start_layer, 'end': end_layer,
                                  'direction': fixed(component)})
display(control, out)

# Press enter after typing the name to save
name.on_submit(name_direction)

VBox(children=(HBox(children=(IntSlider(value=11873, continuous_update=False, description='Seed', max=100000),…

Output()