In [1]:
import torch

In [2]:
from vae import test_nll_estimation, parse_arguments

In [3]:
from pathlib import Path

In [4]:
args = dict(
    data_path='../Data/Dataset',
    device='cuda' if torch.cuda.is_available() else 'cpu',
    embedding_size=300,
    hidden_size=256,
    latent_size=16,
    num_layers=1,
    word_dropout=1.0,
    freebits=None,
    model_save_path='models',
    batch_size_valid=64,
    num_samples=10,
)

In [11]:
def get_approx_nlls(saved_model_files):
    nlls = []
    for saved_model_file in saved_model_files:
        nll = test_nll_estimation(saved_model_file=saved_model_file, **args)
        nlls.append(nll)
    return nlls

In [6]:
def get_model_filenames(experiment_name='vanilla'):
    saved_model_files = [next(Path(f"results_final/results{i}/{experiment_name}/models/").iterdir()) for i in range(4)]
    saved_model_files = [str(f) for f in saved_model_files]
    return saved_model_files

In [7]:
num_samples = 10

# Vanilla NLL

In [8]:
saved_model_files = get_model_filenames('vanilla')
saved_model_files

['results_final/results0/vanilla/models/sentence_vae_3500.pt',
 'results_final/results1/vanilla/models/sentence_vae_3500.pt',
 'results_final/results2/vanilla/models/sentence_vae_3500.pt',
 'results_final/results3/vanilla/models/sentence_vae_3500.pt']

In [None]:
vanilla_nlls = get_approx_nlls(saved_model_files)

In [13]:
print(vanilla_nlls)

[tensor(126.9911, device='cuda:0'), tensor(127.1313, device='cuda:0'), tensor(127.0736, device='cuda:0'), tensor(126.7324, device='cuda:0')]


# Word dropout

In [14]:
saved_model_files = get_model_filenames('word_dropout_066')
saved_model_files

['results_final/results0/word_dropout_066/models/sentence_vae_6000.pt',
 'results_final/results1/word_dropout_066/models/sentence_vae_4500.pt',
 'results_final/results2/word_dropout_066/models/sentence_vae_6000.pt',
 'results_final/results3/word_dropout_066/models/sentence_vae_6000.pt']

In [None]:
wd_nlls = get_approx_nlls(saved_model_files)

In [16]:
print(wd_nlls)

[tensor(126.3514, device='cuda:0'), tensor(127.2843, device='cuda:0'), tensor(126.4913, device='cuda:0'), tensor(126.8849, device='cuda:0')]


# Free Bits

In [17]:
saved_model_files = get_model_filenames('freebits_05')
saved_model_files

['results_final/results0/freebits_05/models/sentence_vae_FreeBits_0.5_3500.pt',
 'results_final/results1/freebits_05/models/sentence_vae_FreeBits_0.5_3500.pt',
 'results_final/results2/freebits_05/models/sentence_vae_FreeBits_0.5_3500.pt',
 'results_final/results3/freebits_05/models/sentence_vae_FreeBits_0.5_3500.pt']

In [None]:
fb_nlls = get_approx_nlls(saved_model_files)

In [19]:
print(fb_nlls)

[tensor(121.0451, device='cuda:0'), tensor(120.7842, device='cuda:0'), tensor(120.9724, device='cuda:0'), tensor(121.1376, device='cuda:0')]


# MDR

In [20]:
saved_model_files = get_model_filenames('mdr10')
saved_model_files

['results_final/results0/mdr10/models/sentence_vae_MDR_10.0_3500.pt',
 'results_final/results1/mdr10/models/sentence_vae_MDR_10.0_3500.pt',
 'results_final/results2/mdr10/models/sentence_vae_MDR_10.0_3500.pt',
 'results_final/results3/mdr10/models/sentence_vae_MDR_10.0_3500.pt']

In [None]:
mdr_nlls = get_approx_nlls(saved_model_files)

In [22]:
print(mdr_nlls)

[tensor(119.8162, device='cuda:0'), tensor(119.8743, device='cuda:0'), tensor(120.0954, device='cuda:0'), tensor(119.5580, device='cuda:0')]


# Word dropout & Free Bits

