In [24]:
from typing import Union, Dict
from itertools import combinations
from collections import defaultdict, Counter

import numpy as np
import pandas as pd
from scipy.stats import multinomial, binom, hypergeom, bernoulli, multivariate_hypergeom
from scipy.special import comb, gammaln

In [4]:
def compute_tu(topics, n=10):
    """
    Topic uniqueness measure from
    https://www.aclweb.org/anthology/P19-1640.pdf
    """
    tu_results = []
    for topics_i in topics:
        w_counts = 0
        for w in topics_i[:n]:
            w_counts += 1 / np.sum([w in topics_j[:n] for topics_j in topics]) # count(k, l)
        tu_results.append((1 / n) * w_counts)
    return tu_results


def compute_tr(topics, n=10):
    """
    Compute topic redundancy score from 
    https://jmlr.csail.mit.edu/papers/volume20/18-569/18-569.pdf
    """
    tr_results = []
    k = len(topics)
    for i, topics_i in enumerate(topics):
        w_counts = 0
        for w in topics_i[:n]:
            w_counts += np.sum([w in topics_j[:n] for j, topics_j in enumerate(topics) if j != i]) # count(k, l)
        tr_results.append((1 / (k - 1)) * w_counts)
    return tr_results


def compute_td(topics, n=25):
    """
    Compute topic diversity from 
    https://doi.org/10.1162/tacl_a_00325
    """
    words = [w for topic in topics for w in topic[:n]]
    return len(set(words)) / len(words)


def compute_te(prob_w_given_topic, n=20, topics_sorted=None):
    """
    Compute topic exclusivity from 
    https://icml.cc/Conferences/2012/papers/113.pdf
    """
    # the should be normed (e.g., via softmax)
    assert np.allclose(prob_w_given_topic.sum(1), 1)
    num_topics = prob_w_given_topic.shape[0]
    prob_w_normed = prob_w_given_topic.sum(0)
    nonzero = prob_w_normed > 0
    exclusivity = np.zeros_like(prob_w_given_topic)
    exclusivity[:, nonzero] = prob_w_given_topic[:, nonzero] / prob_w_normed[nonzero]

    if topics_sorted is None:
        topics_sorted = np.flip(prob_w_given_topic.argsort(-1), -1)[:, :n]

    return [np.mean(exclusivity[k, topics_sorted[k, :n]]) for k in range(num_topics)]


def compute_to(topics, n=10):
    """
    Calculate topic overlap (number of unique topic pairs sharing words)
    """
    k = len(topics)
    overlaps = np.zeros((k, k), dtype=float)
    common_terms = np.zeros((k, k), dtype=float)
    words = Counter([w for topic in topics for w in topic[:n]])

    for i, t_i in enumerate(topics):
        for j, t_j in enumerate(topics[i+1:], start=i+1):
            if i != j:
                overlap_ij = set(t_i[:n]) & set(t_j[:n])
                overlaps[i, j] = len(overlap_ij) 
                common_terms[i, j] = sum(words[w] for w in overlap_ij)
    
    return overlaps.sum()


def compute_tr_weighted(topics, n=10):
    """
    Compute weighted topic redundancy score:
    each additional word from the same topic counts more
    than the previous

    however, it fails the 3 overlapping topics ==
    2 pairs of two topics test
    """
    tr_results = []
    k = len(topics)
    for i, topics_i in enumerate(topics):
        w_counts = 0
        j_counts = defaultdict(int)
        for w in topics_i[:n]:
            for j, topics_j in enumerate(topics):
                if j != i:
                    if w in topics_j[:n]:
                        j_counts[j] += 1
                        w_counts += j_counts[j]

        tr_results.append((1 / (k - 1)) * w_counts)
    return np.array(tr_results)


