In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

import pickle
from pathlib import Path
import json
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
import torchvision
import imageio
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.manifold import TSNE
from scipy.stats import wasserstein_distance
import numpy as np
import functools
import pandas as pd
# from tqdm import tqdm


from stylegan2_ada_pytorch.torch_utils import misc
import stylegan2_ada_pytorch.dnnlib
import stylegan2_ada_pytorch.legacy
from stylegan2_ada_pytorch.projector import project
from stylegan2_ada_pytorch.training.dataset import ImageFolderDataset
from classifiers.models import CNN_MNIST

torch.manual_seed(0)
np.random.seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

path_results = Path.cwd().parent / 'results'
# path_results = Path('w:/results/stylegan2')

In [None]:
def postprocess_images(images):
    assert images.dim() == 4, "Expected 4D (B x C x H x W) image tensor, got {}D".format(images.dim())
    # lo, hi = [-1, 1] # generator scale
    # images = (images - lo) * (255 / (hi - lo)) # classifier scale
    # images = torch.round(images.clamp(0, 255))#.to(torch.uint8).to(torch.float)
    # images = (images * 127.5 + 128).clamp(0, 255)
    images = ((images + 1) / 2).clamp(0, 1)
    images = images[:, :, 2:30, 2:30] # remove padding

    return images

def plot_images(images, title=''):
    images = images * 255
    images = images.to(torch.uint8)
    plt.figure()
    plt.imshow(vutils.make_grid(images.cpu(), pad_value=255).permute(1,2,0), vmin=0, vmax=255)
    plt.axis('off')
    plt.grid(False)
    plt.title(title)


def generate_from_z(z):
    for i in np.arange(0, z.shape[0], batch_size):
        img = G(z[i:i+batch_size], c=None, noise_mode='const', force_fp32=True)
        if i == 0: 
            imgs = img
        else:
            imgs = torch.cat((imgs, img))
    return imgs


def plot_random_images(imgs):
    # from generate.py: img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    # imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    imgs = postprocess_images(imgs)
    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(vutils.make_grid(imgs[torch.randint(0, imgs.shape[0], (100,))].cpu(), pad_value=255, nrow=10).permute(1,2,0))

def plot_images_from_s(s):
    imgs = generate_img_from_s(s)
    imgs = postprocess_images(imgs)
    plot_images(imgs)
    
def truncate(x, x_avg, psi):
    # psi=0 means we get average value, 
    # psi=1 we get original value, 
    # 0<psi<1 we get interpolation between mean and original
    return x_avg.lerp(x, psi)


def styleSpace_dict2vec(styleSpace_dict):
    styleSpace_vec = []
    for res in G.synthesis.block_resolutions:
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            key = f'b{res}.{layer}'
            values = styleSpace_dict[key]
            if values.dim() == 1: values = values.unsqueeze(0)
            styleSpace_vec.append(values)
    styleSpace_vec = torch.cat(styleSpace_vec, dim=1)
    return styleSpace_vec


def styleSpace_vec2dict(styleSpace_vec):
    if styleSpace_vec.dim() == 1:
        styleSpace_vec = styleSpace_vec.unsqueeze(0)
    styleSpace_dict = {}
    dim_base = 0
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            dim_size = block_layer.affine.weight.shape[1]
            key = f'b{res}.{layer}'
            styleSpace_dict[key] = styleSpace_vec[:, dim_base:dim_base+dim_size]#.squeeze()
            dim_base += dim_size
    assert dim_base == styleSpace_vec.shape[1]
    return styleSpace_dict


def compute_styleSpace_vec_idx2coord():
    vec_idx2coord = {}
    idx = 0
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            dim_size = block_layer.affine.weight.shape[1]
            for dim in range(dim_size):
                vec_idx2coord[idx] = (f'b{res}.{layer}', dim)
                idx += 1
    return vec_idx2coord

In [None]:
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/brecahad.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl'

# path_model = path_results / 'stylegan2-training-runs' / '00011-mnist_stylegan2_noise-cond-auto4-original'
# path_model = path_results / 'stylegan2-training-runs' / '00015-mnist_stylegan2_blur_noise-cond-auto4'
path_model = path_results / 'stylegan2-training-runs' / '00016-mnist_stylegan2_blur_noise_maxSeverity3_proba50-cond-auto4'

# find best model in folder
if not str(path_model).endswith('pkl'): # local files
    with open(path_model / 'metric-fid50k_full.jsonl', 'r') as json_file:
        json_list = list(json_file)

    best_fid = 1e6
    for json_str in json_list:
        json_line = json.loads(json_str)
        if json_line['results']['fid50k_full'] < best_fid:
            best_fid = json_line['results']['fid50k_full']
            best_model = json_line['snapshot_pkl']
    print('Best FID: {:.2f} ; best model : {}'.format(best_fid, best_model))
    path_model = path_model / best_model

    with open(path_model, 'rb') as f:
        G = pickle.load(f)['G_ema'].to(device)  # torch.nn.Module

