In [1]:
import numpy as np
from scipy.stats import uniform, expon, lognorm, weibull_min, chi2, t, gumbel_r

from src.learner import LinearMDCRL
from experiments.rand import rand_model
from experiments.utils import plot_hist_noise

## Create synthetic data

In [58]:
model_specs = {
    "nr_doms": 3,
    "joint_idx": [0,1,2,3],
    "domain_spec_idx": [[4],[5],[6]],
    "noise_rvs": [uniform(), expon(scale=0.1), t(df=5), gumbel_r, 
                  lognorm(s=1), weibull_min(c=2), chi2(df=6)],
    "sample_sizes":  [5000, 5000, 5000],
    "dims": [13,14,15],
    "graph_density": 0.75,
    "mixing_density": 0.9,
    "mixing_distribution": 'unif',  # unif or normal
    "indep_domain_spec": True
}

In [59]:
data, g, B_large = rand_model(model_specs)
# Can we plot g with causaldag?

## Fit model

In [60]:
model = LinearMDCRL()
model.fit(data)

#plot_hist_noise(model.indep_comps)

In [61]:
model.joint_factors

[[1, 0, 0], [2, 3, 4], [3, 4, 3], [4, 1, 1]]

In [62]:
model.joint_mixing.shape

(42, 7)

In [63]:
model.joint_mixing.round(2)

