In [2]:
import sys
import re
import pickle
from pathlib import Path
from collections import Counter
from itertools import combinations
from multiprocessing.sharedctypes import RawArray

import torch
import json
import yaml
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from scipy.special import softmax, log_softmax
from scipy.spatial.distance import jensenshannon, cdist

sys.path.append("../soup_nuts/models/dvae/")

from dvae import data_iterator, CollapsedMultinomial, DVAE
from utils import load_sparse, compute_to

In [3]:
def load_json(path):
    with open(path) as infile:
        return json.load(infile)

def load_yaml(path):
    with open(path) as infile:
        return yaml.load(infile, Loader=yaml.FullLoader)

def load_text(path):
    with open(path) as infile:
        return [text.strip().split(" ") for text in infile]

def save_json(obj, path):
    with open(path, 'w') as outfile:
        return json.dump(obj, outfile, indent=2)

In [4]:
def topic_est_to_words(topic_words, inv_vocab, n=10):
    return [inv_vocab[idx] for idx in (-topic_words).argsort()[:n]]

### Estimate loading functions

In [5]:
_CUDA_AVAILABLE = torch.cuda.is_available()
_data_cache = {}

def load_mallet_estimates(fpath):
    """
    Load the doc-topic and topic-word estimates from our mallet output folder
    """
    topic_word = np.load(fpath / "beta.npy")
    # Load the standard mallet document-topic estimate as a numpy matrix
    with open(fpath / "doctopics.txt") as infile:
        doc_topic = np.array([
            [float(x) for x in line.strip().split("\t")[2:]]
            for line in infile
        ])
    return topic_word, doc_topic, None # TODO: is loss available in mallet?

def load_dvae_estimates(fpath): 
    """
    Loads the dvae model and gets the topic word distribution, then instantiates
    the encoder portion and does a forward pass to get the 
    """
    # get the topic word
    device = torch.device("cuda") if _CUDA_AVAILABLE else torch.device("cpu")

    state_dict = torch.load(fpath / "model.pt", map_location=device)
    beta = state_dict["params"]["decoder$$$eta_layer.weight"]
    topic_word = torch.transpose(beta, 0, 1).detach().numpy()

    # do a forward pass to get the document topics
    # first instantiate the model and load in the params
    config = load_yaml(fpath / "config.yml")
    
    dvae = DVAE(
        vocab_size=topic_word.shape[1],
        num_topics=config["num_topics"],
        alpha_prior=config["alpha_prior"],
        embeddings_dim=config["encoder_embeddings_dim"],
        hidden_dim=config["encoder_hidden_dim"],
        dropout=config["dropout"],
        cuda=_CUDA_AVAILABLE,
    )
    dvae_dict = {
        k.replace("$$$", "."): v
        for k, v in state_dict['params'].items()
    }
    dvae.load_state_dict(dvae_dict, strict=False)
    dvae.eval()
    turn_off_bn = 1 * (config["epochs_to_anneal_bn"] > 0) # 0 means use BN, > 0 means no BN

    # then load the data for the forward pass
    data_fpath = Path(config["input_dir"], config["train_path"])
    if data_fpath not in _data_cache:
        data = load_sparse(data_fpath).astype(np.float32)
        _data_cache[data_fpath] = data
    else:
        data = _data_cache[data_fpath]
    
    batch_size = config["batch_size"]
    epochs = config["num_epochs"]
    n = data.shape[0]
    train_batches = n // batch_size + 1

    # do the forward pass and collect outputs in an array
    doc_topic = np.zeros((n, config["num_topics"]), dtype=np.float32)
    losses = np.zeros(n, dtype=np.float32)
    for i, x_batch in enumerate(data_iterator(data, batch_size, train_batches)):
        x_batch = x_batch.to(device)
        doc_topic_batch = dvae.encoder(x_batch)
        doc_topic_batch = doc_topic_batch / doc_topic_batch.sum(1, keepdims=True)
        x_recon = dvae.decoder(doc_topic_batch, bn_annealing_factor=turn_off_bn)
        loss_batch = -CollapsedMultinomial(1, probs=x_recon).log_prob(x_batch)

        doc_topic[i * batch_size:(i + 1) * batch_size] = doc_topic_batch.detach().cpu().numpy().astype(np.float32)
        losses[i * batch_size:(i + 1) * batch_size] = loss_batch.detach().cpu().numpy().astype(np.float32)
    return topic_word, doc_topic, losses


