In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="2"

import pickle
from pathlib import Path
import json
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.manifold import TSNE
import numpy as np
import functools

from torch_utils import misc
import dnnlib
import legacy
from projector import project
from classifiers.models import CNN_MNIST
device = 'cuda' if torch.cuda.is_available() else 'cpu'

path_results = Path.cwd().parent / 'results'
# path_results = Path('/home/jovyan/results')
# path_results = Path('/d/alecoz/projects/stylegan2-mnist-corrupted/results')
# path_results = Path('/home/jovyan/results_temp')

In [None]:
def generator_output_to_classifier_input(images):
    
    lo, hi = [-1, 1] # generator scale
    images = images.cpu()
    images = (images - lo) * (255 / (hi - lo)) # classifier scale
    images = np.rint(images).clip(0, 255)
    images = images[:, :, 2:30, 2:30] # remove padding
    images = images.to(device)

    return images

def postprocess_images(img):
    # from generate.py
    if img.dim() == 4: # B x C x H x W
        img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    elif img.dim() == 3:  # C x H x W
        img = (img.permute(1, 2, 0) * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    return img


def generate_from_z(z):
    for i in np.arange(0, z.shape[0], batch_size):
        img = G(z[i:i+batch_size], c=None, noise_mode='const', force_fp32=True)
        if i == 0: 
            imgs = img
        else:
            imgs = torch.cat((imgs, img))
    return imgs


def plot_random_images(imgs):
    # from generate.py: img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    # imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    imgs = postprocess_images(imgs)
    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(vutils.make_grid(imgs[torch.randint(0, imgs.shape[0], (100,))].cpu(), pad_value=255, nrow=10).permute(1,2,0))

    
def truncate(x, x_avg, psi):
    # psi=0 means we get average value, 
    # psi=1 we get original value, 
    # 0<psi<1 we get interpolation between mean and original
    return x_avg.lerp(x, psi)

In [None]:
# path_model = Path("/home/jovyan/results/stylegan2-training-runs/00014-mnist_stylegan2-cond-auto1-dim64")
# path_model = Path("/home/jovyan/results/stylegan2-training-runs/00015-mnist_stylegan2-auto1-dim64")
# path_model = Path("/home/jovyan/results/stylegan2-training-runs/00016-mnist_stylegan2-cond-auto2-dim64_FAILED")

# path_model = path_results / 'stylegan2-training-runs' / '00010-mnist_stylegan2-auto2-dim512'
# path_model = path_results / 'stylegan2-training-runs' / '00014-mnist_stylegan2-cond-auto1-dim64'
# path_model = path_results / 'stylegan2-training-runs' / '00015-mnist_stylegan2-auto1-dim64'
# path_model = path_results / 'stylegan2-training-runs' / '00016-mnist_stylegan2-cond-auto2-dim64_FAILED'
# path_model = path_results / 'stylegan2-training-runs' / '00023-mnist_stylegan2_noise_blur-cond-auto2-dim512'
# path_model = path_results / 'stylegan2-training-runs' / '00024-mnist_stylegan2_noise_blur-auto2'
# path_model = path_results / 'stylegan2-training-runs' / '00028-mnist_stylegan2_noise-cond-auto1-dim64'
# path_model = path_results / 'stylegan2-training-runs' / '00029-mnist_stylegan2_noise-auto2-dim64'


# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/brecahad.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl'

path_model = path_results / 'stylegan2-training-runs' / '00015-mnist_stylegan2_blur_noise-cond-auto4'
path_model = path_results / 'stylegan2-training-runs' / '00016-mnist_stylegan2_blur_noise_maxSeverity3_proba50-cond-auto4'

# find best model in folder
if not str(path_model).endswith('pkl'):
    with open(path_model / 'metric-fid50k_full.jsonl', 'r') as json_file:
        json_list = list(json_file)

    best_fid = 1e6
    for json_str in json_list:
        json_line = json.loads(json_str)
        if json_line['results']['fid50k_full'] < best_fid:
            best_fid = json_line['results']['fid50k_full']
            best_model = json_line['snapshot_pkl']
    print('Best FID: {:.2f} ; best model : {}'.format(best_fid, best_model))

    # remove all models except best
    models_to_delete = [m for m in path_model.glob('*.pkl') if m.name != best_model]
    for m in models_to_delete:
        m.unlink(missing_ok=True)

    path_model = path_model / best_model
    with open(path_model, 'rb') as f:
        G = pickle.load(f)['G_ema'].cuda()  # torch.nn.Module

else:
    with dnnlib.util.open_url(path_model) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)