def compute_corrected_tr_weighted(topics, n=10, weight=0.9):
    """
    Compute corrected topic redundancy score:
    each additional word from the same topic counts more
    than the previous, and words are downweighted
    to account for double-counting
    """
    tr_results = []
    k = len(topics)
    words = Counter([w for topic in topics for w in topic[:n]])
    i_norm = ((k - 1) * sum(i for i in range(1, n+1)))
    c_norm = (n * (k - 1))
    w_norm = (n * k) * (k - 1)
    w_c_norm = c_norm / w_norm

    for i, topics_i in enumerate(topics):
        i_counts = 0.
        c_counts = 0.
        j_counts = defaultdict(float)
        w_counts = 0.
        for w in topics_i[:n]:
            for j, topics_j in enumerate(topics):
                if j != i:
                    if w in topics_j[:n]:
                        j_counts[j] += 1
                        i_counts += j_counts[j] * weight # TODO: some weighting
                        c_counts += 1
                        w_counts += words[w]
        if c_counts == 0:
            tr_results.append(0)
        else:
            #tr_results.append((1 / (k - 1)) * (i_counts) * (c_counts / w_counts))
            #tr_results.append((i_counts / i_norm) * ((c_counts / w_counts) / w_c_norm) * 100)
            tr_results.append((i_counts / i_norm) * (c_counts / w_counts) * 100)

    return np.array(tr_results)

def compute_corrected_to(topics, n=10, multiplier=2):
    """
    A sensible overlap / redundancy measure. Words from a topic
    are only counted once per "edge"
    """
    k = len(topics)
    # create de-duplicated adjacency matrix
    # for each topic A_i, sorted by total number of overlaps:
    #    create set of sets S = {S_{ij} = A_i \cap A_j st. j=i+1,...,k}
    #    sort sets in S by their cardinality in descending order
    #    initialize a set W = {}
    #    For each S_{ij}' in S:
    #        if words are not already part of an edge, i.e., |W \cap S_{ij}'| is 0:
    #           create an edge between A_i and A_j with weight w = |S_{ij}'|
    #           augment the list of words used in an edge, W = W \cup S_{ij}'
    overlap_counts = np.zeros((k, k), dtype=int)
    overlap_dedup = np.zeros((k, k), dtype=int)
    overlap_words = {}

    # first count all the overlaps between topics
    for i, topic_i in enumerate(topics):
        for j, topic_j in enumerate(topics[i+1:], start=i+1):
            words_ij = set(topic_i[:n]) & set(topic_j[:n])
            overlap_counts[[i, j], [j, i]] = len(words_ij)
            overlap_words[frozenset([i, j])] = words_ij

    # sort topics by those with most overlaps
    sort_idx = overlap_counts.sum(0).argsort()[::-1]
    overlap_counts = overlap_counts[sort_idx, :][:, sort_idx]
    for i, counts in enumerate(overlap_counts):
        counted_words = set()
        start = i + 1
        for j in (counts[start:].argsort()[::-1] + start):
            words_ij = overlap_words[frozenset([i, j])]
            if overlap_counts[i, j] > 0 and len(counted_words & words_ij) == 0:
                overlap_dedup[i, j] = overlap_counts[i, j]
                counted_words |= words_ij

    # TODO: incorporate equivalencies / clean up the below to be neater
    # how many single word overlaps are equivalent to a full topic overlap
    increments = np.linspace(1/multiplier, n, num=n)
    redundancy = increments[overlap_dedup[overlap_dedup > 0] - 1].sum() / (n * (k - 1))
    # redundancy = sum(increments[o - 1] for o in overlap_dedup[overlap_dedup > 0]) / (n * (k - 1))

    # old redundancy = (overlap_dedup * (overlap_dedup/n)).sum() / ((n * (k - 1)))
    # overlaps = np.zeros((k, k), dtype=int)
    # for i, topic_i in enumerate(topics[:-1]):
    #     intersections = [
    #         (j, set(topic_i[:n]) & set(topic_j[:n])) for j, topic_j in enumerate(topics[i+1:], start=i+1)
    #     ]
    #     intersections = sorted(intersections, key=lambda x: len(x[1]), reverse=True)
    #     counted_words = set()
    #     for j, int_j in intersections:
    #         if len(int_j) > 0 and len(int_j & counted_words) == 0:
    #             overlaps[i, j] = len(int_j)
    #             counted_words |= int_j
    # redundancy = (overlaps * overlaps/n).sum() / ((n * (k - 1)))
    # TODO: also return overlap counts
    return redundancy