else: # download pre-trained
    with dnnlib.util.open_url(path_model) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)

if device == 'cpu': G.forward = functools.partial(G.forward, force_fp32=True)

conditional = G.c_dim > 0

# registor hooks to save intermediate values (images and style space)
intermediate_images_torgb = {}
def get_torgb(name):
    def hook(module, input, output):
        intermediate_images_torgb[name] = output.detach()
    return hook
intermediate_images_block = {}
def get_block_img(name):
    def hook(module, input, output):
        intermediate_images_block[name] = output[1].detach()
    return hook
styleSpace_values = {}
def get_styleSpace_values(name):
    def hook(module, input, output):
        styleSpace_values[name] = output.detach()
    return hook
for res in G.synthesis.block_resolutions:
    block = getattr(G.synthesis, f'b{res}')
    block.torgb.register_forward_hook(get_torgb(res))
    block.register_forward_hook(get_block_img(res))
    for layer in ['conv0', 'conv1', 'torgb']:
        if res == 4 and layer == 'conv0': continue
        block_layer = getattr(block, layer)
        block_layer.affine.register_forward_hook(get_styleSpace_values(name=f'b{res}.{layer}'))

        
# backward hooks to get gradients relative to styleSpace
styleSpace_grads = {}
def get_styleSpace_grads(name):
    def hook(self, grad_input, grad_output):
        styleSpace_grads[name] = grad_output[0].detach()
    return hook

for res in G.synthesis.block_resolutions:
    block = getattr(G.synthesis, f'b{res}')
    block.torgb.register_forward_hook(get_torgb(res))
    block.register_forward_hook(get_block_img(res))
    for layer in ['conv0', 'conv1', 'torgb']:
        if res == 4 and layer == 'conv0': continue
        block_layer = getattr(block, layer)
        block_layer.affine.register_full_backward_hook(get_styleSpace_grads(name=f'b{res}.{layer}'))

        
# dict to convert index to coordinate for stylespace vectors
styleSpace_vec_idx2coord = compute_styleSpace_vec_idx2coord()


# function to move a given style dimension
def generate_img_new_style(ws, block_layer_name, index=0, direction=1):
    def move_style(index, direction):
        def hook(module, input, output):
            output[:, index] += direction
            return output
        return hook

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, direction))

    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img
    
    
# function to modify a given style dimension
def generate_img_new_style2(ws, block_layer_name, index, s_style_min, s_style_max, s_shift=1, positive_direction=True):
    def move_style(index, weight_shift):
        def hook(module, input, output):
            output[:, index] += weight_shift
            return output
        return hook
    
    assert type(index) == int, 'Function only works for 1 style'
    assert ws.shape[0] == 1, 'Works only for 1 image' # orig_value only for 1 image
    
    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    
    with torch.no_grad():
        G.synthesis(ws, noise_mode='const', force_fp32=True) # first pass to get style vector from hook
    orig_value = styleSpace_values[block_layer_name][0, index]
    target_value = (s_style_max if positive_direction else s_style_min)
    weight_shift = s_shift * (target_value - orig_value)

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, weight_shift))
    
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img


# function to modify a given style dimension
def generate_img_new_style3(ws, block_layer_name, index, s_std, strength=5, positive_direction=True):
    def move_style(index, weight_shift):
        def hook(module, input, output):
            output[:, index] += weight_shift
            return output
        return hook
    
    assert type(index) == int, 'Function only works for 1 style'
    assert ws.shape[0] == 1, 'Works only for 1 image'
    
    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    
    d = 1 if positive_direction else -1
    weight_shift = d * strength * s_std

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, weight_shift))
    
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img


# function to generate image from S
def generate_img_from_s(s):
    def set_style(values):
        def hook(module, input, output):
            output = values
            return output
        return hook
    
    if type(s) != dict: s = styleSpace_vec2dict(s)
    assert s['b4.conv1'].dim() == 2, 'Should be of 2 dimensions: batch_size x s_dim'
    batch_size = s['b4.conv1'].shape[0]
    
    handles = []
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            values = s[f'b{res}.{layer}']
            handles.append(block_layer.affine.register_forward_hook(set_style(values)))
    
    dummy_ws = torch.zeros((batch_size, G.num_ws, G.w_dim), device=device)
    img = G.synthesis(dummy_ws, noise_mode='const', force_fp32=True)

    for h in handles: h.remove()
    
    return img

In [None]:
n_images = 5
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits, G.c_dim)          
else:
    c = None
misc.print_module_summary(G, [z, c])

