In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

import pickle
from pathlib import Path
import json
import PIL.Image
import torch
import torch.nn.functional as F
import torchvision.utils as vutils
from torch.utils.data import DataLoader, Dataset
import torchvision
import imageio
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_theme()
from sklearn.decomposition import PCA, IncrementalPCA
from sklearn.manifold import TSNE
import numpy as np
import functools
import pandas as pd
# from tqdm import tqdm


from stylegan2_ada_pytorch.torch_utils import misc
import stylegan2_ada_pytorch.dnnlib
import stylegan2_ada_pytorch.legacy
from stylegan2_ada_pytorch.projector import project
from stylegan2_ada_pytorch.training.dataset import ImageFolderDataset
from classifiers.models import CNN_MNIST

torch.manual_seed(0)
np.random.seed(0)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

path_results = Path.cwd().parent / 'results'
# path_results = Path('w:/results/stylegan2')

In [None]:
def postprocess_images(images):
    assert images.dim() == 4, "Expected 4D (B x C x H x W) image tensor, got {}D".format(images.dim())
    # lo, hi = [-1, 1] # generator scale
    # images = (images - lo) * (255 / (hi - lo)) # classifier scale
    # images = torch.round(images.clamp(0, 255))#.to(torch.uint8).to(torch.float)
    # images = (images * 127.5 + 128).clamp(0, 255)
    images = ((images + 1) / 2).clamp(0, 1)
    images = images[:, :, 2:30, 2:30] # remove padding

    return images

def plot_images(images, title=''):
    images = images * 255
    images = images.to(torch.uint8)
    plt.figure()
    plt.imshow(vutils.make_grid(images.cpu(), pad_value=255).permute(1,2,0), vmin=0, vmax=255)
    plt.axis('off')
    plt.grid(False)
    plt.title(title)


def generate_from_z(z):
    for i in np.arange(0, z.shape[0], batch_size):
        img = G(z[i:i+batch_size], c=None, noise_mode='const', force_fp32=True)
        if i == 0: 
            imgs = img
        else:
            imgs = torch.cat((imgs, img))
    return imgs


def plot_random_images(imgs):
    # from generate.py: img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    # imgs = (imgs * 127.5 + 128).clamp(0, 255).to(torch.uint8)
    imgs = postprocess_images(imgs)
    plt.figure(figsize=(10,10))
    plt.axis("off")
    plt.imshow(vutils.make_grid(imgs[torch.randint(0, imgs.shape[0], (100,))].cpu(), pad_value=255, nrow=10).permute(1,2,0))

def plot_images_from_s(s):
    imgs = generate_img_from_s(s)
    imgs = postprocess_images(imgs)
    plot_images(imgs)
    
def truncate(x, x_avg, psi):
    # psi=0 means we get average value, 
    # psi=1 we get original value, 
    # 0<psi<1 we get interpolation between mean and original
    return x_avg.lerp(x, psi)


def styleSpace_dict2vec(styleSpace_dict):
    styleSpace_vec = []
    for res in G.synthesis.block_resolutions:
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            key = f'b{res}.{layer}'
            values = styleSpace_dict[key]
            if values.dim() == 1: values = values.unsqueeze(0)
            styleSpace_vec.append(values)
    styleSpace_vec = torch.cat(styleSpace_vec, dim=1)
    return styleSpace_vec


def styleSpace_vec2dict(styleSpace_vec):
    if styleSpace_vec.dim() == 1:
        styleSpace_vec = styleSpace_vec.unsqueeze(0)
    styleSpace_dict = {}
    dim_base = 0
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            dim_size = block_layer.affine.weight.shape[1]
            key = f'b{res}.{layer}'
            styleSpace_dict[key] = styleSpace_vec[:, dim_base:dim_base+dim_size]#.squeeze()
            dim_base += dim_size
    assert dim_base == styleSpace_vec.shape[1]
    return styleSpace_dict


def compute_styleSpace_vec_idx2coord():
    vec_idx2coord = {}
    idx = 0
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            dim_size = block_layer.affine.weight.shape[1]
            for dim in range(dim_size):
                vec_idx2coord[idx] = (f'b{res}.{layer}', dim)
                idx += 1
    return vec_idx2coord

In [None]:
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqcat.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqdog.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/afhqwild.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/brecahad.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/cifar10.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl'
# path_model = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metfaces.pkl'

# path_model = path_results / 'stylegan2-training-runs' / '00011-mnist_stylegan2_noise-cond-auto4-original'
# path_model = path_results / 'stylegan2-training-runs' / '00015-mnist_stylegan2_blur_noise-cond-auto4'
path_model = path_results / 'stylegan2-training-runs' / '00016-mnist_stylegan2_blur_noise_maxSeverity3_proba50-cond-auto4'

# find best model in folder
if not str(path_model).endswith('pkl'):
    with open(path_model / 'metric-fid50k_full.jsonl', 'r') as json_file:
        json_list = list(json_file)

    best_fid = 1e6
    for json_str in json_list:
        json_line = json.loads(json_str)
        if json_line['results']['fid50k_full'] < best_fid:
            best_fid = json_line['results']['fid50k_full']
            best_model = json_line['snapshot_pkl']
    print('Best FID: {:.2f} ; best model : {}'.format(best_fid, best_model))
    path_model = path_model / best_model

    with open(path_model, 'rb') as f:
        G = pickle.load(f)['G_ema'].to(device)  # torch.nn.Module

else:
    with dnnlib.util.open_url(path_model) as f:
        G = legacy.load_network_pkl(f)['G_ema'].to(device)

if device == 'cpu': G.forward = functools.partial(G.forward, force_fp32=True)

conditional = G.c_dim > 0

# registor hooks to save intermediate values (images and style space)
intermediate_images_torgb = {}
def get_torgb(name):
    def hook(module, input, output):
        intermediate_images_torgb[name] = output.detach()
    return hook
intermediate_images_block = {}
def get_block_img(name):
    def hook(module, input, output):
        intermediate_images_block[name] = output[1].detach()
    return hook
styleSpace_values = {}
def get_styleSpace_values(name):
    def hook(module, input, output):
        styleSpace_values[name] = output.detach()
    return hook
for res in G.synthesis.block_resolutions:
    block = getattr(G.synthesis, f'b{res}')
    block.torgb.register_forward_hook(get_torgb(res))
    block.register_forward_hook(get_block_img(res))
    for layer in ['conv0', 'conv1', 'torgb']:
        if res == 4 and layer == 'conv0': continue
        block_layer = getattr(block, layer)
        block_layer.affine.register_forward_hook(get_styleSpace_values(name=f'b{res}.{layer}'))

        
# backward hooks to get gradients relative to styleSpace
styleSpace_grads = {}
def get_styleSpace_grads(name):
    def hook(self, grad_input, grad_output):
        styleSpace_grads[name] = grad_output[0].detach()
    return hook

for res in G.synthesis.block_resolutions:
    block = getattr(G.synthesis, f'b{res}')
    block.torgb.register_forward_hook(get_torgb(res))
    block.register_forward_hook(get_block_img(res))
    for layer in ['conv0', 'conv1', 'torgb']:
        if res == 4 and layer == 'conv0': continue
        block_layer = getattr(block, layer)
        block_layer.affine.register_full_backward_hook(get_styleSpace_grads(name=f'b{res}.{layer}'))

        
# dict to convert index to coordinate for stylespace vectors
styleSpace_vec_idx2coord = compute_styleSpace_vec_idx2coord()


# function to move a given style dimension
def generate_img_new_style(ws, block_layer_name, index=0, direction=1):
    def move_style(index, direction):
        def hook(module, input, output):
            output[:, index] += direction
            return output
        return hook

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, direction))

    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img
    
    
# function to modify a given style dimension
def generate_img_new_style2(ws, block_layer_name, index, s_style_min, s_style_max, s_shift=1, positive_direction=True):
    def move_style(index, weight_shift):
        def hook(module, input, output):
            output[:, index] += weight_shift
            return output
        return hook
    
    assert type(index) == int, 'Function only works for 1 style'
    assert ws.shape[0] == 1, 'Works only for 1 image' # orig_value only for 1 image
    
    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    
    with torch.no_grad():
        G.synthesis(ws, noise_mode='const', force_fp32=True) # first pass to get style vector from hook
    orig_value = styleSpace_values[block_layer_name][0, index]
    target_value = (s_style_max if positive_direction else s_style_min)
    weight_shift = s_shift * (target_value - orig_value)

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, weight_shift))
    
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img


# function to modify a given style dimension
def generate_img_new_style3(ws, block_layer_name, index, s_std, strength=5, positive_direction=True):
    def move_style(index, weight_shift):
        def hook(module, input, output):
            output[:, index] += weight_shift
            return output
        return hook
    
    assert type(index) == int, 'Function only works for 1 style'
    assert ws.shape[0] == 1, 'Works only for 1 image'
    
    if ws.dim() == 2:
        ws = ws.unsqueeze(1).repeat((1, G.num_ws, 1))
    
    d = 1 if positive_direction else -1
    weight_shift = d * strength * s_std

    block_name, layer_name = block_layer_name.split('.')
    block = getattr(G.synthesis, block_name)
    block_layer = getattr(block, layer_name)
    handle = block_layer.affine.register_forward_hook(move_style(index, weight_shift))
    
    img = G.synthesis(ws, noise_mode='const', force_fp32=True)

    handle.remove()

    return img


# function to generate image from S
def generate_img_from_s(s):
    def set_style(values):
        def hook(module, input, output):
            output = values
            return output
        return hook
    
    if type(s) != dict: s = styleSpace_vec2dict(s)
    assert s['b4.conv1'].dim() == 2, 'Should be of 2 dimensions: batch_size x s_dim'
    batch_size = s['b4.conv1'].shape[0]
    
    handles = []
    for res in G.synthesis.block_resolutions:
        block = getattr(G.synthesis, f'b{res}')
        for layer in ['conv0', 'conv1', 'torgb']:
            if res == 4 and layer == 'conv0': continue
            block_layer = getattr(block, layer)
            values = s[f'b{res}.{layer}']
            handles.append(block_layer.affine.register_forward_hook(set_style(values)))
    
    dummy_ws = torch.zeros((batch_size, G.num_ws, G.w_dim), device=device)
    img = G.synthesis(dummy_ws, noise_mode='const', force_fp32=True)

    for h in handles: h.remove()
    
    return img

