In [1]:
from data_gen_utils import *
import sys
sys.path.append("../") # go to parent dir

import numpy as np
import torch
import matplotlib.pyplot as plt
import mpmath

from synthetic.generate import SingleTaskTreeDepsGenerator
from metal.label_model import LabelModel, CliqueTree
from metal.label_model.utils import (
    compute_mu,
    compute_covariance,
    compute_inv_covariance,
    print_matrix,
    visualize_matrix
)

  return f(*args, **kwds)


In [2]:
M = 6
n = 10000
K = 1
dg = DataGenerator(M,n)

In [3]:
dg.generate_O(k=3)

In [4]:
print(dg.deps)

[(0, 1), (2, 3), (4, 5)]


In [5]:
# remove deps
print(dg.deps_no_diag)

[(0, 1), (2, 3), (4, 5)]


In [6]:
print(dg.L)

w,h = (dg.L).shape
L = np.zeros((w,h),dtype=float)
for i in range(w):
    for j in range(h):
        if dg.L[i,j] > 0:
            L[i,j] = 1
        else:
            L[i,j] = 0

[[-1.  1. -1. ...  1.  1. -1.]
 [-1.  1. -1. ...  1.  1. -1.]
 [ 1.  1. -1. ...  1.  1.  1.]
 [-1.  1. -1. ...  1. -1. -1.]
 [-1.  1. -1. ...  1. -1. -1.]
 [-1.  1. -1. ...  1.  1. -1.]]


In [7]:
print(L)
print(dg.deps_no_diag)

[[0. 1. 0. ... 1. 1. 0.]
 [0. 1. 0. ... 1. 1. 0.]
 [1. 1. 0. ... 1. 1. 1.]
 [0. 1. 0. ... 1. 0. 0.]
 [0. 1. 0. ... 1. 0. 0.]
 [0. 1. 0. ... 1. 1. 0.]]
[(0, 1), (2, 3), (4, 5)]


In [8]:
lm = LabelModel(k=1)
lm.config['higher_order_cliques'] = False
lm.config['verbose'] = True

In [9]:
lm.train(L=dg.L.T,
         deps=dg.deps_no_diag,
         all_unary_cliques=True,
        higher_order_cliques=True,
        n_epochs=50000,
        print_every=5000,
        lr=1e-6,
        l2=0.1,
        O_inv_prec=1024,)

I GOT HERE
I GOT HERE 2
I GOT HERE 3
INIT 1
GET CLIQUE TREE 1
GET CLIQUE TREE 2
GET CLIQUE TREE 3
GET CLIQUE TREE 4
GET CLIQUE TREE 5
INIT 2
CLIQUE TREE 1
START INDEX:  9
I GOT HERE 4
L AUG SAHPE:  (10000, 9)
ALL UNARY CLIQUES
HIGHER ORDER CLIQUES
(9, 9)
[[0.4998 0.2885 0.3099 0.2368 0.3244 0.2472 0.2243 0.2885 0.2081]
 [0.2885 0.3146 0.2224 0.189  0.2395 0.201  0.1816 0.2885 0.1672]
 [0.3099 0.2224 0.4995 0.2641 0.3052 0.2249 0.2038 0.204  0.2641]
 [0.2368 0.189  0.2641 0.2988 0.2298 0.1931 0.1749 0.1735 0.2641]
 [0.3244 0.2395 0.3052 0.2298 0.4977 0.2857 0.2857 0.2181 0.2034]
 [0.2472 0.201  0.2249 0.1931 0.2857 0.3158 0.2857 0.1839 0.1702]
 [0.2243 0.1816 0.2038 0.1749 0.2857 0.2857 0.2857 0.166  0.1549]
 [0.2885 0.2885 0.204  0.1735 0.2181 0.1839 0.166  0.2885 0.1531]
 [0.2081 0.1672 0.2641 0.2641 0.2034 0.1702 0.1549 0.1531 0.2641]]
