In [1]:
import numpy as np
import plotly.graph_objects as go
from tqdm.auto import trange

from lda import lda
from lda.data.corpus import Corpus
from lda.data.document import Document
from lda.data.word import Word
import lda.utils as utils

In [2]:
please = Word('Please', 'please', include=True)
crash = Word('crash', 'crash', include=True)
grandma = Word('grandma', 'grandma', include=True)
sun = Word('sun', 'sun', include=True)
doc1 = Document(
    words=[
        please,
        Word('do', 'do', include=False),
        Word('not', 'not', include=False),
        crash,
        Word('.', '.', include=False),
    ]
)
doc2 = Document(
    words=[
        please,
        Word('do', 'do', include=False),
        Word('not', 'not', include=False),
        grandma,
        Word('.', '.', include=False),
    ]
)
doc3 = Document(
    words=[
        grandma,
        grandma,
        sun,
    ]
)
corpus = Corpus(documents=[doc1, doc2, doc3])
corpus

Corpus with 3 documents

In [3]:
def assert_numerically_ok(array):
    assert np.all(~np.isnan(array))
    assert np.all(~np.isinf(array))

In [4]:
def assert_lda(seed):
    with utils.np_seed(seed): # 62 is the one that breaks
        params, lower_bound_evol = lda(corpus, num_topics=2, num_iterations=32)
    assert set(params.keys()) == {'alpha', 'beta', 'phis', 'gammas'}

    assert params['alpha'].shape == (2,)
    assert_numerically_ok(params['alpha'])
    assert np.all(params['alpha'] > 0)

    assert len(params['beta']) == 2
    for beta_row in params['beta']:
        assert len(beta_row) == 4
        assert_numerically_ok(list(beta_row.values()))
        assert set(beta_row.keys()) == {please, crash, grandma, sun}
        assert np.isclose(sum(beta_row.values()), 1)

    assert len(params['phis']) == 3
    assert set(params['phis'].keys()) == {doc1, doc2, doc3}
    for document, phi_row in params['phis'].items():
        assert phi_row.shape == (len(document), 2)
        assert_numerically_ok(phi_row)
        assert np.all(np.isclose(phi_row.sum(axis=1), 1))

    assert len(params['gammas']) == 3
    assert set(params['gammas'].keys()) == {doc1, doc2, doc3}
    for document, gamma_row in params['gammas'].items():
        assert gamma_row.shape == (2,)
        assert_numerically_ok(gamma_row)
        assert np.all(gamma_row > 0)
        assert np.all(gamma_row >= params['alpha'])

    assert np.all(lower_bound_evol[1:] - lower_bound_evol[:-1] >= 0)

In [5]:
assert_lda(9)