def compute_all_redundancies(
    beta: np.ndarray = None,
    topics_sorted: np.ndarray = None,
    n: Union[int, Dict[str, int]] = {"tu": 10, "tr": 10, "td": 25, "te": 20, "to": 10},
    print_out: bool = True,
    ):
    if isinstance(n, int):
        n = {"tu": n, "tr": n, "td": n, "te": n, "to": n}
    
    if topics_sorted is None:
        topics_sorted = np.flip(beta.argsort(-1), -1)

    metrics = {
        "tu": np.mean(compute_tu(topics_sorted, n["tu"])),
        "tr": np.mean(compute_tr(topics_sorted, n["tr"])),
        "tr_w": np.mean(compute_corrected_tr_weighted(topics_sorted, n["tr"])),
        "td": compute_td(topics_sorted, n["td"]),
        "to": np.mean(compute_corrected_to(topics_sorted, n["to"])),
    }
    if beta is not None:
        metrics["te"] = np.mean(compute_te(beta, n["te"], topics_sorted))
    if print_out:
        print(*[f"{k}: {v:0.3f}" for k, v in metrics.items()])

    return metrics

In [43]:
def print_topics(topics, transpose=True):
    cols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ'
    k, n = topics.shape
    if k > len(cols):
        raise ValueError("Too many topics")
    df = pd.DataFrame(topics, index=list(cols[:k]))
    if transpose:
        df = df.T
    print(df.to_string(index=not transpose, header=transpose))

In [5]:
v = 5000
k = 20
n = 10

np.random.seed(11235)
rand_idx = lambda k: np.random.choice(k, k, replace=False)

unif_param = np.full(v, 1 / v)
biased_param = np.random.dirichlet(unif_param) + 1e-6

beta_unif = np.array([np.random.dirichlet(unif_param) for _ in range(k)])
beta_biased = np.array([np.random.dirichlet(biased_param) for _ in range(k)])

In [6]:
_ = compute_all_redundancies(beta=beta_unif, print_out=True)

tu: 0.975 tr: 0.026 tr_w: 0.022 td: 0.946 to: 0.013 te: 0.951


In [7]:
scores = compute_all_redundancies(beta=beta_biased, print_out=True)

tu: 0.485 tr: 3.553 tr_w: 0.799 td: 0.222 to: 0.428 te: 0.202


In [8]:
# some robustness checks
shuff_beta = beta_biased[rand_idx(k)]
scores == compute_all_redundancies(beta=shuff_beta, print_out=True)
shuff_topics = np.flip(shuff_beta.argsort(1), -1)[:, :n]
shuff_topics = shuff_topics[:, rand_idx(n)]
scores == compute_all_redundancies(topics_sorted=shuff_topics, beta=shuff_beta, print_out=True)

tu: 0.485 tr: 3.553 tr_w: 0.799 td: 0.222 to: 0.428 te: 0.202
tu: 0.485 tr: 3.553 tr_w: 0.799 td: 0.485 to: 0.428 te: 0.400


False

## Synthetic overlap tests
k > n will yield equality between all metrics

n >= k means that "overlap" does worse (undesirable)

In [12]:
# synthetic overlaps
n = 5
k = 6

rand_idx = lambda k: np.random.choice(k, k, replace=False)
helper_idx = int(np.power(10, np.ceil(np.log10(n)))) # helps visualize better
topics = np.array([helper_idx*i+np.arange(n) for i in range(k)])

max_idx = topics.max()
gen_beta = lambda topics: np.array([np.bincount(topic, minlength=max_idx + 1) / n for topic in topics])
beta = gen_beta(topics)