def load_etm_estimates(fpath):
    """
    Load the ETM estimates from a model
    """
    pass


def load_estimates(fpath, model_type):
    if model_type == "dvae":
        return load_dvae_estimates(fpath)
    if model_type == "mallet":
        return load_mallet_estimates(fpath)

### Stability metrics

In [6]:
def estimate_document_word_stability(doc_topics, topic_words, top_n=15):
    """
    Given a collection of estimates of document-topic distributions, determine
    how stable the topic assignments are by comparing the set of top words

    TODO: just uses top topic for now
    """
    runs = len(doc_topics)
    n = doc_topics[0].shape[0]
    top_words_over_runs = np.zeros((n, runs * top_n))
    probs = np.zeros((n, runs))
    
    for i, (doc_topic, topic_word) in enumerate(tqdm(zip(doc_topics, topic_words), total=runs)):
        top_words = (-topic_word).argsort()[:, :top_n]
        top_words_over_runs[:, i*top_n:(i+1)*top_n] = top_words[doc_topic.argmax(1)]
             #* (doc_topic.max(1, keepdims=True) >= min_prob)
        #)
        probs[:, i] = doc_topic.max(1)

    # https://stackoverflow.com/questions/48473056/number-of-unique-elements-per-row-in-a-numpy-array
    nunique = np.count_nonzero(np.diff(np.sort(top_words_over_runs)), axis=1) + 1
    punique = (nunique - n_topic_words) / (n_topic_words * (runs - 1))
    return nunique, punique, probs

def estimate_topic_word_stability(topic_words, top_n=15):
    """
    Given a collection of estimates of topic-word distributions, determine
    how stable the topics are by comparing the set of top words
    """
    runs = len(topic_words)
    top_words_over_runs = np.zeros((n, runs * top_n))
    probs = np.zeros((n, runs))
    pass # TODO

In [7]:
def estimate_topic_stability(topic_words, iters=1, sample_n=1, softmax_ests=False, seed=None):
    """
    Estimate the stability of topics by running pairwise comparisons
    of all runs: take the js-divergence of the topic pairs, then match each topic
    with its closest pair, per run. Repeat `iters` times to get a "pseudo-best" matching

    To speed up computation, can set `sample_pct` to use only a subset of possible combinations
    """
    np.random.seed(seed)
    num_topics = topic_words[0].shape[0]
    runs = len(topic_words)
    combins = (runs * (runs - 1)) // 2
    sample_pct = sample_n if sample_n <= 1 else sample_n / combins
    to_keep = [np.random.rand() <= sample_pct for _ in range(combins)]
    kept = sum(to_keep)
    min_dists = np.zeros((kept, num_topics))
    c = 0
    if softmax_ests:
        topic_words = [softmax(t, axis=1) for t in topic_words]
    pbar = tqdm(range(kept))
    for keep, (t_a, t_b) in zip(to_keep, (combinations(topic_words, 2))):
        if not keep:
            continue
        dists = cdist(t_a, t_b, metric='jensenshannon')

        # algorithm is greedy: we randomize every iteration to get a pseudo-best estimate
        # of the distances
        for i in range(iters):
            dists_ = dists[np.random.permutation(num_topics), :].copy()

            min_dists_i = np.zeros(num_topics)
            for k in range(num_topics):
                min_idx = dists_[k].argmin() # match this topic to its lowest pair
                min_dists_i[k] = dists_[k, min_idx] # record this minimum distance
                dists_[k+1:, min_idx] = 1. # remove this index from consideration for later topics
            if i == 0 or min_dists_i.mean() < min_dists[c].mean():
                min_dists[c] = np.sort(min_dists_i)
        c += 1
        pbar.update()
    pbar.close()
    return min_dists

def estimate_effective_topics(topic_words, iters):
    pass

### collect the runs

In [8]:
# roughly 7 GB RAM for k=100
def get_estimates_over_runs(run_paths, overlap_words, exclude_dups=False):
    doc_topics, topic_words, duplicates = [], [], [] # TODO: change to 3d tensors
    for i, (p, model_type) in enumerate(tqdm(run_paths)):
        t, d, l = load_estimates(p, model_type=model_type)
        
        # located duplicated topics
        sorted_t = np.sort((-t).argsort(axis=1)[:, :overlap_words], axis=1)
        counted_topics = Counter([tuple(t_) for t_ in sorted_t])
        if exclude_dups and max(counted_topics.values()) > 1:
            continue
        doc_topics.append(d)
        topic_words.append(t)
        duplicates.append(sum(c > 1 for c in counted_topics.values()))

    return doc_topics, topic_words, duplicates

