### Import libraries

In [2]:
import pickle
import torch
import numpy as np
import matplotlib.pyplot as plt
import PIL
from tqdm import tqdm

### Load the model

##### G is the Generator and D is the Discriminator 

In [3]:
with open('./training-runs/00001-stylegan3-r-celeb+arcane_64x64/network-snapshot-001402.pkl', 'rb') as fID:
    model = pickle.load(fID)
    G = model['G_ema'].cuda()  # torch.nn.Module
    D = model['D'].cuda()  # torch.nn.Module



### Utility functions that calls different part of the generator

In [4]:
def latent2img(z, truncation_psi=1):
    c = None
    img = G(z, c, truncation_psi=truncation_psi)
    img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB')

def seed2latent(seed):
    return torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).cuda()

def seed2img(seed, truncation_psi=1):
    z = seed2latent(seed)
    return latent2img(z, truncation_psi)

### Manually select example imagesand generate the latent vector for them

In [7]:
seed_list_1 = [1, 4, 8, 20, 21, 26, 27]
seed_list_2 = [3, 12, 30, 31, 34, 35, 36]
num_seed_1 = len(seed_list_1)
num_seed_2 = len(seed_list_2)

latent_list_1 = [seed2latent(seed) for seed in seed_list_1]
latent_list_2 = [seed2latent(seed) for seed in seed_list_2]

### Define a funcion that takes different truncation_psi and generate style mixing matrix plot 

In [15]:
def style_transfer_matrix(truncation_psi, scl_fctr=1.5):

    fig, axes = plt.subplots(num_seed_1+1, num_seed_2+1, figsize=((num_seed_2+1)*scl_fctr, (num_seed_1+1)*scl_fctr))

    for i in range(num_seed_1+1):
        for j in range(num_seed_2+1):
            if i == 0 and j == 0:
                axes[i, j].axis('off')
                axes[i, j].set_aspect('equal')
                continue
            if i == 0:
                img = latent2img(latent_list_2[j-1], truncation_psi=truncation_psi)
            elif j == 0:
                img = latent2img(latent_list_1[i-1], truncation_psi=truncation_psi)
            else:
                img = latent2img((latent_list_1[i-1] + latent_list_2[j-1]) / 2, truncation_psi=truncation_psi)
            axes[i, j].imshow(img)
            axes[i, j].axis('off')
            axes[i, j].set_aspect('equal')
    plt.subplots_adjust(wspace=0.02, hspace=0.02)
    fig.suptitle('truncation psi = {0:.2f}'.format(truncation_psi), fontsize=16, y=0.91)
    fig.savefig('./results/changing_psi/style_transfer_matrix_truncpsi_{0:.2f}.png'.format(truncation_psi))
    plt.close(fig)

style_transfer_matrix(truncation_psi=0)

### Run the style_transfer_matrix with different psi values and save images to folder

In [16]:
for psi in tqdm(np.arange(0, 1, 0.01)):
    style_transfer_matrix(truncation_psi=psi)

100%|██████████| 100/100 [04:14<00:00,  2.54s/it]


### Use the following code to generate transitions from one image to another

In [38]:
truncation_psi = 0.5

latent_diff = latent_list_2[2] - latent_list_1[0]

steps = 36

latent_diff_delta = latent_diff / steps

for i in range(steps+1):
    latent = latent_list_1[0] + i * latent_diff_delta
    img = latent2img(latent, truncation_psi=truncation_psi)
    fig, ax = plt.subplots()
    ax.imshow(img)
    ax.axis('off')
    fig.savefig(f'./results/style_mixing/celeb+arcane_{str(i).zfill(2)}.png')
    plt.close(fig)