In [1]:
%matplotlib inline
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../") # go to parent dir

import numpy as np
import torch
from torch import nn, optim
import matplotlib.pyplot as plt

from itertools import product

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Step 1: Generate the true class balance to be recovered

In [2]:
K = 5

# Generate the true class balance to be recovered
p_Y = np.random.random(K)
p_Y /= p_Y.sum()
p_Y

array([0.27229227, 0.30204649, 0.02238635, 0.13577854, 0.26749635])

### Step 2: Generate the true conditional probability tables (CPTs) for the LFs

Separate simple process here to keep simple (later merge this with the SPA generator).
Generate in terms of the _conditional accuracies_ (which is equivalent to the recall...):
$$
\alpha_{i,y',y} = P(\lambda_i = y' | Y = y)
$$

Note that this table should be normalized such that:
$$
\sum_{y'} \alpha_{i,y',y} = 1
$$

In [3]:
M = 25
alphas = []
for i in range(M):
    a = np.random.random((K,K))
    alphas.append( a @ np.diag(1 / a.sum(axis=0)) )
alpha = np.array(alphas)

assert np.all(np.abs(alpha.sum(axis=1) - 1) < 1e-5)
alpha[0]

array([[0.07319756, 0.02651444, 0.09248776, 0.04260887, 0.23961709],
       [0.26413011, 0.26523537, 0.32580072, 0.13382714, 0.16266584],
       [0.4914127 , 0.31934363, 0.3653866 , 0.25301543, 0.26904373],
       [0.12198711, 0.30777312, 0.10216684, 0.2963796 , 0.11455197],
       [0.04927252, 0.08113344, 0.11415808, 0.27416897, 0.21412137]])

### Step 3: Generate the _three-way_ overlaps tensor $O$ of conditionally-independent LFs

Now we can directly generate $O$.
By our conditional independence assumption, we have:
$$
P(\lambda_i = y', \lambda_j = y'' | Y = y) = \alpha_{i,y',y} \alpha_{j,y'',y}
$$

Thus we have:
$$
O_{i,j,y',y''} = \sum_y P(Y=y) \alpha_{i,y',y} \alpha_{j,y'',y}
$$

In [4]:
# Compute mask
mask = torch.ones((M,M,M,K,K,K)).byte()
for i, j, k in product(range(M), repeat=3):
    if len(set((i,j,k))) < 3:
        mask[i,j,k,:,:,:] = 0

In [5]:
%%time
O = np.einsum('aby,cdy,efy,y->acebdf', alpha, alpha, alpha, p_Y)
O = torch.from_numpy(O).float()
O[1-mask] = 0

CPU times: user 365 ms, sys: 20.5 ms, total: 386 ms
Wall time: 108 ms


In [6]:
# Compute observed labeling rates
O_l = torch.from_numpy(np.einsum('aby,y->ab', alpha, p_Y)).float()

### Step 4: Try to recover $O_\Omega = \left[ Q \otimes Q \otimes Q \right]_\Omega$

Where $\Omega$ is all the non-masked entries, and $Q = A P^{\frac13}$.

In [7]:
def get_loss(Q, O):
    
    # Main constraint: match empirical three-way overlaps matrix (entries O_ijk for i != j != k)
    loss_1 = torch.norm((O - torch.einsum('aby,cdy,efy->acebdf', [Q,Q,Q]))[mask])**2
    
    # Col-wise stochastic: \sum_y' P(\lf=y'|Y=y) = 1.0
    # loss_2 = torch.norm(torch.sum(A, 1) - 1)**2
    
    # Row-wise constraint: match observed labeling rates P(\lf=y') = \sum_y P(Y=y) P(\lf=y'|Y=y)
    # loss_3 = torch.norm(O_l - torch.einsum('aby,y->ab', [A,P]))**2
    
    # Pairwise observed: match empirical pairwise overlaps matrix (entries O_ij for i != j)
    # loss_4 = torch.norm((O_2 - torch.einsum('aby,cdy,y->acbd', [A,A,P]))[mask_2])**2
    
    # return loss_1 + loss_2 + loss_3 # + loss_4
    return loss_1

def train_model_lbfgs(Q, O, n_epochs=10, lr=1, print_every=1):
    optimizer = optim.LBFGS([Q], lr=lr)
    
    for epoch in range(n_epochs):        
        def closure():
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass to calculate outputs
            loss = get_loss(Q, O)

            # Backward pass to calculate gradients
            loss.backward()

            # Report progress
            if epoch % print_every == 0:
                msg = f"[E:{epoch}]\tLoss: {loss.detach():.8f}"
                print(msg)
            
            return loss

        # Perform optimizer step
        optimizer.step(closure)

In [8]:
Q = nn.Parameter(torch.from_numpy(np.random.rand(M, K, K)).float()).float()
train_model_lbfgs(Q, O, n_epochs=15, print_every=5)

[E:0]	Loss: 899771.75000000
[E:0]	Loss: 882881.81250000
[E:0]	Loss: 240359.60937500
[E:0]	Loss: 112120.60156250
[E:0]	Loss: 42985.32031250
[E:0]	Loss: 17680.87500000
[E:0]	Loss: 6948.71923828
[E:0]	Loss: 2704.19873047
[E:0]	Loss: 1026.13830566
[E:0]	Loss: 386.93591309
[E:0]	Loss: 149.85525513
[E:0]	Loss: 62.92713547
[E:0]	Loss: 28.17135620
[E:0]	Loss: 12.22554111
[E:0]	Loss: 6.93255424
[E:0]	Loss: 5.35690165
[E:0]	Loss: 4.30577278
[E:0]	Loss: 3.88412380
[E:0]	Loss: 3.05224180
[E:0]	Loss: 3.87679529
[E:5]	Loss: 0.00611107
[E:5]	Loss: 0.00607124
[E:5]	Loss: 0.00602758
[E:5]	Loss: 0.00597767
[E:5]	Loss: 0.00592760
[E:5]	Loss: 0.00586420
[E:5]	Loss: 0.00579483
[E:5]	Loss: 0.00574548
[E:5]	Loss: 0.00572724
[E:5]	Loss: 0.00569788
[E:5]	Loss: 0.00567550
[E:5]	Loss: 0.00561180
[E:5]	Loss: 0.00554335
[E:5]	Loss: 0.00541598
[E:5]	Loss: 0.00536661
[E:5]	Loss: 0.00508872
[E:5]	Loss: 0.00497434
[E:5]	Loss: 0.00479429
[E:5]	Loss: 0.00456464
[E:5]	Loss: 0.00467708
[E:10]	Loss: 0.00001572
[E:10]	Loss:

In [9]:
p_Y

array([0.27229227, 0.30204649, 0.02238635, 0.13577854, 0.26749635])

In [10]:
p_Y_est = Q[3].sum(0) ** 3
p_Y_est

tensor([0.1358, 0.3020, 0.2723, 0.2675, 0.0224], grad_fn=<PowBackward0>)

In [11]:
p_Y_est.sum()

tensor(1.0000, grad_fn=<SumBackward0>)