In [1]:
import torch
from torch import nn
import numpy as np
import time
import itertools
from typing import Any, DefaultDict, Dict, List, NamedTuple, Optional, Set, Tuple, Union


### Tensor Decomposition Label Model

The implementation contains a few utility functions leading up to the main tensor decomposition method (tensor_decomp). Afterwards, we give the label model, which includes a new fit() method that calls a helper function mu_recovery.

References are all to Anima's paper https://arxiv.org/pdf/1210.7559.pdf

In [2]:
def multimap(A, V_array):
    """Compute a tensor product as a multilinear map.(pg. 2778, Section 2)

        Parameters
        ----------
        A
            A multidimensional tensor
        V_array
            Array of vectors to compute tensor against

    """
    p = len(V_array)
    for i in range(len(V_array)):
        if len(V_array[i].shape) == 1:
            V_array[i] = np.expand_dims(V_array[i], axis=1)

    n = V_array[0].shape[0]
    dims = [a.shape[1] for a in V_array]
    dim_ranges = [range(a.shape[1]) for a in V_array]
    B = np.zeros(dims)

    all_indices = list(itertools.product(*dim_ranges)) #i_1,...,i_p
    all_vectors = list(itertools.product(range(n), repeat=p)) #j_1,...,j_p

    for ind in all_indices:
        for vec in all_vectors:
            tmp = A[vec]
            for k in range(p):
                tmp *= V_array[k][vec[k], ind[k]]
            B[ind] += tmp
    return B

def two_tensor_prod(w, x, y):
    """
    A type of outer product
    """
    r = x.shape[0]
    M2 = np.zeros([r, r])

    for a in range(w.shape[0]):
        for i in range(x.shape[0]):
            for j in range(y.shape[0]):
                M2[i,j] += w[a] * x[i,a] * y[j,a]

    return M2

def three_tensor_prod(w, x, y, z):
    """
    Three-way outer product
    """
    r = x.shape[0]
    M3 = np.zeros([r, r, r])

    if len(w.shape) == 0:
        for i in range(x.shape[0]):
            for j in range(y.shape[0]):
                for k in range(z.shape[0]):
                    M3[i,j,k] += w * x[i] * y[j] * z[k]
    else:
        for a in range(w.shape[0]):
            for i in range(x.shape[0]):
                for j in range(y.shape[0]):
                    for k in range(z.shape[0]):
                        M3[i,j,k] += w[a] * x[i,a] * y[j,a] * z[k,a]

    return M3

def T_map(T, u):
    """ Power method base transformation (pg. 2790, equation (5))

        Parameters
        ----------
        T
            A multidimensional tensor
        u
            A candidate eigenvector

        Returns
        -------
        t   
            Transformed candidate

    """

    d = u.shape[0]
    t = np.zeros(d)
    for i in range(d):
        for j in range(d):
            for k in range(d):
                t[i] += T[i,j,k] * u[j] * u[k]
    return t