ws = G.mapping(z, c, truncation_psi=1)
img = G.synthesis(ws, noise_mode='const', force_fp32=True)
img = postprocess_images(img)
plot_images(img, title='original')


img_ = generate_img_new_style(ws, block_layer_name='b4.conv1', index=3, direction=-1)
img_ = postprocess_images(img_)
plot_images(img_, title='with new style')

# Load classifier

In [None]:
# predict digits
classifier_digits = CNN_MNIST(output_dim=10).to(device)
# classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220411_0826.pth', map_location=device)) # Confiance
# classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220210_1601.pth', map_location=device))
classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_mnist_stylegan2_blur_noise_maxSeverity3_proba50_20220510_1124.pth', map_location=device))
classifier_digits.eval()

# predict noise
classifier_noise = CNN_MNIST(output_dim=6).to(device)
# classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_noise_MNIST_weights_20220411_0841.pth', map_location=device)) # Confiance
classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_noise_weights_20220210_1728.pth', map_location=device))
classifier_noise.eval()

imgs = G.synthesis(ws, noise_mode='const', force_fp32=True)
imgs = postprocess_images(imgs)
digit_pred = classifier_digits(imgs).argmax(dim=1).cpu()
noise_pred = classifier_noise(imgs).argmax(dim=1).cpu()

plt.figure(figsize=(15, 5))
for i in range(min(n_images, 5)):
    plt.subplot(1, 10, i+1)
    plt.imshow(imgs[i].cpu().squeeze(), cmap='gray')
    plt.title(f'digit: {digit_pred[i].numpy()} \n noise: {noise_pred[i].numpy()}')
    plt.axis('off')

# Project


In [None]:
# # load data
# transforms = torchvision.transforms.Compose([
#     torchvision.transforms.ToTensor(),
#     torchvision.transforms.Resize(32),
#     torchvision.transforms.Lambda(lambda x: 255*x)
# ])
# mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms)
#
# idx = 3
# img = mnist_test[idx][0]
# digit = mnist_test[idx][1]


dataset = 'mnistTest_stylegan2_blur_noise_maxSeverity3_proba50'
# dataset = 'mnistTest_stylegan2'

# original data
path_data = Path.cwd().parent / 'data/MNIST' / f'{dataset}.zip'
ds_original = ImageFolderDataset(path_data, use_labels=True)

idx = 1
img = ds_original[idx][0]
digit = ds_original[idx][1].argmax()

img = torch.tensor(img, device=device)
digits_ = torch.tensor(digit, device=device)
c = F.one_hot(digits_, G.c_dim)   

# project
projected_w_steps = project(G, img, c, classifier_digits=classifier_digits, regularize_classif_weight=0, device=device, num_steps=1000)
# projected_w_steps = project(G, img, c, device=device, num_steps=3000)
projected_w = projected_w_steps[-1][0].unsqueeze(0)

# save video
outdir = 'proj_out'

target_pil = torchvision.transforms.ToPILImage()(img.repeat((3, 1, 1))/255)
target_uint8 = np.array(target_pil, dtype=np.uint8)

# # Render debug output: optional video and projected image and W vector.
# os.makedirs(outdir, exist_ok=True)
# video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
# print (f'Saving optimization progress video "{outdir}/proj.mp4"')
# for projected_w in projected_w_steps:
#     synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
#     synth_image = synth_image.repeat((1, 3, 1, 1))
#     synth_image = (synth_image + 1) * (255/2)
#     synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
#     video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
# video.close()

# Save final projected frame and W vector.
target_pil.save(f'{outdir}/target.png')
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = synth_image.repeat((1, 3, 1, 1))
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())

In [None]:
x_target = (img.unsqueeze(0)).to(device)
x_target = x_target[:, :, 2:30, 2:30] / 255
print(classifier_digits(x_target))

x_proj = torch.from_numpy(synth_image[:, :, 0]).unsqueeze(0).unsqueeze(0).to(device)
x_proj = x_proj[:, :, 2:30, 2:30] / 255
print(classifier_digits(x_proj))

dist = (classifier_digits(x_target) - classifier_digits(x_proj)).square().sum()
print(dist)

In [None]:
STOP

## Read projected data

In [None]:
dataset = 'mnistTest_stylegan2_blur_noise_maxSeverity3_proba50'
# dataset = 'mnistTest_stylegan2'
# options = ''
options = '_classifWeight0.001'

# original data
path_data = Path.cwd().parent / 'data/MNIST' / f'{dataset}.zip'
ds_original = ImageFolderDataset(path_data, use_labels=True)

# projected latent codes
file = path_results / 'projected_data' / f'{dataset}_model00016{options}_all.npz'
projected_data = np.load(file)
projected_w = torch.from_numpy(projected_data['projected_w']).to(device)
digits = torch.from_numpy(projected_data['digits']).to(device)