if device == 'cpu': G.forward = functools.partial(G.forward, force_fp32=True)

conditional = G.c_dim > 0

# registor hooks to save intermediate values
intermediate_images_torgb = {}
def get_torgb(name):
    def hook(model, input, output):
        intermediate_images_torgb[name] = output.detach()
    return hook
intermediate_images_block = {}
def get_block_img(name):
    def hook(model, input, output):
        intermediate_images_block[name] = output[1].detach()
    return hook
for res in G.synthesis.block_resolutions:
    block = getattr(G.synthesis, f'b{res}')
    block.torgb.register_forward_hook(get_torgb(res))
    block.register_forward_hook(get_block_img(res))

In [None]:
n_images = 2
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,))
    c = F.one_hot(digits, G.c_dim, device=device)          
else:
    c = None
img = G(z, c)                           # NCHW, float32, dynamic range [-1, +1]

ws = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)
img = G.synthesis(ws, noise_mode='const', force_fp32=True)
img = postprocess_images(img)

# misc.print_module_summary(G, [z, c])

plt.figure()
plt.imshow(vutils.make_grid(img.permute(0, 3, 1, 2)).permute(1,2,0))
plt.axis('off')
plt.grid(False)
plt.title('random images')


z_avg = torch.zeros((1, G.z_dim), device=device)
img_avg = G(z_avg, c, noise_mode='const')
img_avg = postprocess_images(img_avg)

plt.figure()
plt.imshow(img_avg.squeeze(0).cpu())
plt.axis('off')
plt.grid(False)
plt.title('average image')

## Show intermediate layers outputs

In [None]:
# test having a constant input 4x4xdim "image"
# G.synthesis.b4.const.data =  torch.ones_like(G.synthesis.b4.const.data)

up = torch.nn.Upsample(size=[G.img_resolution, G.img_resolution])

z = torch.randn([1, G.z_dim], device=device)    # latent codes
_ = G(z, c, noise_mode='const')

intermediate_images_torgb = {k: up(v).cpu().squeeze() for k, v in intermediate_images_torgb.items()}
fig, axs = plt.subplots(1, len(intermediate_images_torgb), figsize=(20, 3))
fig.suptitle('torgb outputs')
for ax, (k, v) in zip(axs.ravel(), intermediate_images_torgb.items()):
    v = postprocess_images(v).squeeze(0)
    ax.imshow(v)
    ax.axis('off')
    ax.set_title(f'{k}x{k}')

intermediate_images_block = {k: up(v).cpu().squeeze() for k, v in intermediate_images_block.items()}
fig, axs = plt.subplots(1, len(intermediate_images_block), figsize=(20, 3))
fig.suptitle('block outputs')
for ax, (k, v) in zip(axs.ravel(), intermediate_images_block.items()):
    v = postprocess_images(v).squeeze(0)
    ax.imshow(v)
    ax.axis('off')
    ax.set_title(f'{k}x{k}')


# Sample latent codes and generate images

In [None]:
n_samples = 10000
# Sample gaussian noise and classes
zs = torch.randn([n_samples, G.z_dim], device='cuda')
if conditional:
    digits = torch.randint(0, 10, (n_samples,), device='cuda')
    c = F.one_hot(digits, 10)
else:
    c = None

# Get latent code in W
ws = G.mapping(zs, c, truncation_psi=0.5, truncation_cutoff=8)

# Generate images
batch_size = 1000
# imgs = []
for i in np.arange(0, n_samples, batch_size):
    imgs_new = G.synthesis(ws[i:i+batch_size], noise_mode='const', force_fp32=True)
    if i == 0: 
        imgs = imgs_new
    else:
        imgs = torch.cat((imgs, imgs_new))

ws = ws[:, 0, :] # remove repeated dimensions
# ws = ws.cpu()

## PCA in W space

In [None]:
pca = PCA(n_components=G.z_dim)
pca.fit(ws.cpu().numpy())
# stdev = np.dot(pca.components_, ws.T).std(axis=1) # why not same as pca.explained_variance_ ?

In [None]:
n_images = 1
z = torch.randn([n_images, G.z_dim]).cuda()    # latent codes
if conditional:
    digits = torch.tensor([9])#torch.randint(0, 10, (n_images,))
    c = F.one_hot(digits, 10).cuda()          # class labels (not used in this example)
