In [181]:
import torch
from torch import nn
from src.data import get_data
from src.models import get_encoder, get_decoder
from src.utils import softclip
from src.hessian import laplace
import yaml
from copy import deepcopy
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
import matplotlib.pyplot as plt
import numpy as np
import os
import warnings
from tqdm import tqdm

def get_model(encoder, decoder):
    
    net = deepcopy(encoder.encoder._modules)
    decoder = decoder.decoder._modules
    max_ = max([int(i) for i in net.keys()])
    for i in decoder.keys():
        net.update({f"{max_+int(i) + 1}": decoder[i]})

    return nn.Sequential(net)

# Load checkpoints for the different models

In [182]:
path_ae = "../weights/celeba/[no_conv_False]_[use_var_decoder_False]_"
with open(f"{path_ae}/config.yaml") as file:
    config = yaml.full_load(file)
ae_encoder = get_encoder(config, latent_size=config['latent_size'])
ae_encoder.load_state_dict(torch.load(f"{path_ae}/encoder.pth"))
ae_encoder.eval()

path_mcae = "../weights/celeba/mcdropout_ae/[no_conv_False]_[dropout_rate_0.2]_[use_var_decoder_False]_"
with open(f"{path_mcae}/config.yaml") as file:
    config = yaml.full_load(file)
mcae_encoder = get_encoder(config, latent_size=config['latent_size'], dropout=config["dropout_rate"])
mcae_encoder.load_state_dict(torch.load(f"{path_mcae}/encoder.pth"))

path_vae = "../weights/celeba/vae_[use_var_dec=False]/[no_conv_False]_[use_var_decoder_False]_"
with open(f"{path_vae}/config.yaml") as file:
    config = yaml.full_load(file)
vae_encoder_mu = get_encoder(config, latent_size=config['latent_size'])
vae_encoder_mu.load_state_dict(torch.load(f"{path_vae}/mu_encoder.pth"))
vae_encoder_var = get_encoder(config, latent_size=config['latent_size'])
vae_encoder_var.load_state_dict(torch.load(f"{path_vae}/var_encoder.pth"))

path_lae= "../weights/celeba/[backend_layer]_[approximation_mix]_[no_conv_False]_[train_samples_1]_/[backend_layer]_[approximation_mix]_[no_conv_False]_[train_samples_1]_"
with open(f"{path_lae}/config.yaml") as file:
    config = yaml.full_load(file)
lae_encoder = get_encoder(config, latent_size=config['latent_size'])
lae_decoder = get_decoder(config, latent_size=config['latent_size'])
lae_net = get_model(lae_encoder, lae_decoder).eval()
lae_net.load_state_dict(torch.load(f"{path_lae}/net.pth"))
mu_q = parameters_to_vector(lae_net.parameters()).cuda()
l = laplace.DiagLaplace()
h = torch.load(f"{path_lae}/hessian.pth").cpu()
sigma_q = l.posterior_scale(h).cuda()

# Construct function for loading subset of validation data

In [183]:
_, val_dataloader = get_data("celeba")
images, labels = [ ], [ ]

for batch_idx, batch in enumerate(val_dataloader):
    if batch_idx > 10:
        break
    x, y = batch
    images.append(x)
    labels.append(y)
images = torch.cat(images, 0)
labels = torch.cat(labels, 0)

print(images.shape)
print(labels.shape)

def get_dataset(n_select):
    idx = torch.randperm(images.shape[0])
    return images[idx[:n_select]], labels[idx[:n_select]], images[idx[n_select:]], labels[idx[n_select:]]

selected_images, selected_labels, eval_set_images, eval_set_labels = get_dataset(5)
print(selected_images.shape)
print(selected_labels.shape)
print(eval_set_images.shape)
print(eval_set_labels.shape)

torch.Size([352, 3, 64, 64])
torch.Size([352, 40])
torch.Size([5, 3, 64, 64])
torch.Size([5, 40])
torch.Size([347, 3, 64, 64])
torch.Size([347, 40])