duplicate_terms = topics[0]
duplicate_probs = beta[0]
# worst cast scenario: all repeated
topics_all_repeats = [np.arange(n) for _ in range(k)]

# one topic copied entirely
topics_overlapping = np.vstack([duplicate_terms, topics[:k-1]])
beta_overlapping = gen_beta(topics_overlapping)

# duplicate words are distributed evenly across topics 
idx = min(n, k -1)
topics_distributed = np.copy(topics)
topics_distributed[1:idx+1, 0] = duplicate_terms[:idx]
beta_distributed = gen_beta(topics_distributed)

# single duplicate word distributed across topics
topics_redundant_word = np.copy(topics)
topics_redundant_word[0:idx+1, 0] = topics_redundant_word[0, 0]

In [13]:
topics

array([[ 0,  1,  2,  3,  4],
       [10, 11, 12, 13, 14],
       [20, 21, 22, 23, 24],
       [30, 31, 32, 33, 34],
       [40, 41, 42, 43, 44],
       [50, 51, 52, 53, 54]])

In [21]:
topics_overlapping

array([[ 0,  1,  2,  3,  4],
       [ 0,  1,  2,  3,  4],
       [10, 11, 12, 13, 14],
       [20, 21, 22, 23, 24],
       [30, 31, 32, 33, 34],
       [40, 41, 42, 43, 44]])

In [15]:
topics_distributed

array([[ 0,  1,  2,  3,  4],
       [ 0, 11, 12, 13, 14],
       [ 1, 21, 22, 23, 24],
       [ 2, 31, 32, 33, 34],
       [ 3, 41, 42, 43, 44],
       [ 4, 51, 52, 53, 54]])

In [16]:
_ = compute_all_redundancies(beta=beta, topics_sorted=topics, n=n) # best value

tu: 1.000 tr: 0.000 tr_w: 0.000 td: 1.000 to: 0.000 te: 1.000


In [17]:
_ = compute_all_redundancies(topics_sorted=topics_all_repeats, n=n) # worst value

tu: 0.167 tr: 5.000 tr_w: 15.000 td: 0.167 to: 1.000


In [18]:
_ = compute_all_redundancies(beta=beta_overlapping, topics_sorted=topics_overlapping, n=n)

tu: 0.833 tr: 0.333 tr_w: 3.000 td: 0.833 to: 0.200 te: 0.833


In [19]:
_ = compute_all_redundancies(beta=beta_distributed, topics_sorted=topics_distributed, n=n)

tu: 0.833 tr: 0.333 tr_w: 1.000 td: 0.833 to: 0.100 te: 0.833


In [207]:
_ = compute_all_redundancies(topics_sorted=topics_redundant_word, n=n)

tu: 0.952 tr: 0.262 tr_w: 0.039 td: 0.952 to: 0.025


Exploring the distributed/overlap a bit further

In [193]:
for topics_with_repeats in range(1, min(n, k - 1)+1):
    if len(duplicate_terms) % topics_with_repeats == 0:
        bucket_size = len(duplicate_terms) // topics_with_repeats
        topics_distributed_i = np.copy(topics)
        for i in range(topics_with_repeats):
            # duplicate words are distributed evenly across topics 
            topics_distributed_i[i+1, 0:bucket_size] = duplicate_terms[i*bucket_size:(i+1)*bucket_size]
        print(f"{topics_with_repeats} topic(s) with {bucket_size} repeats each")
        beta_distributed_i = gen_beta(topics_distributed_i)
        compute_all_redundancies(beta=beta_distributed_i, topics_sorted=topics_distributed_i, n=n)

1 topic(s) with 10 repeats each
tu: 0.952 tr: 0.048 tr_w: 0.214 td: 0.952 to: 0.050 te: 0.952
2 topic(s) with 5 repeats each
tu: 0.952 tr: 0.048 tr_w: 0.117 td: 0.952 to: 0.047 te: 0.952
5 topic(s) with 2 repeats each
tu: 0.952 tr: 0.048 tr_w: 0.058 td: 0.952 to: 0.039 te: 0.952
10 topic(s) with 1 repeats each
tu: 0.952 tr: 0.048 tr_w: 0.039 td: 0.952 to: 0.025 te: 0.952