In [None]:
n_images = 5
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits, G.c_dim)          
else:
    c = None
misc.print_module_summary(G, [z, c])

ws = G.mapping(z, c, truncation_psi=1)
img = G.synthesis(ws, noise_mode='const', force_fp32=True)
img = postprocess_images(img)
plot_images(img, title='original')


img_ = generate_img_new_style(ws, block_layer_name='b4.conv1', index=3, direction=-1)
img_ = postprocess_images(img_)
plot_images(img_, title='with new style')

## Load classifier

In [None]:
# predict digits
classifier_digits = CNN_MNIST(output_dim=10).to(device)
# classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220411_0826.pth', map_location=device)) # Confiance
# classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_weights_20220210_1601.pth', map_location=device))
classifier_digits.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_mnist_stylegan2_blur_noise_maxSeverity3_proba50_20220510_1124.pth', map_location=device))
classifier_digits.eval()

# predict noise
classifier_noise = CNN_MNIST(output_dim=6).to(device)
# classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_noise_MNIST_weights_20220411_0841.pth', map_location=device)) # Confiance
classifier_noise.load_state_dict(torch.load(path_results / 'classifiers' / 'CNN_MNIST_noise_weights_20220210_1728.pth', map_location=device))
classifier_noise.eval()

imgs = G.synthesis(ws, noise_mode='const', force_fp32=True)
imgs = postprocess_images(imgs)
digit_pred = classifier_digits(imgs).argmax(dim=1).cpu()
noise_pred = classifier_noise(imgs).argmax(dim=1).cpu()

plt.figure(figsize=(15, 5))
for i in range(min(n_images, 5)):
    plt.subplot(1, 10, i+1)
    plt.imshow(imgs[i].cpu().squeeze(), cmap='gray')
    plt.title(f'digit: {digit_pred[i].numpy()} \n noise: {noise_pred[i].numpy()}')
    plt.axis('off')

In [None]:
# sample images in latent space to get a set of latent codes
n_images = 1
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits, G.c_dim)          
else:
    c = None
ws = G.mapping(z, c, truncation_psi=1)

imgs_orig = G.synthesis(ws, noise_mode='const', force_fp32=True)
imgs_orig = postprocess_images(imgs_orig)
with torch.no_grad():
    digits_pred_orig = classifier_digits(imgs_orig).cpu()
    confidence_orig = digits_pred_orig.max(dim=1).values
    maxLogit_orig = digits_pred_orig.max(axis=1).values
    maxSoftmax_orig = F.softmax(digits_pred_orig, dim=1).max(axis=1).values

plot_images(imgs_orig)

## AttFind (slow!)

In [None]:
confidence_change = pd.Series(name='confidence_change', dtype=np.float32)
for block_layer, v in styleSpace_values.items():
    print(block_layer)
    for dimension in range(v.shape[1]):
        # for direction in [-1, 1]:
        direction = 1
        imgs_newStyle = generate_img_new_style(ws, block_layer, index=dimension, direction=direction)
        imgs_newStyle = postprocess_images(imgs_newStyle)
        with torch.no_grad():
            digits_pred_newStyle = classifier_digits(imgs_newStyle).cpu()
            confidence_newStyle = digits_pred_newStyle.max(axis=1).values

        confidence_change[f'{block_layer}_{dimension}_{direction}'] = (confidence_newStyle - confidence_orig).mean()

    # break

In [None]:
confidence_change_sorted = confidence_change[confidence_change.abs().sort_values(ascending=False).index]

top_k = 5
titles_plot = ['original']
imgs_plot = [imgs_orig]
maxLogit = [maxLogit_orig]
for k, v in confidence_change_sorted[:top_k].items():
    # print(f'{k}: {v}')
    block_layer, dimension, direction = k.split('_')
    imgs_newStyle = generate_img_new_style(ws, block_layer, index=int(dimension), direction=-1*int(direction))
    imgs_newStyle = postprocess_images(imgs_newStyle)
    titles_plot.append('{}: $\Delta$={:.2f}'.format(k, v))
    imgs_plot.append(imgs_newStyle)
    with torch.no_grad():
        digits_pred_newStyle = classifier_digits(imgs_newStyle).cpu()
        # maxSoftmax_newStyle = F.softmax(digits_pred_newStyle, dim=1).max(axis=1).values
        maxLogit_newStyle = digits_pred_newStyle.max(axis=1).values
    maxLogit.append(maxLogit_newStyle)

fig, axs = plt.subplots(n_images, top_k+1, figsize=(20, 5))
for k, imgs in enumerate(imgs_plot):
    for i in range(n_images):
        ax = axs[i, k] if n_images > 1 else axs[k]
        ax.imshow(imgs[i].cpu().squeeze(), cmap='gray')
        ax.axis('off')
        ax.grid(False)
        ax.set_title(titles_plot[k] + '\n' + '{:.2f}'.format(maxLogit[k].numpy().item()))

## Gradient of score wrt W

In [None]:
n_images = 1
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits, G.c_dim)          
else:
    c = None

# compute input
w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :] # keep only the first element
w.requires_grad = True

# compute output
imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
imgs = postprocess_images(imgs)

digits_pred = classifier_digits(imgs)
max_logit = digits_pred.max(axis=1).values.mean()
max_softmax = F.softmax(digits_pred, dim=1).max(axis=1).values.mean()
class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices

# grad of output relative to input
grad = torch.autograd.grad(max_logit, w)[0]
grad = grad.mean(axis=0)

plot_images(imgs)
plt.title('top class: {} : {:.0f}%'.format(class_pred.cpu().numpy(), 100*max_softmax.detach().cpu().numpy()))

In [None]:
top_k = 10
top_k_idxs = grad.abs().topk(top_k).indices
mask = torch.zeros_like(grad).index_fill(dim=0, index=top_k_idxs, value=1)

w_new = w - 0.05*mask*grad

imgs = G.synthesis(w_new.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
imgs = postprocess_images(imgs)

digits_pred = classifier_digits(imgs)
# max_logit = digits_pred.max(axis=1).values
max_softmax = F.softmax(digits_pred, dim=1).max(axis=1).values.mean()
class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices

plot_images(imgs)
plt.title('top class: {} : {:.0f}%'.format(class_pred.cpu().numpy(), 100*max_softmax.detach().cpu().numpy()))

## Gradient of score wrt style

### One image

In [None]:
n_images = 1
z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits, G.c_dim)
    class_generated = c.argmax(axis=1)
else:
    c = None

# compute input
w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :] # keep only the first element
w.requires_grad = True

# compute output
imgs_orig = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
imgs_orig = postprocess_images(imgs_orig)

digits_pred_orig = classifier_digits(imgs_orig)
# max_logit = digits_pred.max(axis=1).values
# max_softmax = F.softmax(digits_pred, dim=1).max(axis=1).values.mean()
# class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices
class_logit_orig = digits_pred_orig[torch.arange(n_images), class_generated]
class_softmax_orig = F.softmax(digits_pred_orig, dim=1)[torch.arange(n_images), class_generated]

# backpropagate gradient
loss = class_logit_orig.mean() # goal is to reduce the score for the class
loss.backward() # compute gradients and access them thanks to the backward hook
styleSpace_grads_vec = styleSpace_dict2vec(styleSpace_grads)

plot_images(imgs_orig)
plt.title('Class generated: {}; predicted at {:.0f}%'.format(class_generated.squeeze().cpu().numpy(), 100*class_softmax_orig.squeeze().detach().cpu().numpy()))


In [None]:
top_k = 10
top_k_idxs = styleSpace_grads_vec.mean(axis=0).abs().topk(top_k).indices.cpu().numpy()
top_k_vals = styleSpace_grads_vec.mean(axis=0)[top_k_idxs]

styleSpace_vec_idx2coord = compute_styleSpace_vec_idx2coord()

imgs_plot = [imgs_orig]
class_softmax_all = [class_softmax_orig]
for idx in top_k_idxs:
    block_layer_name, index = styleSpace_vec_idx2coord[idx]
    direction = - styleSpace_grads_vec.mean(axis=0)[idx].sign()
    imgs_newStyle = generate_img_new_style(w, block_layer_name, index, direction=3*direction)
    imgs_newStyle = postprocess_images(imgs_newStyle)
    imgs_plot.append(imgs_newStyle)
    with torch.no_grad():
        digits_pred_newStyle = classifier_digits(imgs_newStyle).cpu()
        class_softmax_newStyle = F.softmax(digits_pred_newStyle, dim=1)[torch.arange(n_images), class_generated]
    class_softmax_all.append(class_softmax_newStyle)

fig, axs = plt.subplots(n_images, top_k+1, figsize=(15, 5))
for k, img in enumerate(imgs_plot):
    ax =  axs[k]
    ax.imshow(img.detach().cpu().squeeze(), cmap='gray')
    ax.axis('off')
    ax.grid(False)
    title = 'original' if k == 0 else '$s_{' + str(top_k_idxs[k-1]) + '}$'
    title += '\n{}: {:.0f}%'.format(class_generated.squeeze().cpu().numpy(), 100*class_softmax_all[k].squeeze().detach().cpu().numpy())
    ax.set_title(title)


Look at mostly well classified images (condition = class pred), and check directions to degrade perfo

Which value of direction ?

### Multiple images - one class

In [None]:
n_images = 1000
batch_size = 32
class_selected = 0