def tensor_decomp(M2, M3, comps):
    """Tensor Decomposition Algorithm (pg. 2795, Algorithm 1)
    This is combined with reduction (4.3.1)

    Parameters
    ----------
    M2
        Symmetric matrix to aid the decomposition
    M3
        Symmetric tensor to be decomposed
    comps
        Number of eigencomponents to return

    Returns
    -------
    mu_rec
        Recovered eigenvectors (a matrix with #comps eigenvectors)
    lam_rec
        Recovered eigenvalues (a vector with #comps eigenvalues)

    """
    lam_rec = np.zeros(comps)
    mu_rec = np.zeros((M2.shape[0], comps))

    for b in range(comps):
        # initial eigendecomposition used in reduction (4.3.1)
        lam, v = np.linalg.eigh(M2)
        idx = lam.argsort()[::-1]
        lam = lam[idx]
        v = v[:, idx]

        # keep only the positive eigenvalues
        n_eigpos = np.sum(lam > 1e-1)
        if n_eigpos > 0:
            W = v[:, :n_eigpos] @ np.diag(1.0 / np.sqrt(np.abs(lam[:n_eigpos])))

            B = np.linalg.pinv(W.T)  # TODO look into this
            M3_tilde = multimap(M3, [W, W, W])  # reduction complete

            # decomposition setup
            # TODO try different hps if this doesn't work
            N = 10  # number of power iterations
            restarts = 1000  # number of random restarts # NOTE critical
            tau_star = 0  # best robust eigenvalue so far
            u_star = np.zeros(n_eigpos)  # best eigenvector so far

            # repeated restarts to find best eigenvector
            for j in range(restarts):
                # randomly draw from unit sphere (step 2)
                # u = np.random.randn(n_eigpos)
                u = np.random.multivariate_normal(np.zeros(n_eigpos), np.eye(n_eigpos))
                u /= np.linalg.norm(u)

                # power iteration for N iterations
                for i in range(N):
                    u = T_map(M3_tilde, u)
                    u /= np.linalg.norm(u)

                # check for best eigenvalue
                if j == 0 or (j > 0 and multimap(M3_tilde, [u, u, u]) > tau_star):
                    tau_star = multimap(M3_tilde, [u, u, u])
                    u_star = u

            # N more power iterations for best eigenvector found
            u = u_star
            for i in range(N):
                u = T_map(M3_tilde, u)
                u /= np.linalg.norm(u)

            # recovered modified (post-reduction) eigenvalue
            lamb = (T_map(M3_tilde, u) / u)[0]

            # recover original eigenvector and eigenvalue pair
            mu_rec[:, b] = lamb * B @ u
            lam_rec[b] = 1 / lamb**2

            # deflation: remove component, repeat
            M2 -= lam_rec[b] * np.outer(mu_rec[:, b], mu_rec[:, b])
            M3 -= three_tensor_prod(
                np.array(lam_rec[b]), mu_rec[:, b], mu_rec[:, b], mu_rec[:, b]
            )

    return mu_rec, lam_rec

def lowrank(x, k):
    u, s, vh = np.linalg.svd(x)
    s_abs = np.abs(s)
    inds = np.argsort(s_abs)[::-1][:k]
    rec = np.zeros_like(x)
    for i in inds:
        rec += s[i] * np.outer(u.T[i], vh[i])
    return rec