Computing O^{-1}...
L AUG SAHPE:  (10000, 9)
ALL UNARY CLIQUES
HIGHER ORDER CLIQUES
O unnorm   [4998.0  2885.0  3099.0  2368.0  3244.0  2472.0  2243.0

In [10]:
lm._set_constants(dg.L.T)
#lm._set_dependencies(dg.deps_no_diag)
mu = compute_mu(lm._get_augmented_label_matrix(dg.L.T), dg.Y.T, K, np.full(K, 1/K))


# Test against the true parameter values
mu_est = lm.mu.detach().numpy()
print(mu_est)
print(f"Average absolute error: {np.mean(np.abs(mu_est - mu))}")

mu_est_sm = mu_est[:6]
mu_true = dg.mu
print(f"Average absolute error: {np.mean(np.abs(mu_est_sm - mu_true))}")

L AUG SAHPE:  (10000, 9)
ALL UNARY CLIQUES
HIGHER ORDER CLIQUES
[[0.5592211 ]
 [0.4006851 ]
 [0.5318669 ]
 [0.3888934 ]
 [0.5676346 ]
 [0.42128956]
 [0.39408115]
 [0.34634444]
 [0.34323525]]
Average absolute error: 0.2100596065610008
Average absolute error: 0.11663740215301516


In [11]:
def search_params(lr_params, l2_params, L, deps, mu_true):
    results_dict = {}
    for lr in lr_params:
        for l2 in l2_params:
            lm = LabelModel(k=1)
            lm.config['higher_order_cliques'] = False
            lm.config['verbose'] = True
            
            lm.train(L=L,
            deps=deps,
            all_unary_cliques=True,
            higher_order_cliques=True,
            n_epochs=50000,
            print_every=5000,
            lr=lr,
            l2=l2,
            O_inv_prec=1024,)
            
            lm._set_constants(L)
            #lm._set_dependencies(dg.deps_no_diag)
            mu_est = lm.mu.detach().numpy()

            mu_est_sm = mu_est[:6]
            print(f"Average absolute error: {np.mean(np.abs(mu_est_sm - mu_true))}")
            results_dict[(lr,l2)] = np.mean(np.abs(mu_est_sm - mu_true))
    return results_dict

In [12]:
search_res = search_params([1e-8,1e-7,1e-6,1e-5],[0,0.1,0.2], dg.L.T, dg.deps_no_diag, dg.mu)

I GOT HERE
I GOT HERE 2
I GOT HERE 3
INIT 1
GET CLIQUE TREE 1
GET CLIQUE TREE 2
GET CLIQUE TREE 3
GET CLIQUE TREE 4
GET CLIQUE TREE 5
INIT 2
CLIQUE TREE 1
START INDEX:  9
I GOT HERE 4
L AUG SAHPE:  (10000, 9)
ALL UNARY CLIQUES
HIGHER ORDER CLIQUES
(9, 9)
[[0.4998 0.2885 0.3099 0.2368 0.3244 0.2472 0.2243 0.2885 0.2081]
 [0.2885 0.3146 0.2224 0.189  0.2395 0.201  0.1816 0.2885 0.1672]
 [0.3099 0.2224 0.4995 0.2641 0.3052 0.2249 0.2038 0.204  0.2641]
 [0.2368 0.189  0.2641 0.2988 0.2298 0.1931 0.1749 0.1735 0.2641]
 [0.3244 0.2395 0.3052 0.2298 0.4977 0.2857 0.2857 0.2181 0.2034]
 [0.2472 0.201  0.2249 0.1931 0.2857 0.3158 0.2857 0.1839 0.1702]
 [0.2243 0.1816 0.2038 0.1749 0.2857 0.2857 0.2857 0.166  0.1549]
 [0.2885 0.2885 0.204  0.1735 0.2181 0.1839 0.166  0.2885 0.1531]
 [0.2081 0.1672 0.2641 0.2641 0.2034 0.1702 0.1549 0.1531 0.2641]]
Computing O^{-1}...
L AUG SAHPE:  (10000, 9)
ALL UNARY CLIQUES
HIGHER ORDER CLIQUES
O unnorm   [4998.0  2885.0  3099.0  2368.0  3244.0  2472.0  2243.0

[Epoch 11000] Loss: 24.709484
[Epoch 12000] Loss: 24.529804
[Epoch 13000] Loss: 24.351599
[Epoch 14000] Loss: 24.175398
[Epoch 15000] Loss: 24.000193
[Epoch 16000] Loss: 23.825985
[Epoch 17000] Loss: 23.656216
[Epoch 18000] Loss: 23.488647
[Epoch 19000] Loss: 23.322557
[Epoch 20000] Loss: 23.159206
[Epoch 21000] Loss: 22.996990
[Epoch 22000] Loss: 22.835669
[Epoch 23000] Loss: 22.675251
[Epoch 24000] Loss: 22.516064
[Epoch 25000] Loss: 22.358868
[Epoch 26000] Loss: 22.206434
[Epoch 27000] Loss: 22.055223
[Epoch 28000] Loss: 21.904957
[Epoch 29000] Loss: 21.756050
[Epoch 30000] Loss: 21.607939
[Epoch 31000] Loss: 21.460779
[Epoch 32000] Loss: 21.315912
[Epoch 33000] Loss: 21.172890
[Epoch 34000] Loss: 21.031158
[Epoch 35000] Loss: 20.890253
[Epoch 36000] Loss: 20.752766
[Epoch 37000] Loss: 20.615992
[Epoch 38000] Loss: 20.480452
[Epoch 39000] Loss: 20.346531
[Epoch 40000] Loss: 20.213730
[Epoch 41000] Loss: 20.082296
[Epoch 42000] Loss: 19.951916
[Epoch 43000] Loss: 19.822586
[Epoch 440

[Epoch 24000] Loss: 47.140713
[Epoch 25000] Loss: 45.178799
[Epoch 26000] Loss: 43.292202
[Epoch 27000] Loss: 41.480431
[Epoch 28000] Loss: 39.744484
[Epoch 29000] Loss: 38.080231
[Epoch 30000] Loss: 36.486584
[Epoch 31000] Loss: 34.963924
[Epoch 32000] Loss: 33.510208
[Epoch 33000] Loss: 32.122761
[Epoch 34000] Loss: 30.800802
[Epoch 35000] Loss: 29.543221
[Epoch 36000] Loss: 28.346935
[Epoch 37000] Loss: 27.210602
[Epoch 38000] Loss: 26.132212
[Epoch 39000] Loss: 25.109831
[Epoch 40000] Loss: 24.141449
[Epoch 41000] Loss: 23.224850
[Epoch 42000] Loss: 22.358112
[Epoch 43000] Loss: 21.539097
[Epoch 44000] Loss: 20.765676
[Epoch 45000] Loss: 20.035957
[Epoch 46000] Loss: 19.347834
[Epoch 47000] Loss: 18.699251
[Epoch 48000] Loss: 18.088417
[Epoch 49000] Loss: 17.513191
Estimating \mu...
[Epoch 0] Loss: 21.315523
[Epoch 1000] Loss: 20.287907
[Epoch 2000] Loss: 19.340706
[Epoch 3000] Loss: 18.473360
[Epoch 4000] Loss: 17.677042
[Epoch 5000] Loss: 16.944067
[Epoch 6000] Loss: 16.267794
[E

[Epoch 1000] Loss: 163.276215
[Epoch 2000] Loss: 161.306549
[Epoch 3000] Loss: 159.397949
[Epoch 4000] Loss: 157.538162
[Epoch 5000] Loss: 155.717224
[Epoch 6000] Loss: 153.925400
[Epoch 7000] Loss: 152.154297
[Epoch 8000] Loss: 150.396545
[Epoch 9000] Loss: 148.645187
[Epoch 10000] Loss: 146.894119
[Epoch 11000] Loss: 145.137604
[Epoch 12000] Loss: 143.370712
[Epoch 13000] Loss: 141.588882
[Epoch 14000] Loss: 139.788101
[Epoch 15000] Loss: 137.964691
[Epoch 16000] Loss: 136.115295
[Epoch 17000] Loss: 134.237350
[Epoch 18000] Loss: 132.328339
[Epoch 19000] Loss: 130.385986
[Epoch 20000] Loss: 128.408920
[Epoch 21000] Loss: 126.396248
[Epoch 22000] Loss: 124.346581
[Epoch 23000] Loss: 122.259697
[Epoch 24000] Loss: 120.134766
[Epoch 25000] Loss: 117.973503
[Epoch 26000] Loss: 115.776184
[Epoch 27000] Loss: 113.543442
[Epoch 28000] Loss: 111.276276
[Epoch 29000] Loss: 108.976158
[Epoch 30000] Loss: 106.646225
[Epoch 31000] Loss: 104.288956
[Epoch 32000] Loss: 101.906731
[Epoch 33000] Los

[Epoch 1000] Loss: 140.470673
[Epoch 2000] Loss: 114.660782
[Epoch 3000] Loss: 81.681046
[Epoch 4000] Loss: 49.600975
[Epoch 5000] Loss: 27.768671
[Epoch 6000] Loss: 17.117510
[Epoch 7000] Loss: 12.915236
[Epoch 8000] Loss: 11.349799
[Epoch 9000] Loss: 10.727670
[Epoch 10000] Loss: 10.448638
[Epoch 11000] Loss: 10.307049
[Epoch 12000] Loss: 10.227310
[Epoch 13000] Loss: 10.178443
[Epoch 14000] Loss: 10.146449
[Epoch 15000] Loss: 10.124404
[Epoch 16000] Loss: 10.108603
[Epoch 17000] Loss: 10.096969
[Epoch 18000] Loss: 10.088205
[Epoch 19000] Loss: 10.081504
[Epoch 20000] Loss: 10.076336
[Epoch 21000] Loss: 10.072303
[Epoch 22000] Loss: 10.069139
[Epoch 23000] Loss: 10.066656
[Epoch 24000] Loss: 10.064693
[Epoch 25000] Loss: 10.063140
[Epoch 26000] Loss: 10.061905
[Epoch 27000] Loss: 10.060921
[Epoch 28000] Loss: 10.060143
[Epoch 29000] Loss: 10.059522
[Epoch 30000] Loss: 10.059024
[Epoch 31000] Loss: 10.058627
[Epoch 32000] Loss: 10.058308
[Epoch 33000] Loss: 10.058060
[Epoch 34000] Los

[Epoch 1000] Loss: 26.524721
[Epoch 2000] Loss: 8.739755
[Epoch 3000] Loss: 8.492949
[Epoch 4000] Loss: 8.470882
[Epoch 5000] Loss: 8.468443
[Epoch 6000] Loss: 8.468159
[Epoch 7000] Loss: 8.468125
[Epoch 8000] Loss: 8.468122
[Epoch 9000] Loss: 8.468122
[Epoch 10000] Loss: 8.468122
[Epoch 11000] Loss: 8.468122
[Epoch 12000] Loss: 8.468122
[Epoch 13000] Loss: 8.468122
[Epoch 14000] Loss: 8.468122
[Epoch 15000] Loss: 8.468122
[Epoch 16000] Loss: 8.468122
[Epoch 17000] Loss: 8.468122
[Epoch 18000] Loss: 8.468122
[Epoch 19000] Loss: 8.468122
[Epoch 20000] Loss: 8.468122
[Epoch 21000] Loss: 8.468122
[Epoch 22000] Loss: 8.468122
[Epoch 23000] Loss: 8.468122
[Epoch 24000] Loss: 8.468122
[Epoch 25000] Loss: 8.468122
[Epoch 26000] Loss: 8.468122
[Epoch 27000] Loss: 8.468122
[Epoch 28000] Loss: 8.468122
[Epoch 29000] Loss: 8.468122
[Epoch 30000] Loss: 8.468122
[Epoch 31000] Loss: 8.468122
[Epoch 32000] Loss: 8.468122
[Epoch 33000] Loss: 8.468122
[Epoch 34000] Loss: 8.468122
[Epoch 35000] Loss: 8.

[Epoch 1000] Loss: 15.468861
[Epoch 2000] Loss: 11.684264
[Epoch 3000] Loss: 11.639183
[Epoch 4000] Loss: 11.635347
[Epoch 5000] Loss: 11.634979
[Epoch 6000] Loss: 11.634944
[Epoch 7000] Loss: 11.634941
[Epoch 8000] Loss: 11.634941
[Epoch 9000] Loss: 11.634941
[Epoch 10000] Loss: 11.634941
[Epoch 11000] Loss: 11.634941
[Epoch 12000] Loss: 11.634941
[Epoch 13000] Loss: 11.634941
[Epoch 14000] Loss: 11.634941
[Epoch 15000] Loss: 11.634941
[Epoch 16000] Loss: 11.634941
[Epoch 17000] Loss: 11.634941
[Epoch 18000] Loss: 11.634941
[Epoch 19000] Loss: 11.634941
[Epoch 20000] Loss: 11.634941
[Epoch 21000] Loss: 11.634941
[Epoch 22000] Loss: 11.634941
[Epoch 23000] Loss: 11.634941
[Epoch 24000] Loss: 11.634941
[Epoch 25000] Loss: 11.634941
[Epoch 26000] Loss: 11.634941
[Epoch 27000] Loss: 11.634941
[Epoch 28000] Loss: 11.634941
[Epoch 29000] Loss: 11.634941
[Epoch 30000] Loss: 11.634941
[Epoch 31000] Loss: 11.634941
[Epoch 32000] Loss: 11.634941
[Epoch 33000] Loss: 11.634941
[Epoch 34000] Loss:

In [14]:
print(search_res)

{(1e-08, 0): 0.1721224686053064, (1e-08, 0.1): 0.2677678901089563, (1e-08, 0.2): 0.24661573973761663, (1e-07, 0): 0.28300384945207174, (1e-07, 0.1): 0.14493701347510018, (1e-07, 0.2): 0.2973745848966969, (1e-06, 0): 0.1129008631123437, (1e-06, 0.1): 0.11561035883426668, (1e-06, 0.2): 0.11723962264855706, (1e-05, 0): 0.11548006538814971, (1e-05, 0.1): 0.11554513213634493, (1e-05, 0.2): 0.11560258935822383}


### Changing learning rate

In [13]:
from amc_utils import *
O = dg.O
O_inv = dg.Oinv
real_deps = dg.deps
mu_rec = solveMatrixCompletionWithMu(O_inv,O,real_deps)
print(f"Average absolute error: {np.mean(np.abs(mu_rec - mu_true))}")

Average absolute error: 0.04020939583686347