w_ok = None
while True:
    z = torch.randn([batch_size, G.z_dim], device=device)    # latent codes
    if conditional:
        digits = class_selected*torch.ones((batch_size, ), dtype=torch.int64, device=device)
        c = F.one_hot(digits, G.c_dim)
        class_generated = c.argmax(axis=1)
    else:
        c = None

    with torch.no_grad():
        # compute input
        w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :] # keep only the first element

        # compute output
        imgs_orig = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
        imgs_orig = postprocess_images(imgs_orig)

        digits_pred_orig = classifier_digits(imgs_orig)
        # max_logit = digits_pred.max(axis=1).values
        # max_softmax = F.softmax(digits_pred, dim=1).max(axis=1).values.mean()
        class_pred = F.softmax(digits_pred_orig, dim=1).max(axis=1).indices
        class_softmax_orig = F.softmax(digits_pred_orig, dim=1)[torch.arange(batch_size), class_generated]
    
    w_ok = w[class_pred == class_selected] if w_ok is None else torch.cat((w_ok, w[class_pred == class_selected]))
    if w_ok.shape[0] >= n_images:
        w_ok = w_ok[:n_images]
        break
ws_ok = w_ok.unsqueeze(1).repeat((1, G.num_ws, 1))

In [None]:
# compute input
class_generated = class_selected*torch.ones((batch_size, ), dtype=torch.int64, device=device)

class_softmax_orig_all = None
imgs_orig_all = None
styleSpace_values_vec_all = None
styleSpace_grads_vec_all = None
for ws in DataLoader(ws_ok, batch_size=batch_size):
    ws.requires_grad = True
    # compute output
    imgs_orig = G.synthesis(ws, noise_mode='const', force_fp32=True)
    imgs_orig = postprocess_images(imgs_orig)

    digits_pred_orig = classifier_digits(imgs_orig)
    # max_logit = digits_pred.max(axis=1).values
    # max_softmax = F.softmax(digits_pred, dim=1).max(axis=1).values.mean()
    class_pred = F.softmax(digits_pred_orig, dim=1).max(axis=1).indices
    class_logits_orig = digits_pred_orig[torch.arange(ws.shape[0]), class_generated[:ws.shape[0]]] # class_generated[:ws.shape[0]] to handle last batch
    class_softmax_orig = F.softmax(digits_pred_orig, dim=1)[torch.arange(ws.shape[0]), class_generated[:ws.shape[0]]] # class_generated[:ws.shape[0]] to handle last batch
    assert all(class_pred == class_selected), 'class_pred != class_selected'
    
    # style values (styleSpace_values from hook)
    styleSpace_values_vec = styleSpace_dict2vec(styleSpace_values)
    
    # backpropagate gradient to get its values (styleSpace_grads from hook)
    # loss = class_logits_orig.mean() # goal is to reduce the score for the class
    loss = class_softmax_orig.mean() # goal is to reduce the score for the class, SEEMS TO WORK BETTER THAN USING LOGIT
    loss.backward() # compute gradients and access them thanks to the backward hook
    styleSpace_grads_vec = styleSpace_dict2vec(styleSpace_grads)
    
    # record variables
    class_softmax_orig_all = class_softmax_orig.detach().cpu() if class_softmax_orig_all is None else torch.cat((class_softmax_orig_all, class_softmax_orig.detach().cpu()))
    imgs_orig_all = imgs_orig.detach().cpu() if imgs_orig_all is None else torch.cat((imgs_orig_all, imgs_orig.detach().cpu()))
    styleSpace_values_vec_all = styleSpace_values_vec if styleSpace_values_vec_all is None else torch.cat((styleSpace_values_vec_all, styleSpace_values_vec))
    styleSpace_grads_vec_all = styleSpace_grads_vec if styleSpace_grads_vec_all is None else torch.cat((styleSpace_grads_vec_all, styleSpace_grads_vec))
    
    
plot_images(imgs_orig_all[:100])
plt.title('100 first images')

style_min_vec = styleSpace_values_vec_all.min(dim=0).values
style_min = styleSpace_vec2dict(style_min_vec)
style_max_vec = styleSpace_values_vec_all.max(dim=0).values
style_max = styleSpace_vec2dict(style_max_vec)
style_std_vec = styleSpace_values_vec_all.std(dim=0)

In [None]:
n_images = 10

top_k = 5
top_k_idxs = styleSpace_grads_vec_all.mean(axis=0).abs().topk(top_k).indices.cpu().numpy()
top_k_vals = styleSpace_grads_vec_all.mean(axis=0)[top_k_idxs]

styleSpace_vec_idx2coord = compute_styleSpace_vec_idx2coord()

imgs_plot = torch.empty((n_images, top_k+1, 1, 28, 28))
imgs_plot[:, 0] = imgs_orig_all[:n_images]
class_softmax_all = torch.empty((n_images, top_k+1))
class_softmax_all[:, 0] = class_softmax_orig_all[:n_images]
for i, ws in enumerate(ws_ok[:n_images]):
    ws = ws.unsqueeze(0) # expand batch dim
    for j, idx in enumerate(top_k_idxs):
        block_layer_name, index = styleSpace_vec_idx2coord[idx]
        direction = styleSpace_grads_vec_all.mean(axis=0)[idx].sign() <= 0
        # direction = not(direction)
        imgs_newStyle = generate_img_new_style2(ws, block_layer_name, index, style_min[block_layer_name][index], style_max[block_layer_name][index], s_shift=2, positive_direction=direction)
        imgs_newStyle = postprocess_images(imgs_newStyle)
        with torch.no_grad():
            digits_pred_newStyle = classifier_digits(imgs_newStyle)
        class_softmax_newStyle = F.softmax(digits_pred_newStyle, dim=1)[0, class_selected]
        
        # record variables
        imgs_plot[i, j+1] = imgs_newStyle.detach().cpu()
        class_softmax_all[i, j+1] = class_softmax_newStyle.cpu()