else:
    c = None

idx_direction = 0
direction = pca.components_[idx_direction, :]
direction = torch.tensor(direction).cuda()
k = 10*np.array([-2, -1, -0.5, 0, 0.5, 1, 2])

images = []
for i in range(7):
    z_new = z + k[i]*direction
    img = G(z_new, c)
    img = postprocess_images(img).cpu()
    images.append(img)
    
images = torch.cat(images, 0)

plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(vutils.make_grid(images, pad_value=255, nrow=7).permute(1,2,0))

## Load and test classifiers

In [None]:
# predict digits
classifier_digits = CNN_MNIST(output_dim=10).to(device)
classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220209_1612.pth')) # DeepLab
# classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220210_1601.pth')) # Confiance

# predict noise
classifier_noise = CNN_MNIST(output_dim=6).to(device)
classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_noise_weights_20220210_1728.pth')) # Confiance


x = generator_output_to_classifier_input(imgs)
digit_pred = classifier_digits(x).argmax(dim=1).cpu()
noise_pred = classifier_noise(x).argmax(dim=1).cpu()

plt.figure(figsize=(15, 5))
for i in range(10):
    idx = np.random.randint(0, len(x))
    plt.subplot(1, 10, i+1)
    plt.imshow(x[idx].cpu().squeeze(), cmap='gray')
    plt.title(f'digit: {digit_pred[idx].numpy()} \n noise: {noise_pred[idx].numpy()}')
    plt.axis('off')

## t-SNE

In [None]:
zs_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(zs.cpu().numpy())
ws_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(ws.cpu().numpy())

In [None]:
# Z space
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
fig.suptitle('t-SNE in Z space')
scatter1 = axs[0].scatter(zs_embedded[:, 0], zs_embedded[:, 1], c=noise_pred, alpha=0.2, cmap='jet')
legend1 = axs[0].legend(*scatter1.legend_elements(), loc="lower left", title="Noise")
for lh in legend1.legendHandles: 
    lh.set_alpha(1)
axs[0].add_artist(legend1)
axs[0].axis('off')

scatter2 = axs[1].scatter(zs_embedded[:, 0], zs_embedded[:, 1], c=digit_pred, alpha=0.2, cmap='jet')
legend2 = axs[1].legend(*scatter2.legend_elements(),
                    loc="lower left", title="Digits")
for lh in legend2.legendHandles: 
    lh.set_alpha(1)
axs[1].add_artist(legend2)
axs[1].axis('off')


# W space
fig, axs = plt.subplots(1, 2, figsize=(15, 5))
fig.suptitle('t-SNE in W space')
scatter1 = axs[0].scatter(ws_embedded[:, 0], ws_embedded[:, 1], c=noise_pred, alpha=0.2, cmap='jet')
legend1 = axs[0].legend(*scatter1.legend_elements(), loc="lower left", title="Noise")
for lh in legend1.legendHandles: 
    lh.set_alpha(1)
axs[0].add_artist(legend1)
axs[0].axis('off')

scatter2 = axs[1].scatter(ws_embedded[:, 0], ws_embedded[:, 1], c=digit_pred, alpha=0.2, cmap='jet')
legend2 = axs[1].legend(*scatter2.legend_elements(),
                    loc="lower left", title="Digits")
for lh in legend2.legendHandles: 
    lh.set_alpha(1)
axs[1].add_artist(legend2)
axs[1].axis('off')

## Histograms Z and W space

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(15, 8))
for i in range(10):
    idx = np.random.randint(0, G.z_dim)
    axs[0,i].hist(zs[:, idx].cpu().numpy(), bins=100)
    axs[0,i].set(yticklabels=[])  # remove the tick labels
    axs[0,i].tick_params(left=False)  # remove the ticks
    axs[0,i].set_title(f'z_{idx}')
    axs[1,i].hist(ws[:, idx].cpu().numpy(), bins=100)
    axs[1,i].set(yticklabels=[])  # remove the tick labels
    axs[1,i].tick_params(left=False)  # remove the ticks
    axs[1,i].set_title(f'w_{idx}')

## Find noise direction in Z

In [None]:
# load noise classifier
classifier_noise = CNN_MNIST(output_dim=6).to(device)
classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_noise_weights_20220210_1728.pth')) # Confiance