In [9]:
#run_dir = "../runs/outputs/url_partisanship_data"
run_dir = "../runs/outputs/full-mindf_power_law-maxdf_0.9"
dataset = 'wikitext' #'url_partisan'

mallet_paths = [
    (p.parent, "mallet")
    for p in Path(run_dir).glob("**/mallet-with-beta/**/doctopics.txt")
    if dataset in str(p) and "_run-logs" not in str(p)
]
dvae_paths = [
    (p.parent, "dvae") for p in Path(run_dir).glob("**/dvae/**/model.pt")
    if dataset in str(p) and "_run-logs" not in str(p)
]

# should be independent of the model
config = load_yaml(dvae_paths[0][0] / "config.yml")
data = load_sparse(Path(config["input_dir"], "train.dtm.npz"))
vocab = load_json(Path(config["input_dir"], "vocab.json"))
inv_vocab = {v: k for k, v in vocab.items()}

In [10]:
num_topics = sorted(set(int(re.search("k-([0-9]+)", str(p)).group(1)) for p in dvae_paths))
num_topics

[50]

In [11]:
n_topic_words = 15
overlap_words = 5

In [12]:
estimates_dvae, estimates_mallet = {}, {}
for k in num_topics:
    print(f"On k={k}")

    # dvae_paths_k = [p for p in dvae_paths if f'k-{k}/' in str(p[0])]
    # runs = len(dvae_paths_k)
    # if dvae_paths_k:
    #     doc_topics, topic_words, duplicates = get_estimates_over_runs(dvae_paths_k, overlap_words, exclude_dups=False)
    #     estimates_dvae[k] = {"doc_topics": doc_topics, "topic_words": topic_words, "duplicates": duplicates}
    
    mallet_paths_k = [p for p in mallet_paths if f'k-{k}/' in str(p[0])]
    if mallet_paths_k:
        doc_topics, topic_words, duplicates = get_estimates_over_runs(mallet_paths_k, overlap_words, exclude_dups=False)
        estimates_mallet[k] = {"doc_topics": doc_topics, "topic_words": topic_words, "duplicates": duplicates}

# with open(f"dvae-{dataset}-estimates.pkl", "wb") as outfile:
#     pickle.dump(estimates_dvae, outfile)

On k=50


  0%|          | 0/80 [00:00<?, ?it/s]

In [13]:
with open(f"dvae-{dataset}-estimates.pkl", "rb") as infile:
    estimates_dvae = pickle.load(infile)

estimates_dvae[50].pop("losses")

array([[11476.642 , 11715.563 , 11440.037 , ..., 11428.749 , 11875.249 ,
        11409.578 ],
       [11940.236 , 12202.471 , 12010.1   , ..., 11977.722 , 12136.514 ,
        12092.146 ],
       [ 9013.488 ,  9227.204 ,  9028.507 , ...,  9075.722 ,  9206.195 ,
         8975.454 ],
       ...,
       [ 4204.35  ,  4298.3613,  4001.2158, ...,  4016.9077,  4247.805 ,
         3961.6814],
       [ 2128.3608,  2228.5117,  2139.183 , ...,  2142.7175,  2265.1128,
         2120.1372],
       [ 4695.365 ,  4739.498 ,  4699.868 , ...,  4702.74  ,  4821.3086,
         4667.9683]], dtype=float32)

In [14]:
doc_word_data = {"mallet": {}, "dvae": {}}
for model, estimates in [("mallet", estimates_mallet), ("dvae", estimates_dvae)]:
    model_data = {}
    for k, est in estimates.items():
        _, topic_words, duplicates = (
            est["doc_topics"], est["topic_words"], est["duplicates"]
        )

        dt, tw = doc_topics, topic_words
        
        nunique, punique, probs = estimate_document_word_stability(dt, tw, top_n=n_topic_words)
        pct_assigned = (probs < 0.5).mean(1)
        nsummary = pd.Series(nunique).describe()
        psummary = pd.Series(punique).describe()
        asummary = pd.Series(pct_assigned).describe()
        
        # model_data.update({f"pct_unique_{k}": psummary, f"nuniuqe_{k}": nsummary, f"assigned_{k}": asummary})
        model_data.update({f"pct_unique_{k}": punique, f"nuniuqe_{k}": nunique, f"assigned_{k}": pct_assigned})
    doc_word_data[model] = model_data