fig, axs = plt.subplots(n_images, top_k+1, figsize=(8, 20))
for i in range(n_images): # for each image
    for s in range(top_k+1): # for each style
        ax = axs[i, s]
        ax.imshow(imgs_plot[i, s].squeeze(), vmin=0, vmax=1, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        title = 'original' if s == 0 else '$s_{' + str(top_k_idxs[s-1]) + '}$'
        title += '\n{}: {:.0f}%'.format(class_selected, 100*class_softmax_all[i, s])
        ax.set_title(title)


## Find noise regions

### Get latent codes of low and high noise samples

In [None]:
batch_size = 32
n_images = 1000
class_selected = 0

noise_pred_all = None
w_low_noise = None
w_high_noise = None
while True:
    z = torch.randn([batch_size, G.z_dim], device=device)    # latent codes
    if conditional:
        # digits = torch.randint(0, G.c_dim, (batch_size,), device=device)
        digits = class_selected*torch.ones((batch_size, ), dtype=torch.int64, device=device)
        c = F.one_hot(digits, G.c_dim)          
    else:
        c = None

    ws = G.mapping(z, c, truncation_psi=1)
    imgs = G.synthesis(ws, noise_mode='const')
    imgs = postprocess_images(imgs)

    # noise_pred_all = classifier_noise(imgs).argmax(dim=1).cpu().int() if noise_pred_all is None else torch.cat((noise_pred_all, classifier_noise(imgs).argmax(dim=1).cpu().int()))
    noise_pred = classifier_noise(imgs).argmax(dim=1).cpu().int()

    w_low_noise = ws[noise_pred == 0][:, 0, :] if w_low_noise is None else torch.cat((w_low_noise, ws[noise_pred == 0][:, 0, :]))
    w_high_noise = ws[noise_pred == 3][:, 0, :] if w_high_noise is None else torch.cat((w_high_noise, ws[noise_pred == 3][:, 0, :]))
    if len(w_low_noise) > n_images and len(w_high_noise) > n_images:
        w_low_noise = w_low_noise[:n_images]
        w_high_noise = w_high_noise[:n_images]
        break

In [None]:
styleSpace_values_low_noise_vec = None
for w in DataLoader(w_low_noise, batch_size=batch_size):
    w.requires_grad = True
    ws = w.unsqueeze(1).repeat((1, G.num_ws, 1))
    # compute output
    imgs_orig = G.synthesis(ws, noise_mode='const', force_fp32=True)
    imgs_orig = postprocess_images(imgs_orig)
    
    # style values (styleSpace_values from hook)
    styleSpace_values_vec = styleSpace_dict2vec(styleSpace_values)

    # record variables
    styleSpace_values_low_noise_vec = styleSpace_values_vec if styleSpace_values_low_noise_vec is None else torch.cat((styleSpace_values_low_noise_vec, styleSpace_values_vec))


styleSpace_values_high_noise_vec = None
for w in DataLoader(w_high_noise, batch_size=batch_size):
    w.requires_grad = True
    ws = w.unsqueeze(1).repeat((1, G.num_ws, 1))
    # compute output
    imgs_orig = G.synthesis(ws, noise_mode='const', force_fp32=True)
    imgs_orig = postprocess_images(imgs_orig)
    
    # style values (styleSpace_values from hook)
    styleSpace_values_vec = styleSpace_dict2vec(styleSpace_values)

    # record variables
    styleSpace_values_high_noise_vec = styleSpace_values_vec if styleSpace_values_high_noise_vec is None else torch.cat((styleSpace_values_high_noise_vec, styleSpace_values_vec))


In [None]:
plt.figure()
plt.scatter(np.arange(len(styleSpace_values_high_noise_vec.mean(0).cpu().numpy())), styleSpace_values_high_noise_vec.mean(0).cpu().numpy(), s=1)

plt.figure()
plt.scatter(np.arange(len(styleSpace_values_low_noise_vec.mean(0).cpu().numpy())), styleSpace_values_low_noise_vec.mean(0).cpu().numpy(), s=1)

In [None]:
mean = torch.cat((styleSpace_values_high_noise_vec, styleSpace_values_low_noise_vec)).mean(0)
std = torch.cat((styleSpace_values_high_noise_vec, styleSpace_values_low_noise_vec)).std(0)
noise_direction = (((styleSpace_values_high_noise_vec - mean) / std).mean(0) - ((styleSpace_values_low_noise_vec - mean) / std).mean(0)).cpu().numpy()
top_k_dims = (-np.abs(noise_direction)).argsort()[:20]
print(top_k_dims)
plt.figure()
plt.scatter(np.arange(len(noise_direction)), noise_direction, s=1)

In [None]:
plt.figure()
style_min_np = styleSpace_dict2vec(style_min).cpu().numpy()[0]
plt.scatter(np.arange(len(style_min_np)), style_min_np, s=1, alpha=0.5)
style_max_np = styleSpace_dict2vec(style_max).cpu().numpy()[0]
plt.scatter(np.arange(len(style_max_np)), style_max_np, s=1, alpha=0.5)
plt.scatter(np.arange(len(noise_direction)), noise_direction, s=1)


In [None]:
noise_direction_w = (w_high_noise.mean(0) - w_low_noise.mean(0)).cpu().numpy()
plt.figure()
plt.scatter(np.arange(len(noise_direction_w)), noise_direction_w, s=1)

### Style dimensions with the most impact on noise : low-noise vs. high-noise samples

In [None]:
top_k_noise_dimensions = 5

for i in range(top_k_noise_dimensions):
    k = (-np.abs(noise_direction)).argsort()[i]
    
    plt.figure()
    plt.hist(styleSpace_values_high_noise_vec[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8)
    plt.hist(styleSpace_values_low_noise_vec[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8)
    plt.axvline(style_min_vec[k].cpu().numpy())
    plt.axvline(style_max_vec[k].cpu().numpy())

In [None]:
n_images = 10

top_k = 10
top_k_idxs = (-np.abs(noise_direction)).argsort()[:top_k]

styleSpace_vec_idx2coord = compute_styleSpace_vec_idx2coord()

imgs_plot = torch.empty((n_images, top_k+1, 1, 28, 28))
imgs_plot[:, 0] = imgs_orig_all[:n_images]
class_softmax_all = torch.empty((n_images, top_k+1))
class_softmax_all[:, 0] = class_softmax_orig_all[:n_images]
for i, ws in enumerate(ws_ok[:n_images]):
    ws = ws.unsqueeze(0) # expand batch dim
    for j, idx in enumerate(top_k_idxs):
        block_layer_name, index = styleSpace_vec_idx2coord[idx]
        direction = np.sign(noise_direction[idx]) >= 0
        # direction = not(direction)
        imgs_newStyle = generate_img_new_style3(ws, block_layer_name, index, s_std=style_std_vec[idx], strength=10, positive_direction=direction)
        imgs_newStyle = postprocess_images(imgs_newStyle)
        with torch.no_grad():
            digits_pred_newStyle = classifier_digits(imgs_newStyle)
        class_softmax_newStyle = F.softmax(digits_pred_newStyle, dim=1)[0, class_selected]
        
        # record variables
        imgs_plot[i, j+1] = imgs_newStyle.detach().cpu()
        class_softmax_all[i, j+1] = class_softmax_newStyle.cpu()


fig, axs = plt.subplots(n_images, top_k+1, figsize=(15, 20))
for i in range(n_images): # for each image
    for s in range(top_k+1): # for each style
        ax = axs[i, s]
        ax.imshow(imgs_plot[i, s].squeeze(), vmin=0, vmax=1, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        title = 'original' if s == 0 else '$s_{' + str(top_k_idxs[s-1]) + '}$'
        title += '\n{}: {:.0f}%'.format(class_selected, 100*class_softmax_all[i, s])
        ax.set_title(title)


In [None]:
top_k = 20
top_k_idxs = (-np.abs(noise_direction)).argsort()[:top_k]

strength = 10

styleSpace_values_low_noise_vec_shifted = styleSpace_values_low_noise_vec.clone()
for k in top_k_idxs:
    positive_direction = noise_direction[k] >= 0
    d = 1 if positive_direction else -1
    weight_shift = weight_shift = d * strength * style_std_vec[k]
    styleSpace_values_low_noise_vec_shifted[:, k] += weight_shift

n_images = 8
imgs = generate_img_from_s(styleSpace_values_low_noise_vec[:n_images])
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('low noise samples')

imgs = generate_img_from_s(styleSpace_values_low_noise_vec_shifted[:n_images])
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('low noise samples shifted to high noise')

imgs = generate_img_from_s(styleSpace_values_high_noise_vec[:n_images])
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('some high noise samples to compare')

### Distance from low-noise center

In [None]:
low_noise_center = styleSpace_values_low_noise_vec.mean(0).unsqueeze(0)

dist_low2center = torch.cdist(styleSpace_values_low_noise_vec, low_noise_center).cpu().numpy()
dist_high2center = torch.cdist(styleSpace_values_high_noise_vec, low_noise_center).cpu().numpy()

plt.figure()
plt.hist(dist_low2center, bins=20, edgecolor='none', alpha=0.8, label='low noise')
plt.hist(dist_high2center, bins=20, edgecolor='none', alpha=0.8, label='high noise')
plt.legend()
plt.title('distance from low noise center using all dimensions');


top_k_dims = (-np.abs(noise_direction)).argsort()[:10]

dist_low2center = torch.cdist(styleSpace_values_low_noise_vec[:, top_k_dims], low_noise_center[:, top_k_dims]).cpu().numpy()
dist_high2center = torch.cdist(styleSpace_values_high_noise_vec[:, top_k_dims], low_noise_center[:, top_k_dims]).cpu().numpy()

plt.figure()
plt.hist(dist_low2center, bins=20, edgecolor='none', alpha=0.8, label='low noise')
plt.hist(dist_high2center, bins=20, edgecolor='none', alpha=0.8, label='high noise')
plt.legend()
plt.title('distance from low noise center using top noise dimensions');

In [None]:
nb_misclassified = 0
for w in DataLoader(w_low_noise, batch_size):
    imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const')
    imgs = postprocess_images(imgs)
    digits_pred = classifier_digits(imgs)
    class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices
    nb_misclassified += (class_pred != 0).sum().cpu().numpy().item()
accuracy = 100 - 100 * nb_misclassified / len(w_high_noise)
print(f'accuracy low noise samples: {accuracy}%')

nb_misclassified = 0
for w in DataLoader(w_high_noise, batch_size):
    imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const')
    imgs = postprocess_images(imgs)
    digits_pred = classifier_digits(imgs)
    class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices
    nb_misclassified += (class_pred != 0).sum().cpu().numpy().item()
accuracy = 100 - 100 * nb_misclassified / len(w_high_noise)
print(f'accuracy high noise samples: {accuracy}%')

In [None]:
n_images = 1000
batch_size = 32
class_selected = 0

w_wellclassified = None
w_misclassified = None
s_wellclassified = None
s_misclassified = None
while True:
    z = torch.randn([batch_size, G.z_dim], device=device)    # latent codes
    if conditional:
        digits = class_selected*torch.ones((batch_size, ), dtype=torch.int64, device=device)
        c = F.one_hot(digits, G.c_dim)
        class_generated = c.argmax(axis=1)
    else:
        c = None

    with torch.no_grad():
        # compute input
        w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :] # keep only the first element

        # compute output
        imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
        imgs = postprocess_images(imgs)

        digits_pred = classifier_digits(imgs)
        class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices
        class_softmax = F.softmax(digits_pred, dim=1)[torch.arange(batch_size), class_generated]
        
    # style values (styleSpace_values from hook)
    styleSpace_values_vec = styleSpace_dict2vec(styleSpace_values)
    
    w_wellclassified = w[class_pred == class_selected] if w_wellclassified is None else torch.cat((w_wellclassified, w[class_pred == class_selected]))
    w_misclassified = w[class_pred != class_selected] if w_misclassified is None else torch.cat((w_misclassified, w[class_pred != class_selected]))
    s_wellclassified = styleSpace_values_vec[class_pred == class_selected] if s_wellclassified is None else torch.cat((s_wellclassified, styleSpace_values_vec[class_pred == class_selected]))
    s_misclassified = styleSpace_values_vec[class_pred != class_selected] if s_misclassified is None else torch.cat((s_misclassified, styleSpace_values_vec[class_pred != class_selected]))
    if (w_wellclassified.shape[0] + w_misclassified.shape[0]) >= n_images:
        break


In [None]:
dist_wellclassified2center = torch.cdist(s_wellclassified, low_noise_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified, low_noise_center).cpu().numpy()

plt.figure()
plt.hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
plt.hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
plt.legend()
plt.title('distance from low noise center using all dimensions');



dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], low_noise_center[:, top_k_dims]).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], low_noise_center[:, top_k_dims]).cpu().numpy()

plt.figure()
plt.hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
plt.hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
plt.legend()
plt.title('distance from low noise center using top noise dimensions');

## misclassified vs. well classified

In [None]:
# class_selected = 0
class_selected = 'all'
n_images = 100000
batch_size = 32

if class_selected == 'all': 
    digits = torch.randint(0, G.c_dim, (n_images,), dtype=torch.int64, device=device)
else:
    digits = class_selected*torch.ones((n_images, ), dtype=torch.int64, device=device)

z_all = None
w_all = None
s_all = None
class_predicted = None
for labels in DataLoader(digits, batch_size):
    batch_size_t = len(labels)
    z = torch.randn([batch_size_t, G.z_dim], device=device)    # sample latent codes
    c = F.one_hot(labels, G.c_dim)

    with torch.no_grad():
        # compute input
        w = G.mapping(z, c, truncation_psi=1, truncation_cutoff=8)[:, 0, :] # keep only the first element

        # compute output
        imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
        imgs = postprocess_images(imgs)

        digits_pred = classifier_digits(imgs)
        class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices
        class_softmax = F.softmax(digits_pred, dim=1)[torch.arange(batch_size_t), labels]
        
    # style values (styleSpace_values from hook)
    s_vec = styleSpace_dict2vec(styleSpace_values)

    z_all = z if z_all is None else torch.cat((z_all, z))
    w_all = w if w_all is None else torch.cat((w_all, w))
    s_all = s_vec if s_all is None else torch.cat((s_all, s_vec))
    class_predicted = class_pred_t if class_predicted is None else torch.cat((class_predicted, class_pred_t))