class LabelModel():
    def __init__(self, cardinality: int = 2, **kwargs: Any) -> None:
        self.cardinality = cardinality

    def mu_recovery(self, triplet_idx_a, triplet_idx_b, triplet_idx_c):
        """ Recover mu for a single labeling function (index triplet_index_a)
            Follows the multi-view models approach (Section 3.3)
            Constructs symmetric matrices / tensors from observed quantities

            Parameters
            ----------
            triplet_idx_a, triplet_idx_b, triplet_idx_c
                Indices for the three labeling functions

        """
        # setups for base matries and tensors     
        if self.unipolar:
            r = 2
        else:
            r = self.cardinality + 1 # for convenience
        M2, M3 = np.zeros([r, r]), np.zeros([r, r, r])
        x_tilde_1, x_tilde_2 = np.zeros((self.n, r)), np.zeros((self.n, r)) 

        # construct the main quantities, empirical two-tensors, pg. 2785
        Ex3x2, Ex1x2, Ex3x1, Ex2x1 = np.zeros([r, r]), np.zeros([r, r]), np.zeros([r, r]), np.zeros([r, r])
        if self.unipolar:
            L = (self.L_train > -1).astype(int) - 1
        else:
            L = self.L_train
        for i in range(r):
            for j in range(r):
                for k in range(self.n):
                    Ex1x2[i, j] += (L[k,triplet_idx_a] == i-1 and L[k,triplet_idx_b] == j-1)
                    Ex3x2[i, j] += (L[k,triplet_idx_c] == i-1 and L[k,triplet_idx_b] == j-1)
                    Ex3x1[i, j] += (L[k,triplet_idx_c] == i-1 and L[k,triplet_idx_a] == j-1)
                    Ex2x1[i, j] += (L[k,triplet_idx_b] == i-1 and L[k,triplet_idx_a] == j-1)

        Ex3x2 /= self.n
        Ex1x2 /= self.n
        Ex3x1 /= self.n 
        Ex2x1 /= self.n

        Ex3x2 = lowrank(Ex3x2, k=self.cardinality)
        Ex1x2 = lowrank(Ex1x2, k=self.cardinality)
        Ex3x1 = lowrank(Ex3x1, k=self.cardinality)
        Ex2x1 = lowrank(Ex2x1, k=self.cardinality)

        for k in range(self.n):
            x1, x2, x3 = np.zeros(r), np.zeros(r), np.zeros(r)
            x1[int(L[k,triplet_idx_a])+1] = 1
            x2[int(L[k,triplet_idx_b])+1] = 1
            x3[int(L[k,triplet_idx_c])+1] = 1

            x_tilde_1[k] = Ex3x2 @ np.linalg.pinv(Ex1x2) @ x1
            x_tilde_2[k] = Ex3x1 @ np.linalg.pinv(Ex2x1) @ x2              

            # symmetrized versions, Theorem 3.6 pg. 2785
            M2 += np.outer(x_tilde_1[k], x_tilde_2[k])
            M3 += three_tensor_prod(np.array(1.0), x_tilde_1[k], x_tilde_2[k], x3)

        M2 /= self.n
        M3 /= self.n

        # comps: we should recover at most the number of cardinality terms
        mu_rec, lam_rec = tensor_decomp(M2, M3, self.cardinality)
        print(f"got mu_rec, lam_rec = {mu_rec} {lam_rec}")
        return mu_rec
        
    def fit(
        self,
        L_train: np.ndarray,
    ) -> None:
        self.n, self.m = L_train.shape
        if self.m < 3:
            raise ValueError("L_train should have at least 3 labeling functions")

        self.mu_numpy = np.zeros((self.m*self.cardinality, self.cardinality))
        self.L_train = L_train
        n_triplets = 1

        # unipolarity affects algorithm; changes the rank of the tensor decomp
        # assume all LFs are either unipolar or all multipolar
        polarity = 0
        for a in range(self.cardinality):
            if np.sum(self.L_train[:, 0] == a) > 0:
                polarity += 1
        
        self.unipolar = True if polarity < 2 else False
        
        # partition the LFs based on their unipolar votes
        if self.unipolar:
            self.unipolar_groups, self.unipolar_votes = {}, {}
            for a in range(self.cardinality):
                self.unipolar_groups[a] = []
                for i in range(self.m):
                    if np.sum(self.L_train[:, i] == a) > 0:
                        self.unipolar_groups[a] += [i]
                        self.unipolar_votes[i] = a

        # keep track of overlaps in order to obtain the least noisy triplet
        overlaps = np.zeros((self.m, self.m))
        for i in range(self.m):
            for j in range(self.m):
                overlaps[i, j] =  np.sum((self.L_train[:, i] > -1) & (self.L_train[:, j] > -1) & (self.L_train[:, i] == self.L_train[:, j]))

        for i in range(self.m):
            # select triplets
            if self.unipolar:
                idxes = self.unipolar_groups[self.unipolar_votes[i]].copy()
            else:
                idxes = list(range(self.m))
            idxes.remove(i)
            idxes = np.random.permutation(idxes)
            b, c = idxes[0], idxes[1]
            
            # get best overlaps
            for k, l in itertools.combinations(idxes, 2):
                if overlaps[i, k] + overlaps[k, l] + overlaps[i, l] > overlaps[i,b] + overlaps[i, c] + overlaps[b, c]:
                    b, c = k, l
            mu_rec = self.mu_recovery(b, c, i)

            # set the recovered mu component
            if self.unipolar:
                self.mu_numpy[i*self.cardinality + self.unipolar_votes[i], :] += mu_rec[1, :]
            else:
                self.mu_numpy[i*self.cardinality:(i+1)*self.cardinality, :] += mu_rec[1:, :]
             