x = generator_output_to_classifier_input(imgs)
predictions = classifier_noise(x).detach().cpu()
y_pred = np.argmax(predictions, 1)

# keep latent codes of images with noise = 1 and 5
z_noise_low = zs[y_pred == 1].squeeze()
z_noise_high = zs[y_pred == 5].squeeze()

avg_z_noise_low = torch.mean(z_noise_low, axis=0)
avg_z_noise_high = torch.mean(z_noise_high, axis=0)

# average of latent codes for noisy - average of latent codes for uncorrupted
noise_direction_z = avg_z_noise_high - avg_z_noise_low


# Show examples
nb_interp = 9
images = []
for i in range(5):
    z = torch.randn(1, G.z_dim, device=device)
    if conditional:
        digits = torch.tensor([8])#torch.randint(0, 10, (n_images,))
        c = F.one_hot(digits, 10).cuda()          # class labels (not used in this example)
    else:
        c = None

    for alpha in np.linspace(-1, 1, num=nb_interp):
        z_new = z + alpha*noise_direction_z
        # z_new = z_new / z_new.norm() * z.norm()
        img = G(z_new, c)
        img = postprocess_images(img).cpu()
        images.append(img)

images = torch.cat(images, 0)

plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(vutils.make_grid(images, pad_value=255, nrow=nb_interp).permute(1,2,0))

## Find noise direction in W

In [None]:
# load noise classifier
classifier_noise = CNN_MNIST(output_dim=6).to(device)
classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_noise_weights_20220210_1728.pth')) # Confiance

x = generator_output_to_classifier_input(imgs)
predictions = classifier_noise(x).detach().cpu()
y_pred = np.argmax(predictions, 1)

# keep latent codes of images with noise = 1 and 5
w_noise_low = ws[y_pred == 1].squeeze()
w_noise_high = ws[y_pred == 5].squeeze()

avg_w_noise_low = torch.mean(w_noise_low, axis=0)
avg_w_noise_high = torch.mean(w_noise_high, axis=0)

# average of latent codes for noisy - average of latent codes for uncorrupted
noise_direction_w = (avg_w_noise_high - avg_w_noise_low).cuda()

# Show examples
nb_interp = 9
images = []
for i in range(5):
    z = torch.randn(1, G.z_dim, device=device)
    if conditional:
        digits = torch.tensor([9])#torch.randint(0, 10, (n_images,))
        c = F.one_hot(digits, 10).cuda()          # class labels (not used in this example)
    else:
        c = None
    w = G.mapping(z, c, truncation_psi=0.5, truncation_cutoff=8)

    for alpha in np.linspace(-1, 1, num=nb_interp):
        w_new = w + alpha*noise_direction_w
        # w_new = w_new / w_new.norm() * w.norm()
        img = G.synthesis(w_new, noise_mode='const', force_fp32=True)
        img = postprocess_images(img).cpu()
        images.append(img)

images = torch.cat(images, 0)
# vutils.make_grid(images).permute(1,2,0)

plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(vutils.make_grid(images, pad_value=255, nrow=nb_interp).permute(1,2,0))

In [None]:
fig, axs = plt.subplots(2, 10, figsize=(15, 8))
fig.suptitle('histograms for low noise samples (blue) and high noise samples (red)')
for i in range(10):
    idx = np.random.randint(0, G.z_dim)
    axs[0,i].hist(z_noise_low[:, idx].cpu().numpy(), bins=100, density=True, color='tab:blue', alpha=0.5)
    axs[0,i].hist(z_noise_high[:, idx].cpu().numpy(), bins=100, density=True, color='tab:red', alpha=0.5)
    axs[0,i].set(yticklabels=[])  # remove the tick labels
    axs[0,i].tick_params(left=False)  # remove the ticks
    axs[0,i].set_title(f'z_{idx}')
    axs[1,i].hist(w_noise_low[:, idx].cpu().numpy(), bins=100, density=True, color='tab:blue', alpha=0.5)
    axs[1,i].hist(w_noise_high[:, idx].cpu().numpy(), bins=100, density=True, color='tab:red', alpha=0.5)
    axs[1,i].set(yticklabels=[])  # remove the tick labels
    axs[1,i].tick_params(left=False)  # remove the ticks
    axs[1,i].set_title(f'w_{idx}')

## Domain with low noise samples
https://fr.m.wikipedia.org/wiki/Calcul_de_l%27enveloppe_convexe