assert all(projected_data['indices'] == np.arange(len(ds_original)))
assert all(projected_data['digits'] == [ds_original[i][1].argmax() for i in range(len(ds_original))])



# PREDICT CLASSES OF PROJECTED DATA

s_proj_all = None
class_predicted = None
batch_size = 64
for w, labels in zip(DataLoader(projected_w, batch_size=batch_size), DataLoader(digits, batch_size=batch_size)):
    batch_size_t = len(labels)

    with torch.no_grad():
        # compute output
        imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
        imgs = postprocess_images(imgs)

        digits_pred = classifier_digits(imgs)
        class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices
        class_softmax = F.softmax(digits_pred, dim=1)[torch.arange(batch_size_t), labels]
        
    # style values (styleSpace_values from hook)
    s_vec = styleSpace_dict2vec(styleSpace_values)

    s_proj_all = s_vec if s_proj_all is None else torch.cat((s_proj_all, s_vec))
    class_predicted = class_pred_t if class_predicted is None else torch.cat((class_predicted, class_pred_t))

s_proj_wellclassified = s_proj_all[class_predicted == digits]
s_proj_misclassified = s_proj_all[class_predicted != digits]


# PREDICT CLASSES OF ORIGINAL DATA AND COMPARE WITH PROJECTED DATA

# path_data = Path.cwd().parent / 'data/MNIST' / f'{dataset}.zip'
# ds = ImageFolderDataset(path_data, use_labels=True)

class_predicted_mnist = None
for x, y in DataLoader(ds_original, batch_size=64):
    x = (x / 255)[:, :, 2:30, 2:30]
    digits_pred = classifier_digits(x.to(device))
    class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices
    class_predicted_mnist = class_pred_t if class_predicted_mnist is None else torch.cat((class_predicted_mnist, class_pred_t))

tp_idx = class_predicted[class_predicted == digits] == class_predicted_mnist[class_predicted == digits]
tp = tp_idx.sum()
tn_idx = class_predicted[class_predicted != digits] == class_predicted_mnist[class_predicted != digits]
tn = tn_idx.sum()
fp_idx = class_predicted[class_predicted == digits] != class_predicted_mnist[class_predicted == digits]
fp = fp_idx.sum()
fn_idx = class_predicted[class_predicted != digits] != class_predicted_mnist[class_predicted != digits]
fn = fn_idx.sum()

assert tp + tn + fp + fn == len(class_predicted)

print('True positive:', tp.item())
print('True negative:', tn.item())
print('False positive:', fp.item())
print('False negative:', fn.item())


# FILTER OUT WRONG SAMPLES (FALSE POSITIVE OR NEGATIVE)
s_proj_wellclassified_fp = s_proj_wellclassified[fp_idx]
s_proj_misclassified_fn = s_proj_misclassified[fn_idx]
s_proj_wellclassified = s_proj_wellclassified[tp_idx]
s_proj_misclassified = s_proj_misclassified[tn_idx]


# PLOT SAMPLES
n_images = 8
plot_images_from_s(s_proj_wellclassified[:n_images])
plt.title('well-classified (correct)')

plot_images_from_s(s_proj_wellclassified_fp[:n_images])
plt.title('well-classified (wrong)')

plot_images_from_s(s_proj_misclassified[:n_images])
plt.title('misclassified (correct)')

plot_images_from_s(s_proj_misclassified_fn[:n_images])
plt.title('misclassified (wrong)')


# mnistTest_stylegan2_blur_noise_maxSeverity3_proba50
# True positive: 9509
# True negative: 158
# False positive: 87
# False negative: 246

# mnistTest_stylegan2_blur_noise_maxSeverity3_proba50_classifWeight0.001
# True positive: 9628
# True negative: 209
# False positive: 47
# False negative: 116

# mnistTest_stylegan2
# True positive: 9779
# True negative: 120
# False positive: 49
# False negative: 52

# Histograms W

In [None]:
dim = 5

z = torch.randn([projected_w.shape[0], G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (projected_w.shape[0],), device=device)
    c = F.one_hot(digits, G.c_dim)
else:
    c = None

ws = G.mapping(z, c, truncation_psi=1)
natural_w = ws[:, 0, :].cpu().numpy()


plt.figure()
plt.hist(projected_w.cpu().numpy()[:, dim], alpha=0.6, label='projected')
plt.hist(natural_w[:, dim], alpha=0.6, label='natural')
plt.legend()

In [None]:
wasserstein_distance(projected_w.cpu().numpy()[:, dim], natural_w[:, dim])

In [None]:
i = 510

plt.figure()
plt.imshow(ds_original[i][0][0], cmap='gray')
plt.xticks([])
plt.yticks([])


plot_images_from_s(s_proj_all[i])
