In [220]:
import itertools
import json
import numpy as np
from hmmlearn import hmm
from models import StandardHMM, DenseHMM
import time
from tqdm import tqdm
import ssm
from ssm.util import find_permutation

np.random.seed(2022)

In [None]:
# ssm_hmm = ssm.HMM(n, 1, observations="categorical")
# ssm_hmm.observations.__init__(K=n, D=1, C=n_features)
# ssm_hmm.fit(X_true)
#
# most_likely_states = ssm_hmm.most_likely_states(X_true)
# ssm_hmm.permute(find_permutation(Y_true, most_likely_states))
# print(ssm_hmm.transitions.transition_matrix.round(2))
#
# (ssm_hmm.most_likely_states(X_true) == Y_true).mean()

In [248]:
def prepare_params(n, v, A_stat=False):
    pi = np.random.uniform(size=n)
    pi /= pi.sum()
    if A_stat:
        A = np.repeat(pi[np.newaxis,:], n, axis=0)
    else:
        A = np.exp(np.random.uniform(0, 5, size=(n, n)))
        A /= A.sum(axis=1)[:, np.newaxis]

    B = np.exp(np.random.uniform(0, 5, size=(n,  v)))
    B /= B.sum(axis=1)[:, np.newaxis]
    return pi, A, B


def my_hmm_sampler(pi, A, B, T):
    n = pi.shape[0]
    v = B.shape[1]
    X = [np.random.choice(np.arange(n), 1, replace=True, p=pi)]
    for t in range(T - 1):
        X.append(np.random.choice(np.arange(n), 1, replace=True, p=A[X[t][0], :]))
    Y = np.concatenate([np.random.choice(np.arange(v), 1, replace=True, p=B[s[0], :]) for s in X]).reshape(-1, 1)
    return X, Y


def experiment(n, m, T, s,  l, A_stat=False):
    pi, A, B = prepare_params(n, m, A_stat)
    data = [my_hmm_sampler(pi, A, B, T) for _ in range(s)]
    X_true = np.concatenate([x[1] for x in data])
    lenghts = [len(x[1]) for x in data]
    Y_true = np.concatenate([np.concatenate(y[0]) for y in data])
    assert(np.unique(X_true).shape[0] == m)

    standard_acc = []
    standard_timer = []
    dense_acc = []
    dense_timer = []
    hmml_acc = []
    hmml_timer = []
    for _ in tqdm(range(10), desc="HMM"):
        A_init = np.exp(np.random.uniform(0, 5, size=(n, n)))
        A_init /= A_init.sum(axis=1)[:, np.newaxis]


        start = time.perf_counter()
        standardhmm = StandardHMM(n, em_iter=1000, init_params="se")
        standardhmm.transmat_ = A_init
        standardhmm.fit(X_true, lenghts)
        preds = np.concatenate([standardhmm.predict(x[1].transpose()) for x in data])
        standard_timer.append(time.perf_counter() - start)
        perm = find_permutation(preds, Y_true)
        standard_acc.append((Y_true == np.array([perm[i] for i in preds])).mean())

        start = time.perf_counter()
        densehmm = DenseHMM(n, init_params="se",  mstep_config={"l_uz":  l, "l_vw": l})
        densehmm.transmat_ = A_init
        densehmm.fit_coocs(X_true, lenghts)
        preds = np.concatenate([densehmm.predict(x[1].transpose()) for x in data])
        perm = find_permutation(preds, Y_true)
        dense_timer.append(time.perf_counter() - start)
        dense_acc.append((Y_true == np.array([perm[i] for i in preds])).mean())

        start = time.perf_counter()
        hmml = hmm.MultinomialHMM(n, n_iter=1000, init_params="se")
        hmml.transmat_ = A_init
        hmml.fit(X_true, lenghts)
        preds = np.concatenate([hmml.predict(x[1].transpose()) for x in data])
        hmml_timer.append(time.perf_counter() - start)
        perm = find_permutation(preds, Y_true)
        hmml_acc.append((Y_true == np.array([perm[i] for i in preds])).mean())

    return {"standard_acc": standard_acc,
            "standard_time":  standard_timer,
            # "ssm_acc": ssm_acc,
            # "ssm_time":  ssm_time,
            "dense_acc": dense_acc,
            "dense_time": dense_timer,
            "hmml_acc": hmml_acc,
            "hmml_time": hmml_timer,
            "pi": pi,
            "A": A,
            "B": B}


def run_experiments():
    results = []
    for  n, v, T, A_stat, l in itertools.product([2, 3, 4, 8],  [5, 10, 20], [10, 100, 1000],  [True, False],  [3,  4,  5]):
        tmp = experiment(n, v, T, 100, l, A_stat)
        results.append({**tmp, "n": n,  "v":  v, "T": T,  "A_stat": A_stat})
    return results

In [249]:
result = run_experiments()
with open("experiment_result_22-07-22.json",  "w")  as f:
    json.dump(result, f)