# Construct functions for embedding data into representation space for the different models

In [184]:
def get_encoding_ae(data, n=1):
    return ae_encoder(data)

print('standard ae')
print(get_encoding_ae(images[:1], n=2).shape)

def get_encoding_mcae(data, n=1):
    return torch.cat([mcae_encoder(data) for _ in range(n)], 0)

print('monte-carlo ae')
print(get_encoding_mcae(images[:1], n=2).shape)

def get_encoding_vae(data, n=1):
    mu = vae_encoder_mu(data)
    sigma = torch.exp(softclip(vae_encoder_var(data), min=-3))
    return torch.cat([mu + torch.randn_like(sigma) * sigma for _ in range(n)])

print('vae')
print(get_encoding_vae(images[:1], n=2).shape)

def sample():
    return (mu_q + torch.randn_like(sigma_q) * sigma_q).cpu()
    
def get_encoding_lae(data, n=1):
    embeddings = [ ]
    def fw_hook_get_latent(module, input, output):
        embeddings.append(output.detach().cpu())
    hook = lae_net[20].register_forward_hook(fw_hook_get_latent)

    if n == 1:
        _ = lae_net(data)
        hook.remove()
        return embeddings[0]

    for i in range(n):
        vector_to_parameters(sample(), lae_net.parameters())
        _ = lae_net(data)
    hook.remove()
    return torch.cat(embeddings, 0)

print('lae')
print(get_encoding_lae(images[:1], n=2).shape)

standard ae
torch.Size([1, 128])
monte-carlo ae
torch.Size([2, 128])
vae
torch.Size([2, 128])
lae
torch.Size([2, 128])


# Run experiment

In [185]:
with torch.inference_mode():
    scale_factor = 100
    names = ['ae', 'mcae', 'vae', 'lae']
    label_size = [10]
    reps = 5

    res = np.zeros((reps, len(label_size), len(names), 40))
    for r in range(reps):
        for j, (name, encoder) in enumerate(zip(names, [get_encoding_ae, get_encoding_mcae, get_encoding_vae, get_encoding_lae])):
            for i, n_select in enumerate(label_size):
                print(r,i,j)
                selected_images, selected_labels, eval_set_images, eval_set_labels = get_dataset(n_select)
                eval_data = torch.cat([encoder(d.unsqueeze(0)) for d in eval_set_images], 0)
                train_data = torch.cat([encoder(d.unsqueeze(0), scale_factor) for d in selected_images], 0)
                train_labels = selected_labels.repeat_interleave(scale_factor, 0) if name != 'ae' else selected_labels
                with warnings.catch_warnings():
                    warnings.filterwarnings("ignore")
                    classifier = GridSearchCV(
                        KNeighborsClassifier(),
                        {'n_neighbors': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]},
                        cv=2,
                        refit=True,
                    )
                classifier.fit(train_data.detach().numpy(), train_labels.detach().numpy())
                preds = classifier.predict(eval_data.detach().numpy())
                acc = (preds == eval_set_labels.numpy()).mean(axis=0)
                res[r,i,j,:] = acc

0 0 0