z_wellclassified = z_all[class_predicted == digits]
z_misclassified = z_all[class_predicted != digits]
w_wellclassified = w_all[class_predicted == digits]
w_misclassified = w_all[class_predicted != digits]
s_wellclassified = s_all[class_predicted == digits]
s_misclassified = s_all[class_predicted != digits]
digits_wellclassified = digits[class_predicted == digits]
digits_misclassified = digits[class_predicted != digits]


style_min_vec = s_all.min(dim=0).values
style_min = styleSpace_vec2dict(style_min_vec)
style_max_vec = s_all.max(dim=0).values
style_max = styleSpace_vec2dict(style_max_vec)
style_std_vec = s_all.std(dim=0)

print('Accuracy: {:.2f}% ; {}/{} misclassified samples'.format(100 * w_wellclassified.shape[0] / (w_wellclassified.shape[0] + w_misclassified.shape[0]), w_misclassified.shape[0], (w_wellclassified.shape[0] + w_misclassified.shape[0])))

In [None]:
n_samples = 10000

z_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(z_all[:n_samples].cpu().numpy())
w_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(w_all[:n_samples].cpu().numpy())
s_embedded = TSNE(n_components=2, learning_rate='auto', init='random').fit_transform(s_all[:n_samples].cpu().numpy())

colors = ['C0' if class_predicted[i] == digits[i] else 'C1' for i in range(n_samples)]
labels = ['well-classified' if class_predicted[i] == digits[i] else 'misclassified' for i in range(n_samples)]

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('t-SNE')
axs[0].scatter(z_embedded[:, 0], z_embedded[:, 1], c=colors, alpha=0.3)
axs[0].set_title('in Z space')
axs[1].scatter(w_embedded[:, 0], w_embedded[:, 1], c=colors, alpha=0.3)
axs[1].set_title('in W space')
axs[2].scatter(s_embedded[:, 0], s_embedded[:, 1], c=colors, alpha=0.3)
axs[2].set_title('in S space')
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])


wellclassified = (class_predicted == digits)[:n_samples].cpu().numpy()

fig, axs = plt.subplots(1, 3, figsize=(15, 5))
fig.suptitle('t-SNE')

axs[0].set_title('in Z space')
axs[0].scatter(z_embedded[wellclassified, 0], z_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.1)
axs[0].scatter(z_embedded[np.logical_not(wellclassified), 0], z_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.1)
axs[1].set_title('in W space')
axs[1].scatter(w_embedded[wellclassified, 0], w_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.1)
axs[1].scatter(w_embedded[np.logical_not(wellclassified), 0], w_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.1)
axs[2].set_title('in S space')
axs[2].scatter(s_embedded[wellclassified, 0], s_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.1)
axs[2].scatter(s_embedded[np.logical_not(wellclassified), 0], s_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.1)
for ax in axs:
    ax.set_xticks([])
    ax.set_yticks([])
    ax.legend(loc='lower right')

In [None]:
n_images = 8

imgs = generate_img_from_s(s_wellclassified[:n_images])
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('well-classified')

imgs = generate_img_from_s(s_misclassified[:n_images])
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('misclassified')

In [None]:
wellclassified_center = s_wellclassified.mean(0).unsqueeze(0)

# All directions
dist_wellclassified2center = torch.cdist(s_wellclassified, wellclassified_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified, wellclassified_center).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from well-classified center using all dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from well-classified center using all dimensions')
axs[1].set_ylabel('accuracy [%]')

# Top perfo directions 
# NORMALIZED
perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
# ORIGINAL
# perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
# STYLESPACE
# s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
# perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
top_k_dims = (-np.abs(perfo_direction)).argsort()[:10]


dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from well-classified center using top perfo dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from well-classified center using top perfo dimensions')
axs[1].set_ylabel('accuracy [%]')

# # Top softmax directions
# top_k = 10
# top_k_dims = styleSpace_grads_vec_all.mean(axis=0).abs().topk(top_k).indices.cpu().numpy()

# dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
# dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

# hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
# hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
# distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
# accuracy = 100 * hist_well / (hist_well+hist_mis)

# fig, axs = plt.subplots(1, 2, figsize=(15, 5))
# axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
# axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
# axs[0].legend()
# axs[0].set_title('distance from well-classified center using top softmax dimensions')
# axs[1].plot(distance, accuracy)
# axs[1].set_xlabel('distance from well-classified center using top softmax dimensions')
# axs[1].set_ylabel('accuracy [%]')

In [None]:
s_center = s_all.mean(0).unsqueeze(0)

# All directions
dist_wellclassified2center = torch.cdist(s_wellclassified, s_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified, s_center).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from global center using all dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from global center using all dimensions')
axs[1].set_ylabel('accuracy [%]')

# Top perfo directions NORMALIZED VECTORS
perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
# Top perfo directions
# perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
# Top perfo directions LIKE STYLESPACE
# s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
# perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
top_k_dims = (-np.abs(perfo_direction)).argsort()[:10]

dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from global center using top perfo dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from global center using top perfo dimensions')
axs[1].set_ylabel('accuracy [%]')

# # Top softmax directions
# top_k = 10
# top_k_dims = styleSpace_grads_vec_all.mean(axis=0).abs().topk(top_k).indices.cpu().numpy()

# dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()
# dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()

# hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
# hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
# distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
# accuracy = 100 * hist_well / (hist_well+hist_mis)

# fig, axs = plt.subplots(1, 2, figsize=(15, 5))
# axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
# axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
# axs[0].legend()
# axs[0].set_title('distance from global center using top softmax dimensions')
# axs[1].plot(distance, accuracy)
# axs[1].set_xlabel('distance from global center using top softmax dimensions')
# axs[1].set_ylabel('accuracy [%]')

In [None]:
plt.figure()
plt.scatter(np.arange(len(perfo_direction)), np.abs(perfo_direction[(-np.abs(perfo_direction)).argsort()]))
plt.xlabel('dimension')
plt.ylabel('absolute difference')

top_k = 5
fig, axs = plt.subplots(1, top_k, figsize=(20, 5))
for i in range(top_k):
    k = (-np.abs(perfo_direction)).argsort()[i]
    
    axs[i].set_title(r'$s_{' + str(k) + '}$')
    axs[i].hist(s_wellclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='well-classified')
    axs[i].hist(s_misclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='misclassified')
    axs[i].axvline(style_min_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_min_vec[k].cpu().numpy(), 100, 'empirical min',rotation=90)
    axs[i].axvline(style_max_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_max_vec[k].cpu().numpy(), 100, 'empirical max',rotation=90)
    axs[i].legend()

In [None]:
top_k = 50
top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k]

strength = 10

s_wellclassified_shifted = s_wellclassified.clone()
for k in top_k_idxs:
    positive_direction = perfo_direction[k] >= 0
    # positive_direction = not(positive_direction)
    d = 1 if positive_direction else -1
    weight_shift = d * strength * style_std_vec[k]
    s_wellclassified_shifted[:, k] += weight_shift


n_images = 8
imgs_orig = generate_img_from_s(s_wellclassified[:n_images])
imgs_orig = postprocess_images(imgs_orig)
imgs_orig = imgs_orig * 255
imgs_orig = imgs_orig.to(torch.uint8).cpu()

imgs_corr = generate_img_from_s(s_wellclassified_shifted[:n_images])
imgs_corr = postprocess_images(imgs_corr)
imgs_corr = imgs_corr * 255
imgs_corr = imgs_corr.to(torch.uint8).cpu()