### Test recovery for synthetic mus

Generate some simple data below

In [3]:
# simple cardinality 2 case
def generate_synthetic(n, p, mu1, mu2, mu3):
    y, L = np.zeros(n), np.zeros((n, 3))

    for i in range(n):
        y[i] = np.random.choice(2, 1, p=p)
        L[i, 0] = np.random.choice(3, 1, p=mu1[:,int(y[i])]) - 1
        L[i, 1] = np.random.choice(3, 1, p=mu2[:,int(y[i])]) - 1
        L[i, 2] = np.random.choice(3, 1, p=mu3[:,int(y[i])]) - 1

    return L, y

In [4]:
#mu_true[0:3,:]

In [5]:
mu_true = np.zeros((9,2))

# mu's for three LFs
# multipolar case
mu_true[0:3,:] = np.array([[0.6,0.7], [0.3, 0.1], [0.1, 0.2]])
mu_true[3:6,:] = np.array([[0.55,0.6], [0.4, 0.1], [0.05, 0.3]])
mu_true[6:9,:] = np.array([[0.5,0.7], [0.4, 0.05], [0.1, 0.25]])

n = 100 
w = np.array([0.3,0.7])

L, Y = generate_synthetic(n, w,  mu_true[0:3,:], mu_true[3:6,:] , mu_true[6:9,:])
Y 

array([1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 0., 1., 0.,
       1., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 0., 0., 1.,
       1., 1., 1., 0., 0., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 0.,
       1., 1., 1., 1., 1., 0., 1., 0., 1., 1., 1., 1., 0., 0., 0., 1., 1.,
       1., 1., 1., 1., 0., 0., 1., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1.,
       0., 1., 1., 1., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 1.])

In [8]:
# generate some samples and evaluate
mu_true = np.zeros((9,2))

# mu's for three LFs
# multipolar case
mu_true[0:3,:] = np.array([[0.6,0.7], [0.3, 0.1], [0.1, 0.2]])
mu_true[3:6,:] = np.array([[0.55,0.6], [0.4, 0.1], [0.05, 0.3]])
mu_true[6:9,:] = np.array([[0.5,0.7], [0.4, 0.05], [0.1, 0.25]])
'''
# unipolar case
mu_true[0:3,:] = np.array([[0.6,0.9], [0.4, 0.1], [0.0, 0.0]])
mu_true[3:6,:] = np.array([[0.55,0.8], [0.45, 0.2], [0.0, 0.0]])
mu_true[6:9,:] = np.array([[0.7,0.95], [0.3, 0.05], [0.0, 0.0]])
'''

# class balance
w = np.array([0.3,0.7])

for ex in range(4,7):
    n = 10**ex
    print(f"working with {n} samples")
    a = time.time()
    L, Y = generate_synthetic(n, w,  mu_true[0:3,:], mu_true[3:6,:] , mu_true[6:9,:])
    #L, Y = generate_synthetic(n, w,  mu1, mu2, mu3)
    
    label_model = LabelModel(cardinality=2)
    label_model.fit(L)

    # parameter recovery error on the mu's
    tot_param_err = 0
    for i in range(3):
        tot_param_err += np.sum(np.abs(mu_true[3*i+1,:]-label_model.mu_numpy[2*i,:]))
    #print(label_model.mu_numpy)
    print(f"for {n} samples, recovery error {tot_param_err}")

working with 10000 samples
got mu_rec, lam_rec = [[0.66998758 0.        ]
 [0.15782961 0.        ]
 [0.17192876 0.        ]] [1.00044762 0.        ]
got mu_rec, lam_rec = [[0.58007223 0.        ]
 [0.18071107 0.        ]
 [0.23731416 0.        ]] [0.99933828 0.        ]
got mu_rec, lam_rec = [[0.66474852 0.        ]
 [0.13469038 0.        ]
 [0.21421256 0.        ]] [0.95902947 0.        ]
for 10000 samples, recovery error 0.8767689350675498