Traceback (most recent call last):
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\model_selection\_validation.py", line 761, in _score
    scores = scorer(estimator, X_test, y_test)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\metrics\_scorer.py", line 418, in _passthrough_scorer
    return estimator.score(*args, **kwargs)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\base.py", line 651, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_classification.py", line 214, in predict
    neigh_dist, neigh_ind = self.kneighbors(X)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_base.py", line 727, in kneighbors
    raise ValueError(
ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

Traceback (most recent call last):
  File "c:\Users\nsde\A

0 0 1
0 0 2
0 0 3
1 0 0


Traceback (most recent call last):
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\model_selection\_validation.py", line 761, in _score
    scores = scorer(estimator, X_test, y_test)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\metrics\_scorer.py", line 418, in _passthrough_scorer
    return estimator.score(*args, **kwargs)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\base.py", line 651, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_classification.py", line 214, in predict
    neigh_dist, neigh_ind = self.kneighbors(X)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_base.py", line 727, in kneighbors
    raise ValueError(
ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

Traceback (most recent call last):
  File "c:\Users\nsde\A

1 0 1
1 0 2
1 0 3
2 0 0


Traceback (most recent call last):
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\model_selection\_validation.py", line 761, in _score
    scores = scorer(estimator, X_test, y_test)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\metrics\_scorer.py", line 418, in _passthrough_scorer
    return estimator.score(*args, **kwargs)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\base.py", line 651, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_classification.py", line 214, in predict
    neigh_dist, neigh_ind = self.kneighbors(X)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_base.py", line 727, in kneighbors
    raise ValueError(
ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

Traceback (most recent call last):
  File "c:\Users\nsde\A

2 0 1
2 0 2
2 0 3
3 0 0


Traceback (most recent call last):
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\model_selection\_validation.py", line 761, in _score
    scores = scorer(estimator, X_test, y_test)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\metrics\_scorer.py", line 418, in _passthrough_scorer
    return estimator.score(*args, **kwargs)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\base.py", line 651, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_classification.py", line 214, in predict
    neigh_dist, neigh_ind = self.kneighbors(X)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_base.py", line 727, in kneighbors
    raise ValueError(
ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

Traceback (most recent call last):
  File "c:\Users\nsde\A

3 0 1
3 0 2
3 0 3
4 0 0


Traceback (most recent call last):
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\model_selection\_validation.py", line 761, in _score
    scores = scorer(estimator, X_test, y_test)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\metrics\_scorer.py", line 418, in _passthrough_scorer
    return estimator.score(*args, **kwargs)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\base.py", line 651, in score
    return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_classification.py", line 214, in predict
    neigh_dist, neigh_ind = self.kneighbors(X)
  File "c:\Users\nsde\Anaconda3\envs\laplace\lib\site-packages\sklearn\neighbors\_base.py", line 727, in kneighbors
    raise ValueError(
ValueError: Expected n_neighbors <= n_samples,  but n_samples = 5, n_neighbors = 6

Traceback (most recent call last):
  File "c:\Users\nsde\A

4 0 1
4 0 2
4 0 3


# Extract overall accuracy

In [191]:
np.save('results_celeba_all.npy', res)
print(res.shape)
print(res.mean(0).mean(-1))

(5, 1, 4, 40)
(5, 4, 40)
(4, 40)
[0.75492203 0.74271442 0.73564815 0.74064327]
[0.7379386  0.72457602 0.73108187 0.74064327]


# Per label accuracy

In [195]:
import pandas as pd
csv_file = pd.read_csv("../data/celeba/list_attr_celeba.txt", index_col=0)
list(csv_file.keys())[1:]
df=pd.DataFrame(data=res, columns=list(csv_file.keys())[1:])
print(df.T)

                            0         1         2         3
5_o_Clock_Shadow     0.872320  0.754386  0.860624  0.795906
Arched_Eyebrows      0.506823  0.521442  0.549708  0.605848
Attractive           0.515595  0.500975  0.499025  0.536842
Bags_Under_Eyes      0.797271  0.699805  0.752437  0.682456
Bald                 0.985380  0.985380  0.985380  0.985380
Bangs                0.815789  0.827485  0.800195  0.801754
Big_Lips             0.607212  0.449318  0.524366  0.550292
Big_Nose             0.676413  0.769981  0.701754  0.633333
Black_Hair           0.680312  0.650097  0.581871  0.605263
Blond_Hair           0.804094  0.857700  0.793372  0.728070
Blurry               0.949318  0.950292  0.935673  0.943860
Brown_Hair           0.616959  0.636452  0.713450  0.577778
Bushy_Eyebrows       0.819688  0.751462  0.804094  0.783041
Chubby               0.883041  0.923977  0.846004  0.912865
Double_Chin          0.944444  0.897661  0.886940  0.873099
Eyeglasses           0.885965  0.899610 