In [23]:
saved_model_files = get_model_filenames('word_dropout_066_freebits_05')
saved_model_files

['results_final/results0/word_dropout_066_freebits_05/models/sentence_vae_FreeBits_0.5_4500.pt',
 'results_final/results1/word_dropout_066_freebits_05/models/sentence_vae_FreeBits_0.5_6000.pt',
 'results_final/results2/word_dropout_066_freebits_05/models/sentence_vae_FreeBits_0.5_6000.pt',
 'results_final/results3/word_dropout_066_freebits_05/models/sentence_vae_FreeBits_0.5_7000.pt']

In [None]:
wd_fb_nlls = get_approx_nlls(saved_model_files)

In [25]:
print(wd_fb_nlls)

[tensor(121.7690, device='cuda:0'), tensor(121.3110, device='cuda:0'), tensor(121.2281, device='cuda:0'), tensor(121.8895, device='cuda:0')]


# Word dropout & MDR

In [26]:
saved_model_files = get_model_filenames('word_dropout_066_mdr_10')
saved_model_files

['results_final/results0/word_dropout_066_mdr_10/models/sentence_vae_MDR_10.0_6000.pt',
 'results_final/results1/word_dropout_066_mdr_10/models/sentence_vae_MDR_10.0_6000.pt',
 'results_final/results2/word_dropout_066_mdr_10/models/sentence_vae_MDR_10.0_6000.pt',
 'results_final/results3/word_dropout_066_mdr_10/models/sentence_vae_MDR_10.0_4500.pt']

In [None]:
wd_mdr_nlls = get_approx_nlls(saved_model_files)

In [13]:
print(vanilla_nlls)

[tensor(126.9911, device='cuda:0'), tensor(127.1313, device='cuda:0'), tensor(127.0736, device='cuda:0'), tensor(126.7324, device='cuda:0')]


In [16]:
print(wd_nlls)

[tensor(126.3514, device='cuda:0'), tensor(127.2843, device='cuda:0'), tensor(126.4913, device='cuda:0'), tensor(126.8849, device='cuda:0')]


In [19]:
print(fb_nlls)

[tensor(121.0451, device='cuda:0'), tensor(120.7842, device='cuda:0'), tensor(120.9724, device='cuda:0'), tensor(121.1376, device='cuda:0')]


In [22]:
print(mdr_nlls)

[tensor(119.8162, device='cuda:0'), tensor(119.8743, device='cuda:0'), tensor(120.0954, device='cuda:0'), tensor(119.5580, device='cuda:0')]


In [25]:
print(wd_fb_nlls)

[tensor(121.7690, device='cuda:0'), tensor(121.3110, device='cuda:0'), tensor(121.2281, device='cuda:0'), tensor(121.8895, device='cuda:0')]


In [28]:
print(wd_mdr_nlls)

[tensor(120.9221, device='cuda:0'), tensor(121.1686, device='cuda:0'), tensor(120.5179, device='cuda:0'), tensor(121.6265, device='cuda:0')]


# Results

In [36]:
all_results = {
    "Vanilla": torch.tensor(vanilla_nlls),
    "Word dropout": torch.tensor(wd_nlls),
    "Free Bits": torch.tensor(fb_nlls),
    "MDR": torch.tensor(mdr_nlls),
    "Word dropout & Free Bits": torch.tensor(wd_fb_nlls),
    "Word dropout & MDR": torch.tensor(wd_mdr_nlls),
}

In [48]:
for name, res in all_results.items():
    std, mean = torch.std_mean(res)
    print(f"{name: >30} approximated NLL: \t {mean} \t +- {std}")

                       Vanilla approximated NLL: 	 126.98207092285156 	 +- 0.1761312633752823
                  Word dropout approximated NLL: 	 126.7529296875 	 +- 0.42009851336479187
                     Free Bits approximated NLL: 	 120.98483276367188 	 +- 0.14984405040740967
                           MDR approximated NLL: 	 119.83599090576172 	 +- 0.22093771398067474
      Word dropout & Free Bits approximated NLL: 	 121.54940795898438 	 +- 0.328605979681015
            Word dropout & MDR approximated NLL: 	 121.05879211425781 	 +- 0.46391791105270386