https://en.m.wikipedia.org/wiki/Extreme_point

https://en.m.wikipedia.org/wiki/Convex_combination

Convex combination of samples only results of samples very close to the mean

In [None]:
# x0 = np.array([[1, 1]])
# x1 = np.array([[2, 5]])
# x2 = np.array([[3, 4]])
# x3 = np.array([[2, 3]])


# X = np.concatenate((x0, x1, x2, x3))

# alpha = np.random.uniform(size=(100, len(X)))
# alpha = alpha / np.sum(alpha, axis=1, keepdims=1)

# x_cc = np.expand_dims(alpha[:, 0], 1)*x0 + np.expand_dims(alpha[:, 1], 1)*x1 + np.expand_dims(alpha[:, 2], 1)*x2 + np.expand_dims(alpha[:, 3], 1)*x3

# plt.figure()
# plt.scatter(X[:, 0], X[:, 1])
# plt.scatter(x_cc[:, 0], x_cc[:, 1])

In [None]:
# keep latent codes of images with noise = 1
z_noise_low = zs[noise_pred == 1].squeeze().to(device)

images = generate_from_z(z_noise_low)
plot_random_images(images)

low_noise_distrib = torch.distributions.normal.Normal(z_noise_low.mean(axis=0), z_noise_low.std(axis=0))
z_low_noise_samples = low_noise_distrib.sample((1000,))
z_low_noise_samples = truncate(z_low_noise_samples, z_noise_low.mean(axis=0), 0.5)
images = generate_from_z(z_low_noise_samples)
plot_random_images(images)

x = generator_output_to_classifier_input(images)
predictions = classifier_noise(x).detach().cpu()
y_pred = np.argmax(predictions, 1)

plt.figure()
sns.countplot(x=y_pred.numpy())

In [None]:
z_noise_high = zs[noise_pred == 5].squeeze().to(device)

images = generate_from_z(z_noise_high)
plot_random_images(images)


high_noise_distrib = torch.distributions.normal.Normal(z_noise_high.mean(axis=0), z_noise_high.std(axis=0))
z_high_noise_samples = high_noise_distrib.sample((1000,))
z_high_noise_samples = truncate(z_high_noise_samples, z_noise_high.mean(axis=0), 1)
images = generate_from_z(z_high_noise_samples)
plot_random_images(images)

x = generator_output_to_classifier_input(images)
predictions = classifier_noise(x).detach().cpu()
y_pred = np.argmax(predictions, 1)

plt.figure()
sns.countplot(x=y_pred.numpy())


## Truncation

In [None]:
n_images = 5
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,))
    c = F.one_hot(digits, G.c_dim, device=device)          
else:
    c = None

ws = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :]
w_avg = G.mapping.w_avg

# Show examples
psi_max = 2
nb_interp = 9
images = []
for i in range(n_images):
    for psi in np.linspace(-psi_max, psi_max, num=nb_interp):
        w_new = w_avg + psi*(ws[i, :]-w_avg)
        w_new = w_new.repeat(1, G.synthesis.num_ws, 1)
        img = G.synthesis(w_new, noise_mode='const', force_fp32=True)
        img = postprocess_images(img).cpu()
        images.append(img)

images = torch.cat(images, 0)

plt.figure(figsize=(15,15))
plt.axis("off")
plt.imshow(vutils.make_grid(images.permute((0, 3, 1, 2)), pad_value=255, nrow=nb_interp).permute(1,2,0))

In [None]:
psi_max = 10
nb_interp = 30
images = []
idx = torch.randint(len(ws), (1000,))
w = ws[idx, :]
first = True
for psi in np.linspace(-psi_max, psi_max, num=nb_interp):
    w_new_temp = w_avg + psi*(w-w_avg)
    img = G.synthesis(w_new_temp.unsqueeze(1).repeat(1, 8, 1), noise_mode='const', force_fp32=True)
    # img = postprocess_image(img).cpu()
    if first: 
        images = img
        w_new = w_new_temp
        first = False
    else:
        images = torch.cat((images, img))
        w_new = torch.cat((w_new, w_new_temp))


In [None]:
x = generator_output_to_classifier_input(images)
predictions = classifier_digits(x).detach().cpu()

In [None]:
dist_from_center = torch.cdist(w_new, w_avg.unsqueeze(0)).squeeze().cpu()
confidence = F.softmax(predictions, dim=1).max(dim=1).values.cpu()