fig, axs = plt.subplots(2, 1, figsize=(15, 4))
axs[0].imshow(vutils.make_grid(imgs_orig, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[0].axis('off')
axs[0].set_title('well-classified samples')
axs[1].imshow(vutils.make_grid(imgs_corr, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[1].set_title('same samples after corruption')
axs[1].axis('off')


s_misclassified_shifted = s_misclassified.clone()
for k in top_k_idxs:
    positive_direction = perfo_direction[k] >= 0
    positive_direction = not(positive_direction)
    d = 1 if positive_direction else -1
    weight_shift = d * strength * style_std_vec[k]
    s_misclassified_shifted[:, k] += weight_shift


imgs_orig = generate_img_from_s(s_misclassified[:n_images])
imgs_orig = postprocess_images(imgs_orig)
imgs_orig = imgs_orig * 255
imgs_orig = imgs_orig.to(torch.uint8).cpu()

imgs_clean = generate_img_from_s(s_misclassified_shifted[:n_images])
imgs_clean = postprocess_images(imgs_clean)
imgs_clean = imgs_clean * 255
imgs_clean = imgs_clean.to(torch.uint8).cpu()


fig, axs = plt.subplots(2, 1, figsize=(15, 4))
axs[0].imshow(vutils.make_grid(imgs_orig, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[0].axis('off')
axs[0].set_title('misclassified samples')
axs[1].imshow(vutils.make_grid(imgs_clean, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[1].set_title('same samples after cleaning')
axs[1].axis('off')

In [None]:
strength = 10
top_k_list = [1, 10, 20, 50, 100, 150, 250, 500, 1000, 2000, 3000, s_wellclassified.shape[1]]

df = pd.Series(index=top_k_list, dtype='float64')
show_n_images = 8
imgs_visualize = {}

for top_k in top_k_list:

    top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k]

    s_wellclassified_shifted = s_wellclassified.clone()
    for k in top_k_idxs:
        positive_direction = perfo_direction[k] >= 0
        # positive_direction = not(positive_direction)
        d = 1 if positive_direction else -1
        weight_shift = d * strength * style_std_vec[k]
        s_wellclassified_shifted[:, k] += weight_shift


    # get misclassifications nb resulting from the shift
    nb_misclassifications = 0
    for s, labels in zip(DataLoader(s_wellclassified_shifted, batch_size), DataLoader(digits[class_predicted == digits], batch_size)):

        imgs = generate_img_from_s(s)
        imgs = postprocess_images(imgs)
        imgs_visualize[top_k] = imgs[:show_n_images]

        with torch.no_grad():
            digits_pred = classifier_digits(imgs)
            class_pred = F.softmax(digits_pred, dim=1).max(axis=1).indices
            nb_misclassifications += (class_pred != labels).sum()
            
    accuracy = 100 - 100 * nb_misclassifications / s_wellclassified_shifted.shape[0]
    df[top_k] = accuracy
    print('Accuracy: {:.2f}% ; {}/{} misclassified samples'.format(accuracy, nb_misclassifications, s_wellclassified_shifted.shape[0]))

plt.figure()
df.plot()

In [None]:
strength = 5
# top_k_list = [1, 10, 20, 50, 100, 150, 250, 500, 1000, 2000, 3000, s_wellclassified.shape[1]]
top_k_list = [1, 5, 10, 20, 50, 100]

df = pd.Series(index=top_k_list, dtype='float64')
show_n_images = 8
imgs_visualize = {}

for top_k in top_k_list:

    # shift images
    top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k]
    s_shifted_all = s_all.clone()
    for k in top_k_idxs:
        positive_direction = perfo_direction[k] >= 0
        # positive_direction = not(positive_direction)
        d = 1 if positive_direction else -1
        weight_shift = d * strength * style_std_vec[k]
        s_shifted_all[:, k] += weight_shift


    # get classifications after the shift
    class_predicted_after_shift = None
    for s, labels in zip(DataLoader(s_shifted_all, batch_size), DataLoader(digits, batch_size)):

        imgs = generate_img_from_s(s)
        imgs = postprocess_images(imgs)

        with torch.no_grad():
            digits_pred = classifier_digits(imgs)
            class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices

        if class_predicted_after_shift is None: imgs_visualize[top_k] = imgs[:show_n_images] # save from 1st batch
        class_predicted_after_shift = class_pred_t if class_predicted_after_shift is None else torch.cat((class_predicted_after_shift, class_pred_t))

    s_shifted_wellclassified = s_shifted_all[class_predicted_after_shift == digits]
    s_shifted_misclassified = s_shifted_all[class_predicted_after_shift != digits]
    accuracy = 100 * s_shifted_wellclassified.shape[0] / s_shifted_all.shape[0]
    df[top_k] = accuracy
    print('Accuracy: {:.2f}% ; {}/{} misclassified samples'.format(accuracy, s_shifted_misclassified.shape[0], s_shifted_all.shape[0]))

    # Top perfo directions 
    # NORMALIZED
    perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
    # ORIGINAL
    # perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
    # STYLESPACE
    # s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
    # perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
    top_k_dims = (-np.abs(perfo_direction)).argsort()[:top_k]

    dist_wellclassified2center = torch.cdist(s_shifted_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
    dist_misclassified2center = torch.cdist(s_shifted_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

    hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
    hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
    distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
    accuracy = 100 * hist_well / (hist_well+hist_mis)


    # PLOT

    fig, axs = plt.subplots(1, 2, figsize=(15, 3))
    axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
    axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
    axs[0].legend()
    axs[0].set_xlabel(f'distance from well-classified center using top {top_k} perfo dimensions')
    axs[0].set_ylabel('number of values')
    axs[1].plot(distance, accuracy)
    axs[1].set_xlabel(f'distance from well-classified center using top {top_k} perfo dimensions')
    axs[1].set_ylabel('accuracy [%]')

    plot_images(imgs_visualize[top_k])
    plt.title(top_k)
    plt.show()


plt.figure()
df.plot()

### Look at individual styles

In [None]:
s_norm_diff = (s_wellclassified - s_misclassified.mean(0)) / s_misclassified.std(0)
relevance = s_norm_diff.mean(0) / s_norm_diff.std(0)
perfo_direction = relevance.cpu().numpy()
top_k_idxs = (-np.abs(relevance.cpu().numpy())).argsort()[:top_k]
top_k_idxs

In [None]:
s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
relevance = s_norm_diff.mean(0) / s_norm_diff.std(0)
perfo_direction = relevance.cpu().numpy()
top_k_idxs = (-np.abs(relevance.cpu().numpy())).argsort()[:top_k]
top_k_idxs

In [None]:
s_norm_diff = (s_misclassified - s_all.mean(0)) / s_all.std(0)
relevance = s_norm_diff.mean(0) / s_norm_diff.std(0)
perfo_direction = relevance.cpu().numpy()
top_k_idxs = (-np.abs(relevance.cpu().numpy())).argsort()[:top_k]
top_k_idxs

In [None]:
n_images = 10


# ORIGINAL
# perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()

# NORMALIZED
perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()

# STYLESPACE
# s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
# perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()



top_k_list = [0, 1, 3, 6, 9, 12, 15, 20]
imgs_plot = torch.empty((top_k+1, n_images, 1, 28, 28))
class_softmax_all = torch.empty((top_k+1, n_images))

for i, top_k_subset in enumerate(top_k_list):

    top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k_subset]
    strength = 10

    # manipulate images
    s_wellclassified_shifted = s_wellclassified[:n_images].clone()
    for k in top_k_idxs:
        positive_direction = perfo_direction[k] >= 0
        # positive_direction = not(positive_direction)
        d = 1 if positive_direction else -1
        weight_shift = d * strength * style_std_vec[k]
        s_wellclassified_shifted[:, k] += weight_shift

    imgs_newStyle = generate_img_from_s(s_wellclassified_shifted)
    imgs_newStyle = postprocess_images(imgs_newStyle)
    if type(class_selected) == int:
        with torch.no_grad():
            digits_pred_newStyle = classifier_digits(imgs_newStyle)
        class_softmax_newStyle = F.softmax(digits_pred_newStyle, dim=1)[0, class_selected]

    # record variables
    imgs_plot[i] = imgs_newStyle.cpu()
    if type(class_selected) == int: class_softmax_all[i] = class_softmax_newStyle.cpu()


fig, axs = plt.subplots(n_images, len(top_k_list), figsize=(12, 20))
for i in range(n_images): # for each image
    for j, nb_k in enumerate(top_k_list): # for each style
        ax = axs[i, j]
        ax.imshow(imgs_plot[j, i].squeeze(), vmin=0, vmax=1, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        title = 'original' if j == 0 else f'nb top_k: {nb_k}'
        if type(class_selected) == int: title += '\n{}: {:.0f}%'.format(class_selected, 100*class_softmax_all[j, i])
        ax.set_title(title)

### images vs. distance

In [None]:
n_images = 10
top_k = 10



# Top perfo directions 
# NORMALIZED
perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
# ORIGINAL
# perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
# STYLESPACE
# s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
# perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
top_k_dims = (-np.abs(perfo_direction)).argsort()[:top_k]


dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from well-classified center using top perfo dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from well-classified center using top perfo dimensions')
axs[1].set_ylabel('accuracy [%]')



# SAMPLE IMAGES
distances_low2high = dist_wellclassified2center.squeeze().argsort()
idx_images = distances_low2high[np.random.randint(1000, size=n_images)]

s = s_wellclassified[idx_images, :]
s_digits = digits_wellclassified[idx_images]


strengths = np.linspace(0, 3, num=6)
imgs_plot = torch.empty((len(strengths), n_images, 1, 28, 28))
distances = []
confidences = []

for i, strength in enumerate(strengths):
    # manipulate images
    s_shifted = s.clone()
    for k in top_k_dims:
        positive_direction = perfo_direction[k] >= 0
        # positive_direction = not(positive_direction)
        d = 1 if positive_direction else -1
        weight_shift = d * strength * style_std_vec[k]
        s_shifted[:, k] += weight_shift

    imgs = generate_img_from_s(s_shifted)
    imgs = postprocess_images(imgs)

    digits_pred = classifier_digits(imgs)
    class_logit = digits_pred[torch.arange(n_images), s_digits]
    class_softmax = F.softmax(digits_pred, dim=1)[torch.arange(n_images), s_digits]

    imgs_plot[i] = imgs
    distances.append(torch.cdist(s_shifted[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy())
    confidences.append(class_softmax.detach().cpu().numpy())

fig, axs = plt.subplots(n_images, imgs_plot.shape[0], figsize=(12, 20))
for i in range(n_images): # for each image
    for j in range(imgs_plot.shape[0]):
        ax = axs[i, j]
        ax.imshow(imgs_plot[j, i].squeeze(), vmin=0, vmax=1, cmap='gray')
        ax.axis('off')
        ax.grid(False)
        title = 'd={:.1f} ; c={:.0f}%'.format(distances[j][i, 0], 100*confidences[j][i])
        ax.set_title(title)


# Project

In [None]:
# load data
transforms = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(32),
    torchvision.transforms.Lambda(lambda x: 255*x)
])
mnist_test = torchvision.datasets.MNIST(root='../data', train=False, transform=transforms)

idx = 3
img = mnist_test[idx][0]
digit = mnist_test[idx][1]

digits_ = torch.tensor(digit, device=device)
c = F.one_hot(digits_, G.c_dim)   

# project
projected_w_steps = project(G, img, c, device=device)
projected_w = projected_w_steps[-1][0].unsqueeze(0)

# save video
outdir = 'proj_out'

target_pil = torchvision.transforms.ToPILImage()(img.repeat((3, 1, 1))/255)
target_uint8 = np.array(target_pil, dtype=np.uint8)

# Render debug output: optional video and projected image and W vector.
os.makedirs(outdir, exist_ok=True)
video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M')
print (f'Saving optimization progress video "{outdir}/proj.mp4"')
for projected_w in projected_w_steps:
    synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
    synth_image = synth_image.repeat((1, 3, 1, 1))
    synth_image = (synth_image + 1) * (255/2)
    synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
    video.append_data(np.concatenate([target_uint8, synth_image], axis=1))
video.close()

# Save final projected frame and W vector.
target_pil.save(f'{outdir}/target.png')
projected_w = projected_w_steps[-1]
synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const')
synth_image = synth_image.repeat((1, 3, 1, 1))
synth_image = (synth_image + 1) * (255/2)
synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy()
PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png')
np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy())

In [None]:
x_target = (img.unsqueeze(0)).to(device)
x_target = x_target[:, :, 2:30, 2:30] / 255
print(classifier_digits(x_target))

x_proj = torch.from_numpy(synth_image[:, :, 0]).unsqueeze(0).unsqueeze(0).to(device)
x_proj = x_proj[:, :, 2:30, 2:30] / 255
print(classifier_digits(x_proj))

## Read projected data

In [None]:
dataset = 'mnistTest_stylegan2_blur_noise_maxSeverity3_proba50'
# dataset = 'mnistTest_stylegan2'

file = path_results / 'projected_data' / f'{dataset}_model00016_all.npz'
projected_data = np.load(file)

mnist_test = torchvision.datasets.MNIST(root='../data', train=False)

assert all(projected_data['indices'] == np.arange(len(mnist_test)))
assert all(projected_data['digits'] == mnist_test.targets.cpu().numpy())

# projected_data['indices']
projected_w = torch.from_numpy(projected_data['projected_w']).to(device)
digits = torch.from_numpy(projected_data['digits']).to(device)
# projected_w[0]


# PREDICT CLASSES OF PROJECTED DATA

s_proj_all = None
class_predicted = None
batch_size = 64
for w, labels in zip(DataLoader(projected_w, batch_size=batch_size), DataLoader(digits, batch_size=batch_size)):
    batch_size_t = len(labels)

    with torch.no_grad():
        # compute output
        imgs = G.synthesis(w.unsqueeze(1).repeat((1, G.num_ws, 1)), noise_mode='const', force_fp32=True)
        imgs = postprocess_images(imgs)

        digits_pred = classifier_digits(imgs)
        class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices
        class_softmax = F.softmax(digits_pred, dim=1)[torch.arange(batch_size_t), labels]
        
    # style values (styleSpace_values from hook)
    s_vec = styleSpace_dict2vec(styleSpace_values)

    s_proj_all = s_vec if s_proj_all is None else torch.cat((s_proj_all, s_vec))
    class_predicted = class_pred_t if class_predicted is None else torch.cat((class_predicted, class_pred_t))

s_proj_wellclassified = s_proj_all[class_predicted == digits]
s_proj_misclassified = s_proj_all[class_predicted != digits]


# PREDICT CLASSES OF ORIGINAL DATA AND COMPARE WITH PROJECTED DATA

path_data = Path.cwd().parent / 'data/MNIST' / f'{dataset}.zip'
ds = ImageFolderDataset(path_data, use_labels=True)

class_predicted_mnist = None
for x, y in DataLoader(ds, batch_size=64):
    x = (x / 255)[:, :, 2:30, 2:30]
    digits_pred = classifier_digits(x.to(device))
    class_pred_t = F.softmax(digits_pred, dim=1).max(axis=1).indices
    class_predicted_mnist = class_pred_t if class_predicted_mnist is None else torch.cat((class_predicted_mnist, class_pred_t))

tp_idx = class_predicted[class_predicted == digits] == class_predicted_mnist[class_predicted == digits]
tp = tp_idx.sum()
tn_idx = class_predicted[class_predicted != digits] == class_predicted_mnist[class_predicted != digits]
tn = tn_idx.sum()
fp_idx = class_predicted[class_predicted == digits] != class_predicted_mnist[class_predicted == digits]
fp = fp_idx.sum()
fn_idx = class_predicted[class_predicted != digits] != class_predicted_mnist[class_predicted != digits]
fn = fn_idx.sum()

assert tp + tn + fp + fn == len(class_predicted)

print('True positive:', tp.item())
print('True negative:', tn.item())
print('False positive:', fp.item())
print('False negative:', fn.item())


# FILTER OUT WRONG SAMPLES (FALSE POSITIVE OR NEGATIVE)
s_proj_wellclassified_fp = s_proj_wellclassified[fp_idx]
s_proj_misclassified_fn = s_proj_misclassified[fn_idx]
s_proj_wellclassified = s_proj_wellclassified[tp_idx]
s_proj_misclassified = s_proj_misclassified[tn_idx]


# PLOT SAMPLES
n_images = 8
plot_images_from_s(s_proj_wellclassified[:n_images])
plt.title('well-classified (correct)')

plot_images_from_s(s_proj_wellclassified_fp[:n_images])
plt.title('well-classified (wrong)')

plot_images_from_s(s_proj_misclassified[:n_images])
plt.title('misclassified (correct)')

plot_images_from_s(s_proj_misclassified_fn[:n_images])
plt.title('misclassified (wrong)')





In [None]:
plt.figure()
plt.imshow(mnist_test[3][0], cmap='gray')
plt.xticks([])
plt.yticks([])


plot_images_from_s(s_proj_all[3])


In [None]:
wellclassified_center = s_wellclassified.mean(0).unsqueeze(0)

# All directions
dist_wellclassified2center = torch.cdist(s_proj_wellclassified, wellclassified_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_proj_misclassified, wellclassified_center).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from well-classified center using all dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from well-classified center using all dimensions')
axs[1].set_ylabel('accuracy [%]')

# Top perfo directions 
# NORMALIZED
perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
# ORIGINAL
# perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
# STYLESPACE
# s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
# perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
top_k_dims = (-np.abs(perfo_direction)).argsort()[:10]


dist_wellclassified2center = torch.cdist(s_proj_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
dist_misclassified2center = torch.cdist(s_proj_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

fig, axs = plt.subplots(1, 2, figsize=(15, 5))
axs[0].hist(dist_wellclassified2center, bins=20, edgecolor='none', alpha=0.8, label='well-classified')
axs[0].hist(dist_misclassified2center, bins=20, edgecolor='none', alpha=0.8, label='misclassified')
axs[0].legend()
axs[0].set_title('distance from well-classified center using top perfo dimensions')
axs[1].plot(distance, accuracy)
axs[1].set_xlabel('distance from well-classified center using top perfo dimensions')
axs[1].set_ylabel('accuracy [%]')

### Move projected image to the well-classified center

In [None]:
n_images = 8
s = s_proj_misclassified[:n_images]
wellclassified_center

s_ = truncate(s, wellclassified_center, 0.9)

plot_images_from_s(s_)


In [None]:
top_k = 20
top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k]

strength = 10

s = s_proj_wellclassified[:n_images]
s_shifted = s.clone()

for k in top_k_idxs:
    positive_direction = perfo_direction[k] >= 0
    # positive_direction = not(positive_direction)
    d = 1 if positive_direction else -1
    weight_shift = d * strength * style_std_vec[k]
    s_shifted[:, k] += weight_shift


n_images = 8
imgs = generate_img_from_s(s)
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('well classified samples')

imgs = generate_img_from_s(s_shifted)
imgs = postprocess_images(imgs)
plot_images(imgs)
plt.title('well classified samples corrupted')

# Figures for paper

In [None]:
SMALL_SIZE = 12
MEDIUM_SIZE = 15
BIGGER_SIZE = 15

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

### samples real vs generated images

In [None]:
n_images = 8

# REAL IMAGES

dataset = 'mnist_stylegan2_blur_noise_maxSeverity3_proba50'
path_data = Path.cwd().parent / 'data/MNIST' / f'{dataset}.zip'
ds = ImageFolderDataset(path_data, use_labels=True)

rand_idx = np.random.randint(len(ds), size=n_images)

real_images = None
for i in rand_idx:
    img = torch.tensor(ds[i][0][:, 2:30, 2:30]).unsqueeze(0)
    real_images = img if real_images is None else torch.cat((real_images, img))

plt.figure()
plt.imshow(vutils.make_grid(real_images.cpu(), pad_value=255).permute(1,2,0), vmin=0, vmax=255)
plt.axis('off')
plt.grid(False)


# GENERATED IMAGES

z = torch.randn([n_images, G.z_dim], device=device)    # latent codes
if conditional:
    digits_ = torch.randint(0, G.c_dim, (n_images,), device=device)
    c = F.one_hot(digits_, G.c_dim)          
else:
    c = None

ws = G.mapping(z, c, truncation_psi=1)
img = G.synthesis(ws, noise_mode='const', force_fp32=True)
img = postprocess_images(img)
plot_images(img)

### t-SNE

In [None]:
plt.figure(figsize=(4, 4))
plt.scatter(z_embedded[wellclassified, 0], z_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.2)
plt.scatter(z_embedded[np.logical_not(wellclassified), 0], z_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.2)
plt.legend(loc="lower left", fontsize='x-large')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(path_results / 'figures' / 'tsne_z')

plt.figure(figsize=(4, 4))
plt.scatter(w_embedded[wellclassified, 0], w_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.2)
plt.scatter(w_embedded[np.logical_not(wellclassified), 0], w_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.2)
plt.legend(loc="lower left", fontsize='x-large')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(path_results / 'figures' / 'tsne_w')

plt.figure(figsize=(4, 4))
plt.scatter(s_embedded[wellclassified, 0], s_embedded[wellclassified, 1], c='C0', label='well-classified', alpha=0.2)
plt.scatter(s_embedded[np.logical_not(wellclassified), 0], s_embedded[np.logical_not(wellclassified), 1], c='C1', label='misclassified', alpha=0.2)
plt.legend(loc="lower left", fontsize='x-large')
plt.xticks([])
plt.yticks([])
plt.tight_layout()
plt.savefig(path_results / 'figures' / 'tsne_s')

### histograms dimensions S

In [None]:
top_k = 3
fig, axs = plt.subplots(1, top_k, figsize=(13, 4))
for i in range(top_k):
    k = (-np.abs(perfo_direction)).argsort()[i]
    
    # axs[i].set_title(r'$s_{' + str(k) + '}$')
    axs[i].hist(s_wellclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='well-classified')
    axs[i].hist(s_misclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='misclassified')
    axs[i].set_ylabel('number of samples')
    axs[i].set_xlabel(r'value for dimension $s_{' + str(k) + '}$')
    # axs[i].axvline(style_min_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_min_vec[k].cpu().numpy(), 100, 'empirical min',rotation=90)
    # axs[i].axvline(style_max_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_max_vec[k].cpu().numpy(), 100, 'empirical max',rotation=90)
    axs[i].legend()
plt.tight_layout()
plt.savefig(path_results / 'figures' / 'hist_top_dims')

In [None]:
rnd_idx = np.random.randint(len(perfo_direction), size=3)

fig, axs = plt.subplots(1, top_k, figsize=(13, 4))
for i, idx in enumerate(rnd_idx):
    k = (-np.abs(perfo_direction)).argsort()[idx]
    
    # axs[i].set_title(r'$s_{' + str(k) + '}$')
    axs[i].hist(s_wellclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='well-classified')
    axs[i].hist(s_misclassified[:, k].cpu().numpy(), bins=20, edgecolor='none', alpha=0.8, label='misclassified')
    axs[i].set_ylabel('number of samples')
    axs[i].set_xlabel(r'value for dimension $s_{' + str(k) + '}$')
    # axs[i].axvline(style_min_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_min_vec[k].cpu().numpy(), 100, 'empirical min',rotation=90)
    # axs[i].axvline(style_max_vec[k].cpu().numpy(), color='k', ls='--')
    # plt.text(1.1*style_max_vec[k].cpu().numpy(), 100, 'empirical max',rotation=90)
    axs[i].legend()
plt.tight_layout()
plt.savefig(path_results / 'figures' / 'hist_random_dims')

### accuracy vs distance

In [None]:
cut = 0.7
top_k_list = [10, 100, 2000]

wellclassified_center = s_wellclassified.mean(0).unsqueeze(0)

plt.figure(figsize=(10,5))
plt.xlabel('(normalized) distance from well-classified center')
plt.ylabel('accuracy [%]')

for top_k in top_k_list:
    # Top perfo directions 
    # NORMALIZED
    perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
    # ORIGINAL
    # perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
    # STYLESPACE
    # s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
    # perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
    top_k_dims = (-np.abs(perfo_direction)).argsort()[:top_k]


    dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()
    dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], wellclassified_center[:, top_k_dims]).cpu().numpy()

    # Normalize
    dist_misclassified2center -= dist_wellclassified2center.min()
    dist_wellclassified2center -= dist_wellclassified2center.min()
    dist_misclassified2center /= cut * dist_wellclassified2center.max()
    dist_wellclassified2center /= cut * dist_wellclassified2center.max()

    hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
    hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
    distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
    accuracy = 100 * hist_well / (hist_well+hist_mis)

    plt.plot(distance[distance<1], accuracy[distance<1], label=f'using top {top_k} dimensions')
    
# All directions
dist_wellclassified2center = torch.cdist(s_wellclassified, wellclassified_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified, wellclassified_center).cpu().numpy()

# Normalize
dist_misclassified2center -= dist_wellclassified2center.min()
dist_wellclassified2center -= dist_wellclassified2center.min()
dist_misclassified2center /= cut * dist_wellclassified2center.max()
dist_wellclassified2center /= cut * dist_wellclassified2center.max()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

plt.plot(distance[distance<1], accuracy[distance<1], label='using all dimensions')
plt.legend()
plt.savefig(path_results / 'figures' / 'accuracy_vs_distance_wellclassified_center')

In [None]:
distances = [0.1, 0.3, 0.5, 0.7, 0.9]
n_images = 8

dist_all2center = torch.cdist(s_all, wellclassified_center).cpu().numpy()
dist_all2center -= dist_all2center.min()
dist_all2center /= cut * dist_all2center.max()

for d in distances:

    idx = (np.abs(dist_all2center - d) <= 0.01).squeeze()
    s = s_all[idx][:n_images]



    break

In [None]:
digits

In [None]:
s_center = s_all.mean(0).unsqueeze(0)

plt.figure(figsize=(10,5))
plt.xlabel('(normalized) distance from global center')
plt.ylabel('accuracy [%]')

for top_k in top_k_list:
    # Top perfo directions 
    # NORMALIZED
    perfo_direction = (((s_misclassified - s_all.mean(0)) / s_all.std(0)).mean(0) - ((s_wellclassified - s_all.mean(0)) / s_all.std(0)).mean(0)).cpu().numpy()
    # ORIGINAL
    # perfo_direction = (s_misclassified.mean(0) - s_wellclassified.mean(0)).cpu().numpy()
    # STYLESPACE
    # s_norm_diff = (s_misclassified - s_wellclassified.mean(0)) / s_wellclassified.std(0)
    # perfo_direction = (s_norm_diff.mean(0) / s_norm_diff.std(0)).cpu().numpy()
    top_k_dims = (-np.abs(perfo_direction)).argsort()[:top_k]


    dist_wellclassified2center = torch.cdist(s_wellclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()
    dist_misclassified2center = torch.cdist(s_misclassified[:, top_k_dims], s_center[:, top_k_dims]).cpu().numpy()

    # Normalize
    dist_misclassified2center -= dist_wellclassified2center.min()
    dist_wellclassified2center -= dist_wellclassified2center.min()
    dist_misclassified2center /= cut * dist_wellclassified2center.max()
    dist_wellclassified2center /= cut * dist_wellclassified2center.max()

    hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
    hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
    distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
    accuracy = 100 * hist_well / (hist_well+hist_mis)

    plt.plot(distance[distance<1], accuracy[distance<1], label=f'using top {top_k} dimensions')
    
# All directions
dist_wellclassified2center = torch.cdist(s_wellclassified, s_center).cpu().numpy()
dist_misclassified2center = torch.cdist(s_misclassified, s_center).cpu().numpy()

# Normalize
dist_misclassified2center -= dist_wellclassified2center.min()
dist_wellclassified2center -= dist_wellclassified2center.min()
dist_misclassified2center /= cut * dist_wellclassified2center.max()
dist_wellclassified2center /= cut * dist_wellclassified2center.max()

hist_well, bins = np.histogram(dist_wellclassified2center, bins=20)
hist_mis, _ = np.histogram(dist_misclassified2center, bins=bins)
distance = np.array([(bins[i]+bins[i+1])/2 for i in range(len(bins)-1)])
accuracy = 100 * hist_well / (hist_well+hist_mis)

plt.plot(distance[distance<1], accuracy[distance<1], label='using all dimensions')
plt.legend()
plt.savefig(path_results / 'figures' / 'accuracy_vs_distance_global_center')

### Image manipulation : visualize corruption

In [None]:
top_k = 10
strength = 5
n_images = 8

top_k_idxs = (-np.abs(perfo_direction)).argsort()[:top_k]

# ORIGINAL IMAGES
imgs_orig = generate_img_from_s(s_wellclassified[:n_images])
imgs_orig = postprocess_images(imgs_orig)
digits_pred = classifier_digits(imgs_orig)
class_logit = digits_pred[torch.arange(n_images), digits_wellclassified[:n_images]]
class_softmax_orig = F.softmax(digits_pred, dim=1)[torch.arange(n_images), digits_wellclassified[:n_images]]

# CORRUPT
s_wellclassified_shifted = s_wellclassified[:n_images].clone()
for k in top_k_idxs:
    positive_direction = perfo_direction[k] >= 0
    # positive_direction = not(positive_direction)
    d = 1 if positive_direction else -1
    weight_shift = d * strength * style_std_vec[k]
    s_wellclassified_shifted[:, k] += weight_shift

imgs_corr = generate_img_from_s(s_wellclassified_shifted)
imgs_corr = postprocess_images(imgs_corr)
digits_pred = classifier_digits(imgs_corr)
class_logit = digits_pred[torch.arange(n_images), digits_wellclassified[:n_images]]
class_softmax_corr = F.softmax(digits_pred, dim=1)[torch.arange(n_images), digits_wellclassified[:n_images]]

# CLEAN
s_wellclassified_shifted = s_wellclassified[:n_images].clone()
for k in top_k_idxs:
    positive_direction = perfo_direction[k] >= 0
    positive_direction = not(positive_direction)
    d = 1 if positive_direction else -1
    weight_shift = d * strength * style_std_vec[k]
    s_wellclassified_shifted[:, k] += weight_shift

imgs_clean = generate_img_from_s(s_wellclassified_shifted)
imgs_clean = postprocess_images(imgs_clean)
digits_pred = classifier_digits(imgs_clean)
class_logit = digits_pred[torch.arange(n_images), digits_wellclassified[:n_images]]
class_softmax_clean = F.softmax(digits_pred, dim=1)[torch.arange(n_images), digits_wellclassified[:n_images]]

# PLOT
imgs_orig = imgs_orig * 255
imgs_orig = imgs_orig.to(torch.uint8).cpu()
imgs_corr = imgs_corr * 255
imgs_corr = imgs_corr.to(torch.uint8).cpu()
imgs_clean = imgs_clean * 255
imgs_clean = imgs_clean.to(torch.uint8).cpu()

fig, axs = plt.subplots(3, 1, figsize=(10, 5))
axs[0].imshow(vutils.make_grid(imgs_orig, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[0].axis('off')
axs[0].set_title('well-classified samples')
axs[1].imshow(vutils.make_grid(imgs_corr, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[1].set_title(f'after corruption (top_k={top_k}; strength={strength})')
axs[1].axis('off')
axs[2].imshow(vutils.make_grid(imgs_clean, pad_value=255).permute(1,2,0), vmin=0, vmax=255)
axs[2].set_title(f'after cleaning (top_k={top_k}; strength={strength})')
axs[2].axis('off')
# plt.tight_layout()
plt.savefig(path_results / 'figures' / 'image_manipulation', bbox_inches='tight')
plt.show()

print(class_softmax_orig)
print(class_softmax_corr)
print(class_softmax_clean)



### images vs distance