HMM:   0%|                                                                                                                                                                          | 0/10 [00:00<?, ?it/s]

0.20235856643234928
0.053939594127372134
0.04690433490996932
0.0406934301685718
0.03336657687176476
0.03069774025192095
0.03003195774254358
0.02979450270023473
0.029687173338327304
0.0296255843632079


HMM:  10%|████████████████▏                                                                                                                                                 | 1/10 [00:33<05:04, 33.86s/it]

0.04512350601379992
0.014762474376040328
0.0069644784019549785
0.004302200418771004
0.003571564524894417
0.0034739793572691616
0.0034998113945890258
0.0035333342831578643
0.0035844001341636136
0.0036624148834138934


HMM:  20%|████████████████████████████████▍                                                                                                                                 | 2/10 [01:08<04:32, 34.05s/it]

0.1972514787521108
0.12252647581932401
0.048229217242516796
0.006220090553376371
0.005585169660128122
0.0060053114721734776
0.006115295729514705
0.0060832983732341354
0.006005763616652883
0.005912589987429742


HMM:  30%|████████████████████████████████████████████████▌                                                                                                                 | 3/10 [01:42<03:58, 34.06s/it]

0.1315919587570004
0.04300558412054491
0.027498160805133343
0.016284305605237272
0.008574771484931035
0.005882578547320635
0.0049322385282688785
0.004520831804576473
0.004324762742548756
0.004223608638037702


HMM:  40%|████████████████████████████████████████████████████████████████▊                                                                                                 | 4/10 [02:17<03:28, 34.75s/it]

0.18857238492109527
0.038630471029434635
0.0343858022078314
0.031617242607847325
0.029673458005374513
0.028798412856573346
0.028424464043500908
0.028175974366673493
0.027932872515083404
0.027660232335527273


HMM:  50%|█████████████████████████████████████████████████████████████████████████████████                                                                                 | 5/10 [02:52<02:53, 34.62s/it]

0.2568858373885619
0.025788032003362278
0.024689435171368557
0.022248501381250622
0.01769905586374042
0.011569457911611807
0.007303324124262249
0.005516829698914518
0.004698972921812307
0.004190112582507422


HMM:  60%|█████████████████████████████████████████████████████████████████████████████████████████████████▏                                                                | 6/10 [03:25<02:16, 34.23s/it]

0.03622408279441812
0.014927187028952224
0.010247168823193764
0.008657030225672893
0.008100459655934068
0.007874796646809084
0.007747397974801344
0.007633477389407362
0.007502884646081007
0.007347444917278283


HMM:  70%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████▍                                                | 7/10 [03:59<01:41, 33.99s/it]

0.0640635834120254
0.0437265533032421
0.038937250487974924
0.03540435524548484
0.03164342223812075
0.027503281307357096
0.023029067945542735
0.018226512557996486
0.013382952048284965
0.009406940943126366


HMM:  80%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌                                | 8/10 [04:33<01:07, 34.00s/it]

0.18411012935306975
0.04157192221847166
0.0304920763620747
0.018484422355626418
0.011308786685208913
0.008228328809571884
0.006968271400628219
0.006191103818686073
0.005616624988193396
0.005225687570888993


HMM:  90%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▊                | 9/10 [05:07<00:34, 34.11s/it]

0.30854883946552303
0.15397397634052956
0.03019149611413962
0.013709848806634535
0.009067207231412116
0.006877119852628459
0.00589915964045131
0.005456207395112094
0.005230278231821815
0.005094978363994758


HMM: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 10/10 [05:43<00:00, 34.39s/it]


{'standard_acc': [0.58, 0.59, 0.62, 0.57, 0.55, 0.54, 0.49, 0.59, 0.57, 0.57],
 'standard_time': [23.75400410295697,
  24.110998097981792,
  23.848815653007478,
  24.314705462020356,
  23.425166082044598,
  23.61325781099731,
  23.402705926971976,
  23.651125258009415,
  23.950039413000923,
  23.670056675968226],
 'dense_acc': [0.33, 0.43, 0.41, 0.4, 0.35, 0.55, 0.41, 0.49, 0.47, 0.48],
 'dense_time': [9.981763622025028,
  9.901767604984343,
  10.11854359798599,
  11.41895427799318,
  10.84140859998297,
  9.703426597989164,
  9.966126161045395,
  10.19770968699595,
  10.275200691015925,
  12.456446995958686],
 'hmml_acc': [0.47, 0.44, 0.52, 0.41, 0.44, 0.41, 0.38, 0.48, 0.44, 0.45],
 'hmml_time': [0.12795152503531426,
  0.16973707097349688,
  0.09941593604162335,
  0.07689894904615358,
  0.1073035430163145,
  0.16037668898934498,
  0.12198578100651503,
  0.1589150340296328,
  0.14459453499875963,
  0.11098437098553404],
 'pi': array([0.00548345, 0.20786194, 0.01437976, 0.37892492, 0.39