In [27]:
topics_with_repeats

1

In [178]:
if k >= (1+2*len(duplicate_terms)):
    idx = len(duplicate_terms)
    topics_distributed_double = np.copy(topics)
    topics_distributed_double[1:(2*idx)+1, 0] = np.concatenate([duplicate_terms, duplicate_terms])
    print("Overlap: ")
    _ = compute_all_redundancies(topics_sorted=topics_overlapping, n=n)
    print("\nDist: ")
    _ = compute_all_redundancies(topics_sorted=topics_distributed, n=n)
    print("\nDist double: ")
    _ = compute_all_redundancies(topics_sorted=topics_distributed_double, n=n)

Overlap: 
tu: 0.952 tr: 0.048 tr_w: 0.214 td: 0.952 to: 0.050

Dist: 
tu: 0.952 tr: 0.048 tr_w: 0.039 td: 0.952 to: 0.025

Dist double: 
tu: 0.905 tr: 0.143 tr_w: 0.078 td: 0.905 to: 0.050


Other synthetic examples of overlap.

Four completely overlapping topics should score worst

But two pairs of two identical topics vs.
three identical topics should be the same:
in both cases, you have effectively k-2 total topics

In [179]:
t0 = topics[0]
t1 = topics[1]

three_of_a_kind = np.vstack([t0, t0, topics[:k-2]])
two_pair = np.vstack([t0, t0, t1, t1, topics[4:]])

four_of_a_kind = np.vstack([t0, t0, t0, topics[:k-3]])
full_house = np.vstack([t0, t0, t0, t1, t1, topics[5:]])
assert four_of_a_kind.shape == three_of_a_kind.shape == two_pair.shape == full_house.shape == (k, n)

In [165]:
_ = compute_all_redundancies(topics_sorted=three_of_a_kind, n=n)

tu: 0.900 tr: 0.158 tr_w: 0.474 td: 0.900 to: 0.105


In [166]:
_ = compute_all_redundancies(topics_sorted=two_pair, n=n) # should be == to three_of_a_kind

tu: 0.900 tr: 0.105 tr_w: 0.474 td: 0.900 to: 0.105


In [167]:
_ = compute_all_redundancies(topics_sorted=four_of_a_kind, n=n) # should be "worse" than three_of_a_kind

tu: 0.850 tr: 0.316 tr_w: 0.711 td: 0.850 to: 0.158


In [168]:
_ = compute_all_redundancies(topics_sorted=full_house, n=n) # should be == four_of_a_kind

tu: 0.850 tr: 0.211 tr_w: 0.711 td: 0.850 to: 0.158


This is another similar artificial test. I think this too should be the same, but it's not clear this will hold true for metrics that meet other criteria

In [147]:
semi_three_of_a_kind = [
    [ 1,  2,  3,  4],
    [ 1,  2,  3,  4],
    [ 1,  2, 23, 24],
    [31, 32, 33, 34],
]

semi_two_pair = [
    [ 1,  2,  3,  4],
    [ 1,  2,  3,  4],
    [31, 32, 23, 24],
    [31, 32, 33, 34],
]

_ = compute_all_redundancies(topics_sorted=semi_three_of_a_kind, n=4)
_ = compute_all_redundancies(topics_sorted=semi_two_pair, n=4)

tu: 0.625 tr: 1.333 tr_w: 8.812 td: 0.625 to: 0.472
tu: 0.625 tr: 1.000 tr_w: 9.750 td: 0.625 to: 0.472


Here we do a similar thing but topics overlap only partially (7 words)