array([[ 0.37, -0.26, -0.85, -0.36,  0.  ,  0.  ,  0.  ],
       [-0.  , -0.01,  0.67,  0.  , -0.01,  0.  ,  0.  ],
       [ 0.01,  0.8 ,  0.06,  0.69, -0.85,  0.  ,  0.  ],
       [ 0.  ,  0.01, -0.54,  0.56, -0.91,  0.  ,  0.  ],
       [-0.01,  0.7 , -0.  ,  1.35,  0.91,  0.  ,  0.  ],
       [-0.55,  0.74,  2.24, -0.74, -0.53,  0.  ,  0.  ],
       [ 0.88, -0.28, -1.3 , -0.14, -0.42,  0.  ,  0.  ],
       [ 0.68,  0.48,  0.78,  0.2 ,  0.3 ,  0.  ,  0.  ],
       [-0.68,  0.01,  0.68, -0.48, -0.29,  0.  ,  0.  ],
       [-0.81, -0.74,  0.52, -1.75, -0.61,  0.  ,  0.  ],
       [-0.44,  0.57,  0.61, -0.54, -0.63,  0.  ,  0.  ],
       [ 0.31, -0.75, -1.5 ,  0.06, -0.84,  0.  ,  0.  ],
       [ 0.98, -0.87, -1.95,  0.88, -0.86,  0.  ,  0.  ],
       [ 0.36, -0.49, -0.85,  0.89,  0.  ,  1.03,  0.  ],
       [-0.59, -0.46,  0.53, -1.24,  0.  , -0.72,  0.  ],
       [-0.95,  0.51,  0.6 ,  0.11,  0.  ,  0.76,  0.  ],
       [ 0.08, -0.01,  0.73, -0.75,  0.  , -0.71,  0.  ],
       [-0.74,

In [64]:
B_large.round(2)

array([[-0.85, -0.36,  0.27,  0.37,  0.  ,  0.  ,  0.  ],
       [ 0.68,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.06,  0.67, -0.79,  0.  , -0.83,  0.  ,  0.  ],
       [-0.53,  0.55,  0.  ,  0.  , -0.9 ,  0.  ,  0.  ],
       [-0.01,  1.35, -0.72,  0.  ,  0.86,  0.  ,  0.  ],
       [ 2.23, -0.75, -0.73, -0.56, -0.46,  0.  ,  0.  ],
       [-1.29, -0.14,  0.28,  0.87, -0.42,  0.  ,  0.  ],
       [ 0.78,  0.19, -0.5 ,  0.68,  0.3 ,  0.  ,  0.  ],
       [ 0.68, -0.48,  0.  , -0.68, -0.26,  0.  ,  0.  ],
       [ 0.52, -1.73,  0.77, -0.81, -0.55,  0.  ,  0.  ],
       [ 0.59, -0.55, -0.55, -0.45, -0.59,  0.  ,  0.  ],
       [-1.47,  0.06,  0.76,  0.3 , -0.84,  0.  ,  0.  ],
       [-1.91,  0.87,  0.86,  0.98, -0.88,  0.  ,  0.  ],
       [-0.84,  0.88,  0.49,  0.47,  0.  , -0.97,  0.  ],
       [ 0.51, -1.21,  0.44, -0.69,  0.  ,  0.63,  0.  ],
       [ 0.58,  0.14, -0.48, -0.93,  0.  , -0.82,  0.  ],
       [ 0.73, -0.75,  0.  ,  0.  ,  0.  ,  0.69,  0.  ],
       [ 0.82,

## Recover graph (in infinite data limit)

In [65]:
B = B_large[:,:4]
B = B[:,[2,0,3,1]]  # add a permutation

In [66]:
def get_row_ranks(matrix):
    d = matrix.shape[0]
    R = np.zeros((d,d))
    for i in range(d):
        for j in range(i+1,d):
            R[i,j] = np.linalg.matrix_rank(matrix[[i,j],:])
    return R

In [67]:
# Find pure children of all (different) latent nodes
R = get_row_ranks(B) 
pos_rows = np.unique(np.where(R==1)[0])

# Remove duplicates
to_remove_idx = np.where(get_row_ranks(B[pos_rows,:])==1)[1]
mask = np.ones(len(pos_rows), dtype=bool)
mask[to_remove_idx] = False
pos_rows = pos_rows[mask]
B = B[pos_rows,:]
B

array([[ 0.        ,  0.6763524 ,  0.        ,  0.        ],
       [-0.79331175,  0.05577118,  0.        ,  0.6699705 ],
       [ 0.        , -0.53428523,  0.        ,  0.54870805],
       [ 0.        ,  0.67546024, -0.68068828, -0.47825125]])

In [68]:
# Remove permutation indetermancies of rows and columns
def single_support_rows(matrix):
    d, l = matrix.shape
    row_ids = []
    max_ids = []
    for i in range(d):
        m = np.abs(matrix[i,:]).argmax()
        mask = np.ones(l, dtype=bool)
        mask[m] = False
        if np.sum(np.abs(matrix[i,mask])) == 0:
            row_ids.append(i)
            max_ids.append(m)
    return row_ids, max_ids

def search_order(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find rows with all but one element equal to zero
        row_ids, max_ids = single_support_rows(matrix)
        if len(row_ids) == 0:
            break
        target_row = row_ids[0]
        target_col = max_ids[0]
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [69]:
order_rows, order_cols = search_order(B)
if order_rows is not None:
    B = B[order_rows, :][:, order_cols]

In [70]:
# Remove scaling indetermancy and solve for A
np.eye(B.shape[0]) - np.linalg.inv(np.matmul(np.diag(1/np.diag(B)), B)).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.97,  0.  ,  0.  ,  0.  ],
       [-0.89, -0.84,  0.  ,  0.  ],
       [-0.31,  0.7 ,  0.  ,  0.  ]])

In [71]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.97,  0.  ,  0.  ,  0.  ],
       [-0.89, -0.84,  0.  ,  0.  ],
       [-0.31,  0.7 ,  0.  ,  0.  ]])

In [72]:
## okay up to relabeling that preserves causal order, etc

TODO: 
- Implement a routine to check the solution of recovering the joint mixing matrix 
- How to implement graph recovery for real data? (DIFFICULT)
- With the current implementation it can happen that G has zero rows...

## Recover noisy graph

In [225]:
l = len(model.joint_factors)
B_hat = model.joint_mixing[:,:l]

In [226]:
def score_rows(matrix):
    d = matrix.shape[0]
    R = np.full(shape=(d,d),fill_value=0.)
    for i in range(d):
        for j in range(i+1,d):
            u, s, v = np.linalg.svd(matrix[[i,j],:])
            s2 = np.sort(s) ** 2
            R[i,j] = s2[-1] / s2.sum()
    return R

In [227]:
R = score_rows(B_hat)

ord_tup = np.unravel_index(R.ravel().argsort(), R.shape)

d = R.shape[0]
nr_tuples = int(d * (d-1) / 2)

ord_tup = (np.flip(ord_tup[0][nr_tuples:]), 
           np.flip(ord_tup[1][nr_tuples:]))

In [218]:
# Having 2*l candidates for the pure children, we have to make sure that they are pure children from different nodes

# Checks: 
# - 8 unique children
# - the rank between two children from different nodes should be 2 (p(W)<level)
# - If a two-children tuple does not satisfy one of the above checks: 
#    Remove it from candidate list. Add two-children tuple with next highest p(W) to candidate list

In [219]:
def get_duplicates(cand_ids, ord_tup):
    l = len(cand_ids)
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    uniques = np.unique(np.concatenate((ids0, ids1), axis=0))
    if len(uniques)==(2*l):
        return None
    else:
        for i in range(l):
            for j in range(i+1,l):
                if (ids0[i]==ids0[j]) or (ids0[i]==ids1[j]):
                    return j

In [220]:
def get_low_rank(cand_ids, ord_tup, R, level=0.95):
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    for i in range(l):
        for j in range(i+1,l):
            if (R[ids0[i],ids0[j]] > level) or (R[ids0[i],ids1[j]] > level):
                print(i,j,R[ids0[i],ids0[j]],R[ids0[i],ids1[j]])
                return j
    return None        

In [221]:
def update_cand_ids(cand_ids, to_remove):
    current_max = max(cand_ids)
    cand_ids.remove(to_remove)
    cand_ids.append(current_max+1)
    return(cand_ids)

In [222]:
cand_ids = np.arange(l)

while cand_ids.max() < nr_tuples:
    to_remove = get_duplicates(cand_ids, ord_tup)
    if to_remove is not None:
        cand_ids = update_cand_ids(cand_ids, to_remove, maxim=nr_tuples)
        continue
    to_remove = get_low_rank(cand_ids, ord_tup, R, level=0.95)
    if to_remove is not None:
        cand_ids = update_cand_ids(cand_ids, to_remove, maxim=nr_tuples)
        continue
    break

In [223]:
pure_children_rows = ord_tup[0][cand_ids]
pure_children_rows

array([1, 2, 8, 3])

In [230]:
B_star = B_hat[pure_children_rows,:]
B_star

array([[-0.00093562, -0.00989883,  0.67478332,  0.00213392],
       [ 0.01051012,  0.79948927,  0.06119898,  0.69093332],
       [-0.67994694,  0.01304396,  0.6823765 , -0.48023394],
       [ 0.00452868,  0.007207  , -0.54421053,  0.5616751 ]])

In [238]:
def l2_without_max(x):
    m = np.abs(x).argmax()
    mask = np.ones(len(x), dtype=bool)
    mask[m] = False
    return np.linalg.norm(x[mask])

def search_order_noisy(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find row with lowest l2 norm where all entries but the maximum are considered
        target_row = np.apply_along_axis(l2_without_max, 1, matrix).argmin()
        target_col = abs(matrix[target_row,:]).argmax()
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [242]:
order_rows, order_cols = search_order_noisy(B_star)
if order_rows is not None:
    B_star = B_star[order_rows, :][:, order_cols]
B_star

array([[ 0.67478332,  0.00213392, -0.00989883, -0.00093562],
       [-0.54421053,  0.5616751 ,  0.007207  ,  0.00452868],
       [ 0.06119898,  0.69093332,  0.79948927,  0.01051012],
       [ 0.6823765 , -0.48023394,  0.01304396, -0.67994694]])

In [243]:
# Remove scaling indetermancy and solve for A
np.eye(B_star.shape[0]) - np.linalg.inv(np.matmul(np.diag(1/np.diag(B_star)), B_star)).round(2)

array([[ 0.02,  0.02, -0.01,  0.  ],
       [-0.96,  0.  ,  0.  ,  0.01],
       [ 0.91,  0.85,  0.  ,  0.01],
       [-0.29,  0.74, -0.03, -0.01]])

In [244]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.97,  0.  ,  0.  ,  0.  ],
       [-0.89, -0.84,  0.  ,  0.  ],
       [-0.31,  0.7 ,  0.  ,  0.  ]])