In [59]:
import json
from math import floor
from pathlib import Path

import numpy as np
import pandas as pd
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import AgglomerativeClustering

from soup_nuts.utils import read_json
from calculate_metrics import load_estimates, load_yaml

## Setup
Load in estimates for a given hyperparameter setting

In [2]:
# specify a dataset, vocabulary, topic-size, model setting
base_run_dir = Path("outputs/stability_coverage/wikitext-labeled/vocab_5k/k-50/dvae")
config_fpaths = list(base_run_dir.glob("**/config.yml"))

# get estimates for each run in this setting
topic_word_runs, doc_topic_runs = [], []
tw_ids, dt_ids = [], [] # need to uniquely identify objects from each run

for i, fpath in enumerate(config_fpaths):
    # softmaxes dvae & scholar topic_word (beta)
    tw, dt, _ = load_estimates(fpath.parent, model_type="dvae")

    topic_word_runs.append(tw)
    doc_topic_runs.append(dt)

    tw_ids.extend([i]*tw.shape[0])
    dt_ids.extend([i]*dt.shape[0])

# collect into matrices:
topic_word_runs = np.vstack(topic_word_runs) # [runs*k x |V|]
doc_topic_runs = np.vstack(doc_topic_runs) # [runs*n x k]

tw_ids = np.array(tw_ids) # [runs*k]
dt_ids = np.array(dt_ids) # [runs*n]

input_dir = load_yaml(config_fpaths[0])["input_dir"]
vocab = read_json(Path(input_dir, "vocab.json"))
inv_vocab = dict(zip(vocab.values(), vocab.keys()))

## Threshold setting
Inspect the top-matched topics so we can set a threshold

In [98]:
display_topic = lambda topic, inv_vocab: " ".join(inv_vocab[w] for w in topic)

def inspect_matched_topics(
    topic_word_runs,
    inv_vocab,
    dists=None, # precompute if desired
    metric="jensenshannon",
    top_n_words=15,
    log_base=4,
    limit_to_rank=None,
):
    """Get a sense for what a good distance threshold might be"""
    # get pairwise distances
    if dists is None:
        dists = squareform(pdist(topic_word_runs, metric=metric))
    # get ordered indices: from https://stackoverflow.com/a/64338853/5712749
    x, y = np.unravel_index(np.argsort(dists, axis=None), dists.shape)
    coords = np.array([x, y]).T[x!=y][::2] # remove diagnals (i,i), & every-other (since j,i=i,j)

    # display topics
    top_words = (-topic_word_runs).argsort(1)[:, :top_n_words] # top words per topic

    floorlog = lambda x: np.floor(np.log(x) / np.log(log_base))

    for rank, (i, j) in enumerate(coords[:limit_to_rank]):
        if rank == 0 or floorlog(rank) != floorlog(rank+1): # display more at beginning, less at end
            i_words = display_topic(top_words[i], inv_vocab)
            j_words = display_topic(top_words[j], inv_vocab)
            dist = dists[i, j]
            print(f">> Rank: {rank}, dist: {dist:0.3f}\n    {i_words}\n    {j_words}")

def inspect_matched_topics_around_threshold(
    topic_word_runs,
    inv_vocab,
    dists=None,
    metric="jensenshannon",
    threshold=None,
    top_n_words=15,
    window_pct=0.001,
    window_limit=None,
):
    # get pairwise distances
    if dists is None:
        dists = squareform(pdist(topic_word_runs, metric=metric))

    low, high = threshold * (1-window_pct), threshold * (1+window_pct)
    x, y = np.where((low < dists) & (dists < high))
    coords = np.array([x, y]).T[x!=y] # remove diagnals (i,i)
    coords = np.unique(coords, axis=0)

    top_words = (-topic_word_runs).argsort(1)[:, :top_n_words] # top words per topic

    for (i, j) in coords[:window_limit]:
        i_words = display_topic(top_words[i], inv_vocab)
        j_words = display_topic(top_words[j], inv_vocab)
        dist = dists[i, j]
        print(f">> {metric} dist: {dist:0.3f}\n    {i_words}\n    {j_words}")


In [99]:
# jensenshannon takes a little while to compute (and could look better, tbh?)
js_dists = squareform(pdist(topic_word_runs, "jensenshannon"))
inspect_matched_topics(topic_word_runs, inv_vocab, dists=js_dists)

>> Rank: 0, dist: 0.000
    species found studio album surface males women fall animation episode fish machine africa band female
    species found white castle crash film genus black floor usually fish males number single food
>> Rank: 3, dist: 0.000
    species found force films film forces battalion race rush regiment formula served war israel company
    species found black film white magazine oil food groups race grant glass battalion long network
>> Rank: 15, dist: 0.001
    species found force films film forces battalion race rush regiment formula served war israel company
    species found white castle crash film genus black floor usually fish males number single food
>> Rank: 63, dist: 0.003
    jackson units highway electric battery sales revenge satisfied nielsen day massachusetts governor birds painting car
    music album born home tracks match soundtrack track farm temple song episode songs officer hell
>> Rank: 255, dist: 0.006
    species found white brain round girl fl

In [100]:
# oddly, correlation seems solid
corr_dists = squareform(pdist(topic_word_runs, metric="correlation"))
inspect_matched_topics(topic_word_runs, inv_vocab, dists=corr_dists)

>> Rank: 0, dist: 0.000
    species found game white series tree match games season park ocean character song known released
    species found dutch episode paul body white ships way drug fleet brown nuclear known genus