{'alpha': array([0.01037415, 0.50187459]),
 'beta': [{Word(original_form='grandma', lda_form='grandma', include=True): 0.1435070382328889,
           Word(original_form='Please', lda_form='please', include=True): 0.5006432587935381,
           Word(original_form='sun', lda_form='sun', include=True): 0.2207055707081681,
           Word(original_form='crash', lda_form='crash', include=True): 0.13514413226540492},
          {Word(original_form='grandma', lda_form='grandma', include=True): 0.07668503255089015,
           Word(original_form='Please', lda_form='please', include=True): 0.3817921321387416,
           Word(original_form='sun', lda_form='sun', include=True): 0.31518777548144816,
           Word(original_form='crash', lda_form='crash', include=True): 0.22633505982892008}],
 'gammas': {Document including 3 words: array([0.01037418, 3.50187457]),
            Document including 2 words: array([0.01037416, 2.50187458]),
            Document including 2 words: array([0.01037415, 2.501

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.20519034948164
Document including 2 words lower bound is -2.1065780608135043
Document including 3 words lower bound is -3.6492022858869575
Document including 2 words lower bound is -3.2024425367127094
Document including 2 words lower bound is -2.103830248044602
Document including 3 words lower bound is -3.6457468469985486
Document including 2 words lower bound is -3.201326080565906
Document including 2 words lower bound is -2.1027137918977985
Document including 3 words lower bound is -3.6442680097626834
Document including 2 words lower bound is -3.2007195706430878
Document including 2 words lower bound is -2.1021072819749804
Document including 3 words lower bound is -3.6434417652668003
Document including 2 words lower bound is -3.2003385232692096
Document including 2 words lower bound is -2.1017262346011023
Document including 3 words lower bound is -3.6429134076693543
Document including 2 words lower bound is -3.2000769456876412
Document inc

In [6]:
for seed in trange(100):
    assert_lda(seed)

  0%|          | 0/100 [00:00<?, ?it/s]

{'alpha': array([0.5488135 , 0.71518937]),
 'beta': [{Word(original_form='grandma', lda_form='grandma', include=True): 0.19107688284120386,
           Word(original_form='Please', lda_form='please', include=True): 0.2718584733852522,
           Word(original_form='sun', lda_form='sun', include=True): 0.29131130808112,
           Word(original_form='crash', lda_form='crash', include=True): 0.24575333569242383},
          {Word(original_form='grandma', lda_form='grandma', include=True): 0.3600506427818107,
           Word(original_form='Please', lda_form='please', include=True): 0.16349449532149288,
           Word(original_form='sun', lda_form='sun', include=True): 0.14326419052519523,
           Word(original_form='crash', lda_form='crash', include=True): 0.3331906713715011}],
 'gammas': {Document including 3 words: array([1.14408964, 3.11991323]),
            Document including 2 words: array([1.45520026, 1.80880261]),
            Document including 2 words: array([1.14956005, 2.11444

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.407267263475272
Document including 2 words lower bound is -2.367020032270053
Document including 3 words lower bound is -3.942798090043405
Document including 2 words lower bound is -3.29854405824643
Document including 2 words lower bound is -2.2871997718840427
Document including 3 words lower bound is -3.793567973717638
Document including 2 words lower bound is -3.2448566836414052
Document including 2 words lower bound is -2.2543742294609865
Document including 3 words lower bound is -3.698195821389484
Document including 2 words lower bound is -3.2143607063800324
Document including 2 words lower bound is -2.2395282048322755
Document including 3 words lower bound is -3.6255129367959182
Document including 2 words lower bound is -3.1953598536507934
Document including 2 words lower bound is -2.233500212021612
Document including 3 words lower bound is -3.5668243885275404
Document including 2 words lower bound is -3.1815754886849597
Document includi

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.3790818179214925
Document including 2 words lower bound is -2.30720459447018
Document including 3 words lower bound is -3.9287371051061943
Document including 2 words lower bound is -3.3228757893876884
Document including 2 words lower bound is -2.2163177382050043
Document including 3 words lower bound is -3.809585190390847
Document including 2 words lower bound is -3.3090202524893195
Document including 2 words lower bound is -2.169210299305817
Document including 3 words lower bound is -3.7373986147399427
Document including 2 words lower bound is -3.3188536507374122
Document including 2 words lower bound is -2.1371062358108395
Document including 3 words lower bound is -3.686223473186113
Document including 2 words lower bound is -3.3346684591683733
Document including 2 words lower bound is -2.1150982552883884
Document including 3 words lower bound is -3.6512067889964412
Document including 2 words lower bound is -3.347927190887998
Document inclu

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.215278534907895
Document including 2 words lower bound is -2.116666246239774
Document including 3 words lower bound is -3.662625766642684
Document including 2 words lower bound is -3.208196888220442
Document including 2 words lower bound is -2.1095845995523206
Document including 3 words lower bound is -3.653738159720632
Document including 2 words lower bound is -3.205353566495617
Document including 2 words lower bound is -2.1067412778274957
Document including 3 words lower bound is -3.649975614970856
Document including 2 words lower bound is -3.2038169161045005
Document including 2 words lower bound is -2.105204627436386
Document including 3 words lower bound is -3.6478834633467088
Document including 2 words lower bound is -3.20285434405087
Document including 2 words lower bound is -2.1042420553827483
Document including 3 words lower bound is -3.6465492846628464
Document including 2 words lower bound is -3.202194836482377
Document including 

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -2.8517781777166484
Document including 2 words lower bound is -2.5611003757759825
Document including 3 words lower bound is -3.8215438658258996
Document including 2 words lower bound is -2.7736414387818176
Document including 2 words lower bound is -2.524838636881717
Document including 3 words lower bound is -3.468630170170888
Document including 2 words lower bound is -2.748914045051574
Document including 2 words lower bound is -2.5266129972883564
Document including 3 words lower bound is -3.3049493663708187
Document including 2 words lower bound is -2.7315037843872285
Document including 2 words lower bound is -2.538547408021601
Document including 3 words lower bound is -3.2239083864286826
Document including 2 words lower bound is -2.709310781619098
Document including 2 words lower bound is -2.5531770999989702
Document including 3 words lower bound is -3.1840490666218573
Document including 2 words lower bound is -2.6865398885990324
Document incl

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.3774335159610267
Document including 2 words lower bound is -2.330138945932318
Document including 3 words lower bound is -3.9477660241578585
Document including 2 words lower bound is -3.2903015014598695
Document including 2 words lower bound is -2.248727196727491
Document including 3 words lower bound is -3.8561428855896986
Document including 2 words lower bound is -3.2448627852818372
Document including 2 words lower bound is -2.20666889929036
Document including 3 words lower bound is -3.813116579763804
Document including 2 words lower bound is -3.2163763933498983
Document including 2 words lower bound is -2.1806516299724756
Document including 3 words lower bound is -3.791015231709051
Document including 2 words lower bound is -3.197383219623104
Document including 2 words lower bound is -2.1630764595207608
Document including 3 words lower bound is -3.7787186966063073
Document including 2 words lower bound is -3.184545555238503
Document includi

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -2.875158592269494
Document including 2 words lower bound is -2.3031095185836663
Document including 3 words lower bound is -4.107757002579093
Document including 2 words lower bound is -2.6944769017296726
Document including 2 words lower bound is -2.3982754150343535
Document including 3 words lower bound is -3.8488421446912593
Document including 2 words lower bound is -2.6539911136331678
Document including 2 words lower bound is -2.5206307337534923
Document including 3 words lower bound is -3.6050612260791572
Document including 2 words lower bound is -2.654439250307832
Document including 2 words lower bound is -2.599132300025861
Document including 3 words lower bound is -3.4430667668035153
Document including 2 words lower bound is -2.6568247573004835
Document including 2 words lower bound is -2.631887465284324
Document including 3 words lower bound is -3.3447562481259934
Document including 2 words lower bound is -2.652218576916159
Document inclu

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.68572819668403
Document including 2 words lower bound is -2.297969394501009
Document including 3 words lower bound is -4.293073630235876
Document including 2 words lower bound is -3.486824323021275
Document including 2 words lower bound is -2.236070994036539
Document including 3 words lower bound is -4.0271750461142295
Document including 2 words lower bound is -3.3925211613027795
Document including 2 words lower bound is -2.1957291463591764
Document including 3 words lower bound is -3.9044633897516543
Document including 2 words lower bound is -3.341104674595426
Document including 2 words lower bound is -2.1716462608156086
Document including 3 words lower bound is -3.8383103871713007
Document including 2 words lower bound is -3.3092604607370593
Document including 2 words lower bound is -2.156645637258963
Document including 3 words lower bound is -3.7977897395273974
Document including 2 words lower bound is -3.2876678797992027
Document includi

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.2421565073206042
Document including 2 words lower bound is -2.1435428716709057
Document including 3 words lower bound is -3.699004930864136
Document including 2 words lower bound is -3.224845650427255
Document including 2 words lower bound is -2.126233156394205
Document including 3 words lower bound is -3.677023714151785
Document including 2 words lower bound is -3.2173979708010823
Document including 2 words lower bound is -2.11878557027024
Document including 3 words lower bound is -3.6671108540955455
Document including 2 words lower bound is -3.2132473930467427
Document including 2 words lower bound is -2.114635033377498
Document including 3 words lower bound is -3.661440139705841
Document including 2 words lower bound is -3.2106011084488966
Document including 2 words lower bound is -2.111988770430351
Document including 3 words lower bound is -3.6577637228652304
Document including 2 words lower bound is -3.2087670041657606
Document includin

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.0432117130555736
Document including 2 words lower bound is -2.4289769663690626
Document including 3 words lower bound is -3.6360882886477937
Document including 2 words lower bound is -2.941935342241417
Document including 2 words lower bound is -2.426521898210202
Document including 3 words lower bound is -3.4274858119831504
Document including 2 words lower bound is -2.858021077871938
Document including 2 words lower bound is -2.4341013389090804
Document including 3 words lower bound is -3.340632613313452
Document including 2 words lower bound is -2.8001403244781
Document including 2 words lower bound is -2.449163571476123
Document including 3 words lower bound is -3.2917222201360645
Document including 2 words lower bound is -2.7597186573555743
Document including 2 words lower bound is -2.468368095102153
Document including 3 words lower bound is -3.257942473768438
Document including 2 words lower bound is -2.7299423900349247
Document including

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.20519034948164
Document including 2 words lower bound is -2.1065780608135043
Document including 3 words lower bound is -3.6492022858869575
Document including 2 words lower bound is -3.2024425367127094
Document including 2 words lower bound is -2.103830248044602
Document including 3 words lower bound is -3.6457468469985486
Document including 2 words lower bound is -3.201326080565906
Document including 2 words lower bound is -2.1027137918977985
Document including 3 words lower bound is -3.6442680097626834
Document including 2 words lower bound is -3.2007195706430878
Document including 2 words lower bound is -2.1021072819749804
Document including 3 words lower bound is -3.6434417652668003
Document including 2 words lower bound is -3.2003385232692096
Document including 2 words lower bound is -2.1017262346011023
Document including 3 words lower bound is -3.6429134076693543
Document including 2 words lower bound is -3.2000769456876412
Document inc

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is -3.2106288833599237
Document including 2 words lower bound is -2.1120165946918092
Document including 3 words lower bound is -3.656577043615883
Document including 2 words lower bound is -3.2058400972731746
Document including 2 words lower bound is -2.10722780860506
Document including 3 words lower bound is -3.6505022367481246
Document including 2 words lower bound is -3.203792087355936
Document including 2 words lower bound is -2.1051797986878213
Document including 3 words lower bound is -3.6477777040447066
Document including 2 words lower bound is -3.202653785193341
Document including 2 words lower bound is -2.1040414965252197
Document including 3 words lower bound is -3.6462229613804595
Document including 2 words lower bound is -3.2019291522548556
Document including 2 words lower bound is -2.103316863586741
Document including 3 words lower bound is -3.645216441722397
Document including 2 words lower bound is -3.2014274126166242
Document includ

  0%|          | 0/32 [00:00<?, ?it/s]

Document including 2 words lower bound is nan
Document including 2 words lower bound is nan
Document including 3 words lower bound is nan
Document including 2 words lower bound is nan
Document including 2 words lower bound is nan
Document including 3 words lower bound is nan


AssertionError: 