In [148]:
three_of_a_kind_partial = np.copy(topics)
three_of_a_kind_partial[:, :n-3] = three_of_a_kind[:, :n-3]
two_pair_partial = np.copy(topics)
two_pair_partial[:, :n-3] = two_pair[:, :n-3]

four_of_a_kind_partial = np.copy(topics)
four_of_a_kind_partial[:, :n-3] = four_of_a_kind[:, :n-3]
full_house_partial = np.copy(topics)
full_house_partial[:, :n-3] = full_house[:, :n-3]

In [149]:
_ = compute_all_redundancies(topics_sorted=three_of_a_kind_partial, n=n)
_ = compute_all_redundancies(topics_sorted=two_pair_partial, n=n)

_ = compute_all_redundancies(topics_sorted=four_of_a_kind_partial, n=n)
_ = compute_all_redundancies(topics_sorted=full_house_partial, n=n)

tu: 0.927 tr: 0.109 tr_w: 0.327 td: 0.927 to: 0.065
tu: 0.927 tr: 0.073 tr_w: 0.327 td: 0.927 to: 0.065
tu: 0.891 tr: 0.218 tr_w: 0.491 td: 0.891 to: 0.098
tu: 0.891 tr: 0.145 tr_w: 0.491 td: 0.891 to: 0.098


In [150]:
# TODO: 
# consider when four_of_a_kind_partial should be less
# than tree_of_a_kind (full)

# for n = 10, 30 total redundant words
# NB that four-of-a-kind with j word overlap is same as five-of-a-kind with j-1 (?)
five_of_a_kind = np.vstack([np.vstack([t0, t0, t0, t0, topics[:k-4]])])
five_of_a_kind_partial_6 = np.copy(topics)
five_of_a_kind_partial_6[:, :n-4] = five_of_a_kind[:, :n-4]
_ = compute_all_redundancies(topics_sorted=three_of_a_kind, n=n)
_ = compute_all_redundancies(topics_sorted=four_of_a_kind_partial, n=n)
_ = compute_all_redundancies(topics_sorted=five_of_a_kind_partial_6, n=n)

tu: 0.818 tr: 0.273 tr_w: 1.636 td: 0.818 to: 0.200
tu: 0.891 tr: 0.218 tr_w: 0.491 td: 0.891 to: 0.098
tu: 0.927 tr: 0.182 tr_w: 0.218 td: 0.927 to: 0.040


Now, consider sets of n//2 repeating words:

1 - n//2 words from one topic are found in one other topic

2 - the same set of n//2 words from one topic appear in two other topics

3 - two different sets of n//2 words from one topic appear once each in two other topics

We want (larger is better): 1 > 3 >= 2 (it's not a dealbreaker if 2 and 3, but not preferable)

In [162]:
mid_idx = n // 2
topics_half_overlap = np.copy(topics)
topics_half_overlap[1, :mid_idx] = topics_half_overlap[0, :mid_idx]

topics_half_overlap_2x = np.copy(topics)
topics_half_overlap_2x[1, :mid_idx] = topics_half_overlap_2x[0, :mid_idx]
topics_half_overlap_2x[2, :mid_idx] = topics_half_overlap_2x[0, :mid_idx]

topics_seperate_half_overlap = np.copy(topics)
topics_seperate_half_overlap[1, :mid_idx] = topics_seperate_half_overlap[0, :mid_idx]
topics_seperate_half_overlap[2, :mid_idx] = topics_seperate_half_overlap[0, mid_idx:]

In [163]:
_ = compute_all_redundancies(topics_sorted=topics_half_overlap, n=n)
_ = compute_all_redundancies(topics_sorted=topics_half_overlap_2x, n=n)
_ = compute_all_redundancies(topics_sorted=topics_seperate_half_overlap, n=n)


tu: 0.975 tr: 0.026 tr_w: 0.065 td: 0.975 to: 0.025
tu: 0.950 tr: 0.079 tr_w: 0.129 td: 0.950 to: 0.050
tu: 0.950 tr: 0.053 tr_w: 0.129 td: 0.950 to: 0.050