doc_word_data = pd.concat({k: pd.DataFrame(v) for k, v in doc_word_data.items()})

  0%|          | 0/80 [00:00<?, ?it/s]

  0%|          | 0/80 [00:00<?, ?it/s]

In [36]:
doc_word_data.to_csv(f"document_word_stability-{dataset}-20220428.csv")
doc_word_data

Unnamed: 0,Unnamed: 1,pct_unique_50,nuniuqe_50,assigned_50
mallet,0,0.012658,30,0.2375
mallet,1,0.070042,98,0.9875
mallet,2,0.042194,65,0.9500
mallet,3,0.025316,45,0.2875
mallet,4,0.022785,42,0.0000
...,...,...,...,...
dvae,28467,0.716456,864,0.0125
dvae,28468,0.735021,886,0.0000
dvae,28469,0.731646,882,0.0000
dvae,28470,0.735021,886,0.0000


In [25]:
topic_word_data = {"mallet": {}, "dvae": {}}
for model, estimates in [("mallet", estimates_mallet), ("dvae", estimates_dvae)]:
    for k, est in estimates.items():
        _, topic_words, duplicates = (
            est["doc_topics"], est["topic_words"], est["duplicates"]
        )

        tw = topic_words
        # TODO: dvae may take longer since it should be float64?
        dists = estimate_topic_stability(tw, softmax_ests=model=='dvae', iters=2, sample_n=250, seed=42)
        summary = pd.Series(dists.flatten()).describe()
        topic_word_data[model][f"min_dists_{k}"] = dists.mean(1) #summary

topic_word_data = pd.concat({k: pd.DataFrame(v) for k, v in topic_word_data.items()})

  0%|          | 0/274 [00:00<?, ?it/s]

  0%|          | 0/273 [00:00<?, ?it/s]

In [27]:
topic_word_data.to_csv(f"topic_word_stability-{dataset}-20220428.csv")
topic_word_data

Unnamed: 0,Unnamed: 1,min_dists_50
mallet,0,0.349347
mallet,1,0.325461
mallet,2,0.307134
mallet,3,0.306972
mallet,4,0.315830
...,...,...
dvae,268,0.568489
dvae,269,0.482378
dvae,270,0.502758
dvae,271,0.142594


In [None]:
# TODO: run over mallet as well
data = {}

for k, est in estimates_dvae.items():
    doc_topics, topic_words, duplicates, losses = (
        est["doc_topics"], est["topic_words"], est["duplicates"], est["losses"]
    )
    for exclude in [False, True]:
        dt, tw = doc_topics, topic_words
        if exclude:
            dt, tw = exclude_dupes(dt, duplicates), exclude_dupes(tw, duplicates)
        nunique, punique, probs = estimate_document_word_stability(dt, tw, top_n=n_topic_words)
        pct_assigned = (probs < 0.5).mean(1)
        nsummary = pd.Series(nunique).describe()
        psummary = pd.Series(punique).describe()
        asummary = pd.Series(pct_assigned).describe()
        
        data.update({f"pct_unique_{k}": psummary, f"nuniuqe_{k}": nsummary, f"assigned_{k}": asummary})

data = pd.DataFrame(data)

In [None]:
data.to_csv("stability_summary-dvae-url_partisan-20220428.csv")

Unnamed: 0,a,b
count,176377.0,176377.0
mean,0.749403,0.438181
std,0.175117,0.137653
min,0.142857,0.15
25%,0.619048,0.326667
50%,0.761905,0.436667
75%,0.904762,0.54
max,1.0,0.886667


In [None]:
data

In [111]:
data

Unnamed: 0,Unnamed: 1,min_dists_50
mallet,count,750.0
mallet,mean,0.307692
mallet,std,0.178525
mallet,min,0.038807
mallet,25%,0.162107
mallet,50%,0.266771
mallet,75%,0.424182
mallet,max,0.73995
dvae,count,850.0
dvae,mean,0.34619