plt.figure()
plt.scatter(dist_from_center, confidence, alpha=0.01)

## Convex combinations

In [None]:
def convex_combination(alpha, X):
    
    assert alpha.shape[0] == X.shape[0]
    x_new = 0
    for i in range(X.shape[0]):
        x_new += alpha[i] * X[i, :]
        
    return x_new

In [None]:
corner_cases = np.random.randint(0, 10, (100, 2))

x_cc = np.zeros((100, 2))
for i in range(100):
    alpha = np.random.uniform(size=(len(corner_cases)))
    alpha = alpha / np.sum(alpha)
    # alpha = 0.1 * alpha
    x_cc[i, :] = convex_combination(alpha, corner_cases)
    
plt.figure()
plt.scatter(x_cc[:, 0], x_cc[:, 1])
plt.scatter(corner_cases[:, 0], corner_cases[:, 1])

In [None]:
np.random.uniform(size=(50, 10))

In [None]:
from scipy.spatial import ConvexHull
hull = ConvexHull(np.random.uniform(size=(100, 5)))
hull.vertices

In [None]:
nb_img = 100
corner_cases = z_noise_low.cpu()
z_new = np.zeros((nb_img, G.z_dim))
for i in range(nb_img):
    alpha = np.random.uniform(size=(len(corner_cases)))
    alpha = alpha / np.sum(alpha)
    alpha = alpha
    z_new[i, :] = convex_combination(alpha, corner_cases)
z_new = torch.from_numpy(z_new).to(device)

# Generate images
images = generate_from_z(z_new)
plot_random_images(images)

## Project

In [None]:
def run_projection(
    network_pkl: str,
    target_fname: str,
    outdir: str,
    save_video: bool,
    seed: int,
    num_steps: int
):
    """Project given image to the latent space of pretrained network pickle.

    Examples:

    \b
    python projector.py --outdir=out --target=~/mytargetimg.png \\
        --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl
    """
    np.random.seed(seed)
    torch.manual_seed(seed)

    # Load networks.
    print('Loading networks from "%s"...' % network_pkl)
    device = torch.device('cuda')
    with dnnlib.util.open_url(network_pkl) as fp:
        G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore

    # Load target image.
    target_pil = PIL.Image.open(target_fname).convert('RGB')
    w, h = target_pil.size
    s = min(w, h)
    target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
    target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
    target_uint8 = np.array(target_pil, dtype=np.uint8)

    # Optimize projection.
    start_time = perf_counter()
    projected_w_steps = project(
        G,
        target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
        num_steps=num_steps,
        device=device,
        verbose=True
    )
    print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

    # Render debug output: optional video and projected image and W vector.
    os.makedirs(outdir, exist_ok=True)
    if save_video:
        video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
        print (f'Saving optimization progress video "{outdir}/proj.mp4"')
        for projected_w in projected_w_steps:
            synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
            synth_image = (synth_image + 1) * (255/2)
            synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
            video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
        video.close()

    # Save final projected frame and W vector.
    target_pil.save(f'{outdir}/target.png')
    projected_w = projected_w_steps[-1]
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
    np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())

In [None]:
# # Load networks.
# network_pkl = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
# print('Loading networks from "%s"...' % network_pkl)
# device = torch.device('cuda')
# with dnnlib.util.open_url(network_pkl) as fp:
#     # G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
#     G = pickle.load(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore
        
        
#         # Load target image.
# target_pil = PIL.Image.open('face.png').convert('RGB')
# w, h = target_pil.size
# s = min(w, h)
# target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2))
# target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS)
# target_uint8 = np.array(target_pil, dtype=np.uint8)

# # Optimize projection.
# # start_time = perf_counter()
# projected_w_steps = project(
#     G,
#     target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable
#     # num_steps=num_steps,
#     device=device,
#     verbose=True
# )
# # print (f'Elapsed: {(perf_counter()-start_time):.1f} s')

In [None]:
# url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt'
# with dnnlib.util.open_url(url) as f:
#     print('bla')
#     print(f)
#     # vgg16 = torch.jit.load(f).eval().to(device)

In [None]:
# import dnnlib

In [None]:
# with dnnlib.util.open_url(url) as f:
#     vgg16 = torch.jit.load(f)

In [None]:
import torch

In [None]:
vgg16.eval().to(torch.device('cuda'))