>> Rank: 3, dist: 0.000
    species found body tropical nuclear white depression hurricane names dutch long storm common children genus
    species found brown genus white common evolution music long song martin similar fruit small bodies
>> Rank: 15, dist: 0.000
    species found cemetery williams genus usually plants song security england units entertainment oil body german
    species found studio album surface males women fall animation episode fish machine africa band female
>> Rank: 63, dist: 0.000
    species found white brain round girl florida males rock age small fight new_york_times hall music
    species found game white series tree match games season park ocean character song known released
>> Rank: 255, dist: 0.033
    cambridge oxford lengths reigning uni

In [126]:
# euclidean doesn't look so good
euc_dists = squareform(pdist(topic_word_runs, metric="euclidean"))
inspect_matched_topics(topic_word_runs, inv_vocab, dists=euc_dists)

>> Rank: 0, dist: 0.000
    species found white castle crash film genus black floor usually fish males number single food
    species found studio album surface males women fall animation episode fish machine africa band female
>> Rank: 3, dist: 0.000
    species found white castle crash film genus black floor usually fish males number single food
    species found cemetery williams genus usually plants song security england units entertainment oil body german
>> Rank: 15, dist: 0.000
    century mark species women match event princess battalion company australia theatre ray work episode islands
    hall rush cattle company concerts angel queen german scotland michael battle ride care claimed germany
>> Rank: 63, dist: 0.000
    properties historic series national company bay rates rate listed countries machine debt group draft property
    jackson units highway electric battery sales revenge satisfied nielsen day massachusetts governor birds painting car
>> Rank: 255, dist: 0.001
    

In [120]:
inspect_matched_topics_around_threshold(topic_word_runs, inv_vocab, dists=js_dists, threshold=0.04, window_pct=0.005, window_limit=5)

>> jensenshannon dist: 0.040
    poem book church twitter emperor women king published god century religious children language political books
    species genus prey habitat females plants breeding males populations colonies specimens horses ecology cells eggs
>> jensenshannon dist: 0.040
    poem book church twitter emperor women king published god century religious children language political books
    murder officers fbi prisoners police sentence jury crime gay prison witnesses der attorney convicted charges
>> jensenshannon dist: 0.040
    century mark species women match event princess battalion company australia theatre ray work episode islands
    episodes characters dvd comic episode animation animated series aired season character novels volumes viewers released
>> jensenshannon dist: 0.040
    century mark species women match event princess battalion company australia theatre ray work episode islands
    comedy starred drama film actress actor emmy actors nominations episodes

## Clustering attempts

In [95]:
def ensemble_runs(
    topic_word_runs,
    tw_ids=None,
    doc_topic_runs=None, # not used yet, may need ids as well
    dt_ids=None,
    distance_threshold=None,
    dists=None,
    metric="jensenshannon",
    linkage="average",
):
    """
    Attempt at an ensembling/agglomerative clustering algorithm

    TODO:
     - once topic-words have been found, get likely doc-topic assignments
     - try using the doc-topic probabilities in the distance calculations
    """
    if dists is None:
        dists = squareform(pdist(topic_word_runs, metric=metric))
    
    ag = AgglomerativeClustering(
        n_clusters=None,
        affinity="precomputed",
        linkage=linkage,
        distance_threshold=distance_threshold,
    )
    preds = ag.fit_predict(dists)

    # put predictions into a df
    df = pd.DataFrame({"cluster": preds, "run_id": tw_ids, "topic_id": np.arange(topic_word_runs.shape[0])})
    cluster_sizes = df.groupby("cluster", as_index=False).size()
    df = df.merge(cluster_sizes, on="cluster").sort_values("size", ascending=False)

    return ag, df

def display_clusters(df, topic_word_runs, top_n_words=15, min_cluster_size=0):
    """Display the """
    top_words = (-topic_word_runs).argsort(1)[:, :top_n_words] # top words per topic
    for cluster_id, grp in df.groupby("cluster", sort=False):
        if grp["size"].min() >= min_cluster_size:
            print(f"\n>> Cluster: {cluster_id}")
            for _, row in grp.iterrows():
                topic_words = display_topic(top_words[row.topic_id], inv_vocab)
                print(f"run: {row.run_id} | {topic_words}")

In [121]:
ag, js_cluster_df = ensemble_runs(topic_word_runs, tw_ids, distance_threshold=0.05, dists=js_dists)

In [124]:
ag, corr_cluster_df = ensemble_runs(topic_word_runs, tw_ids, distance_threshold=0.4, dists=corr_dists)

In [128]:
display_clusters(corr_cluster_df, topic_word_runs, min_cluster_size=2)


>> Cluster: 0
run: 7 | spanish british french militia regiment army fort fleet ships squadron washington troops expedition battle admiral
run: 7 | match runs england test australia team matches tour scored season series ball johnson class india
run: 7 | species horses plants birds fish prey males plant genus breeding females eggs meat animals brown
run: 7 | election republican senate labour paul presidential party president governor vote campaign democratic grant johnson congress
run: 7 | ship ships guns fleet aircraft tons torpedo german admiral war cruisers turrets battleships cruiser squadron
run: 7 | creek river lake species park fish salt trail water island plants basin area mountain metres
run: 7 | building city museum library park county street avenue buildings population tower downtown district construction square
run: 7 | aircraft engine wing squadron flight nuclear engines air mission radar pilot fighter test design soviet
run: 7 | squadron aircraft wing fighter war command 