In [1]:
%matplotlib inline
import argparse
import glob
import logging
import os
import pickle
import random
import re
import csv
from typing import Dict, List, Tuple
import numpy as np
from scipy import stats
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter, defaultdict
import yaml
from experiment import *
from models import *
from experiment import VAEXperiment

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


In [2]:
def generate_vis(influence_dir, output_dir, model_config, num_of_test_ex, num_of_top_train_ex_for_test_ex):
    agg_influence_dict = pickle.load(open(os.path.join(influence_dir, 'agg_influence_dict.pkl'), "rb"))
    config = yaml.safe_load(open(model_config, 'r'))
    model = vae_models[config['model_params']['name']](**config['model_params'])
    experiment = VAEXperiment(model, config['exp_params'])
    train_dataloader = experiment.train_sequential_dataloader()
    test_dataloader = experiment.test_dataloader()[0]
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    # dump test
    for test_idx, test_batch in enumerate(test_dataloader):
        if test_idx >= num_of_test_ex:
            break
        assert len(test_batch[0]) == 1 # check whether only one image is passed in
        test_input, test_label = test_batch
        vutils.save_image(test_input, os.path.join(output_dir, f"test_{test_idx}.png"), normalize=False, nrow=1)
        vutils.save_image(test_input, os.path.join(output_dir, f"normed_test_{test_idx}.png"), normalize=True, nrow=1)
    # dump train
    for test_idx, test_batch in enumerate(test_dataloader):
        if test_idx >= num_of_test_ex:
            break
        assert len(test_batch[0]) == 1 # check whether only one image is passed in
        inf_sorted_idx = list(np.argsort(agg_influence_dict[test_idx][0]))
        inf_sorted_idx.reverse()
        train_pic_list = [None] * num_of_top_train_ex_for_test_ex
        top_isi_i_dict = {isi: _i for _i, isi in list(enumerate(inf_sorted_idx))[:num_of_top_train_ex_for_test_ex]}
        cnt_filled = 0
        for train_idx, train_batch in enumerate(train_dataloader):
            if cnt_filled >= num_of_top_train_ex_for_test_ex:
                break
            assert len(train_batch[0]) == 1 # check whether only one image is passed in
            if train_idx in top_isi_i_dict:
                train_pic_list[top_isi_i_dict[train_idx]] = train_batch[0]
    #             print(agg_influence_dict[test_idx][0][train_idx])
                cnt_filled += 1
        vutils.save_image(torch.cat(train_pic_list, 0), os.path.join(output_dir, f"influential_to_test_{test_idx}.png"), normalize=False, nrow=10)
        vutils.save_image(torch.cat(train_pic_list, 0), os.path.join(output_dir, f"normed_influential_to_{test_idx}.png"), normalize=True, nrow=10)

In [5]:
influence_dir = "vanilla_vae_dotprod_IF/"
output_dir = "analysis_vanilla_vae_dotprod/"
model_config = "configs/test_vae.yaml"
num_of_test_ex = 100
num_of_top_train_ex_for_test_ex = 100

In [6]:
%%time
generate_vis(influence_dir, output_dir, model_config, num_of_test_ex, num_of_top_train_ex_for_test_ex)

Files already downloaded and verified
Files already downloaded and verified
CPU times: user 3h 23min 48s, sys: 2min 24s, total: 3h 26min 13s
Wall time: 3h 25min 56s


In [None]:
agg_influence_dict = pickle.load(open(os.path.join(influence_dir, 'agg_influence_dict.pkl'), "rb"))
sns.distplot(agg_influence_dict[0][0])
plt.show()