In [165]:
import numpy as np
from scipy.stats import beta, uniform, expon, lognorm, weibull_min, chi2, t, gumbel_r

from src.learner import LinearMDCRL
from experiments.rand import rand_model
from experiments.utils import plot_hist_noise

## Create synthetic data

In [166]:
model_specs = {
    "nr_doms": 3,
    "joint_idx": [0,1,2,3],
    "domain_spec_idx": [[4],[5],[6]],
    "noise_rvs": [beta(2,3), expon(scale=0.1), t(df=5), gumbel_r, 
                  lognorm(s=1), weibull_min(c=2), chi2(df=6)],
    "sample_sizes":  [10000, 10000, 10000],
    "dims": [10,10,10],
    "graph_density": 0.75,
    "mixing_density": 0.9,
    "mixing_distribution": 'unif',  # unif or normal
    "indep_domain_spec": True
}

In [167]:
data, g, B_large = rand_model(model_specs)
# Can we plot g with causaldag?

In [168]:
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.46,  0.  ,  0.  ,  0.  ],
       [-0.87,  0.  ,  0.  ,  0.  ],
       [-0.41,  0.37,  0.  ,  0.  ]])

## Fit model

In [169]:
model = LinearMDCRL()
model.fit(data)

#plot_hist_noise(model.indep_comps)

In [170]:
model.joint_factors

[[1, 2, 4], [2, 3, 1], [3, 4, 3], [4, 1, 2]]

In [171]:
model.joint_mixing.shape

(30, 7)

In [172]:
model.joint_mixing.round(2)

array([[ 0.4 , -0.01,  0.21, -0.99,  0.  ,  0.  ,  0.  ],
       [-0.46,  0.  ,  0.69,  0.51,  0.34,  0.  ,  0.  ],
       [-0.01,  0.01, -0.6 ,  0.01, -0.72,  0.  ,  0.  ],
       [-0.76,  0.57,  0.38, -0.35,  0.81,  0.  ,  0.  ],
       [ 0.37,  0.01, -0.2 ,  0.01, -0.96,  0.  ,  0.  ],
       [ 0.13,  0.26,  0.62, -0.31,  0.  ,  0.  ,  0.  ],
       [ 0.77,  0.  , -1.03,  0.03, -0.77,  0.  ,  0.  ],
       [ 0.16, -0.96,  0.38, -0.4 , -0.91,  0.  ,  0.  ],
       [-0.93,  0.37,  0.24,  0.54,  0.01,  0.  ,  0.  ],
       [ 0.  , -0.42,  0.37,  0.  ,  0.71,  0.  ,  0.  ],
       [ 0.26, -0.46,  0.4 , -0.7 ,  0.  , -0.81,  0.  ],
       [-0.  ,  0.  , -0.7 ,  0.02,  0.  , -0.44,  0.  ],
       [ 0.9 , -0.99, -1.57, -0.37,  0.  ,  0.04,  0.  ],
       [ 0.97, -0.04, -0.4 , -0.02,  0.  ,  0.94,  0.  ],
       [ 0.01,  0.98,  0.78,  0.02,  0.  , -0.96,  0.  ],
       [ 0.26,  0.  ,  0.14, -0.7 ,  0.  , -0.62,  0.  ],
       [-0.94,  0.72, -0.06,  0.5 ,  0.  ,  0.03,  0.  ],
       [-0.76,

In [173]:
B_large.round(2)

array([[-0.24,  0.37,  0.  ,  0.99,  0.  ,  0.  ,  0.  ],
       [-0.66, -0.45,  0.  , -0.53,  0.33,  0.  ,  0.  ],
       [ 0.59,  0.  ,  0.  ,  0.  , -0.7 ,  0.  ,  0.  ],
       [-0.36, -0.78, -0.57,  0.32,  0.79,  0.  ,  0.  ],
       [ 0.17,  0.37,  0.  ,  0.  , -0.94,  0.  ,  0.  ],
       [-0.63,  0.11, -0.27,  0.3 ,  0.  ,  0.  ,  0.  ],
       [ 1.  ,  0.78,  0.  ,  0.  , -0.75,  0.  ,  0.  ],
       [-0.41,  0.15,  0.95,  0.41, -0.9 ,  0.  ,  0.  ],
       [-0.21, -0.91, -0.35, -0.56,  0.  ,  0.  ,  0.  ],
       [-0.35,  0.  ,  0.41,  0.  ,  0.69,  0.  ,  0.  ],
       [-0.44,  0.27, -0.47,  0.73,  0.  ,  0.76,  0.  ],
       [ 0.67,  0.  ,  0.  ,  0.  ,  0.  ,  0.47,  0.  ],
       [ 1.55,  0.93, -0.95,  0.39,  0.  ,  0.  ,  0.  ],
       [ 0.45,  0.98,  0.  ,  0.  ,  0.  , -0.92,  0.  ],
       [-0.82,  0.  ,  0.94,  0.  ,  0.  ,  0.95,  0.  ],
       [-0.17,  0.26,  0.  ,  0.72,  0.  ,  0.59,  0.  ],
       [ 0.07, -0.95,  0.7 , -0.51,  0.  ,  0.  ,  0.  ],
       [ 0.14,

## Recover graph (in infinite data limit)

In [174]:
B = B_large[:,:4]
B = B[:,[2,0,3,1]]  # add a permutation

In [175]:
def get_row_ranks(matrix):
    d = matrix.shape[0]
    R = np.zeros((d,d))
    for i in range(d):
        for j in range(i+1,d):
            R[i,j] = np.linalg.matrix_rank(matrix[[i,j],:])
    return R

In [176]:
# Find pure children of all (different) latent nodes
R = get_row_ranks(B) 
pos_rows = np.unique(np.where(R==1)[0])

# Remove duplicates
to_remove_idx = np.where(get_row_ranks(B[pos_rows,:])==1)[1]
mask = np.ones(len(pos_rows), dtype=bool)
mask[to_remove_idx] = False
pos_rows = pos_rows[mask]
B = B[pos_rows,:]
B

array([[ 0.        , -0.23616439,  0.99292499,  0.36633738],
       [ 0.        ,  0.58562377,  0.        ,  0.        ],
       [ 0.        ,  0.17103669,  0.        ,  0.36977767],
       [ 0.40650182, -0.35402682,  0.        ,  0.        ]])

In [177]:
# Remove permutation indetermancies of rows and columns
def single_support_rows(matrix):
    d, l = matrix.shape
    row_ids = []
    max_ids = []
    for i in range(d):
        m = np.abs(matrix[i,:]).argmax()
        mask = np.ones(l, dtype=bool)
        mask[m] = False
        if np.sum(np.abs(matrix[i,mask])) == 0:
            row_ids.append(i)
            max_ids.append(m)
    return row_ids, max_ids

def search_order(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find rows with all but one element equal to zero
        row_ids, max_ids = single_support_rows(matrix)
        if len(row_ids) == 0:
            break
        target_row = row_ids[0]
        target_col = max_ids[0]
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [178]:
B

array([[ 0.        , -0.23616439,  0.99292499,  0.36633738],
       [ 0.        ,  0.58562377,  0.        ,  0.        ],
       [ 0.        ,  0.17103669,  0.        ,  0.36977767],
       [ 0.40650182, -0.35402682,  0.        ,  0.        ]])

In [179]:
order_rows, order_cols = search_order(B)
if order_rows is not None:
    B = B[order_rows, :][:, order_cols]

In [180]:
# Remove scaling indetermancy and solve for A
np.eye(B.shape[0]) - np.linalg.inv(np.matmul(np.diag(1/np.diag(B)), B)).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.46,  0.  ,  0.  ,  0.  ],
       [-0.41,  0.37,  0.  ,  0.  ],
       [-0.87,  0.  ,  0.  ,  0.  ]])

In [181]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.46,  0.  ,  0.  ,  0.  ],
       [-0.87,  0.  ,  0.  ,  0.  ],
       [-0.41,  0.37,  0.  ,  0.  ]])

In [182]:
## okay up to relabeling that preserves causal order, etc

## Recover noisy graph

In [183]:
l = len(model.joint_factors)
B_hat = model.joint_mixing[:,:l]#B_large[:,:l]#

In [184]:
def score_rows(matrix):
    d = matrix.shape[0]
    R = np.full(shape=(d,d),fill_value=.0)
    for i in range(d):
        for j in range(i+1,d):
            u, s, v = np.linalg.svd(matrix[[i,j],:])
            #s2 = np.sort(s) ** 2
            #R[i,j] = s2[-1] / s2.sum()
            R[i,j] = 1/(min(s)+0000000.1)
    return R

In [185]:
# Having 2*l candidates for the pure children, we have to make sure that they are pure children from different nodes

# Checks: 
# - 8 unique children
# - the rank between two children from different nodes should be 2 (p(W)<level)
# - If a two-children tuple does not satisfy one of the above checks: 
#    Remove it from candidate list. Add two-children tuple with next highest p(W) to candidate list

In [186]:
def get_duplicates(cand_ids, ord_tup):
    l = len(cand_ids)
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    uniques = np.unique(np.concatenate((ids0, ids1), axis=0))
    if len(uniques)==(2*l):
        return None
    else:
        for i in range(l):
            for j in range(i+1,l):
                if (ids0[i]==ids0[j]) or (ids0[i]==ids1[j]):
                    return cand_ids[j]

In [187]:
def get_low_rank(cand_ids, ord_tup, R, level=0.95):
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    #print(ids0,  ids1)
    for i in range(l):
        for j in range(i+1,l):
            if (R[ids0[i],ids0[j]] > level) or (R[ids0[i],ids1[j]] > level):
                #print(i,j,R[ids0[i],ids0[j]],R[ids0[i],ids1[j]])
                return cand_ids[i], cand_ids[j]
    return None        

In [188]:
def update_cand_ids(cand_ids, to_remove):
    current_max = max(cand_ids)
    cand_ids.remove(to_remove)
    cand_ids.append(current_max+1)
    return(cand_ids)

In [189]:
def get_pure_children(B):
    
    # Score all pairs of rows
    R = score_rows(B)
    d = R.shape[0]
    nr_tuples = int(d * (d-1) / 2)
    
    ord_tup = np.unravel_index(R.ravel().argsort(), R.shape)
    ord_tup = (np.flip(ord_tup[0][nr_tuples:]), 
               np.flip(ord_tup[1][nr_tuples:]))
    
    # Make R symmetric now (important)
    for i in range(d):
        for j in range(i+1,d):
            R[j,i]=R[i,j]
            
    # Choose rows that maximize thw "within-tuple" score and at the same time have a low "inter-tuple" score 
    cand_ids = list(np.arange(l))
    while max(cand_ids) < nr_tuples:
        to_remove = get_duplicates(cand_ids, ord_tup)
        if to_remove is not None:
            cand_ids = update_cand_ids(cand_ids, to_remove)
            continue
        to_remove = get_low_rank(cand_ids, ord_tup, R, level=10)
        if to_remove is not None:
            temp_cands = cand_ids.copy()
            temp_cands = update_cand_ids(temp_cands, to_remove[0])
            if get_low_rank(temp_cands, ord_tup, R, level=10) is None:
                cand_ids = update_cand_ids(cand_ids, to_remove[0])
            else:
                cand_ids = update_cand_ids(cand_ids, to_remove[1])
            continue
        break   
    pure_children_rows = ord_tup[0][cand_ids]
    pure_children_rows
    
    return B[pure_children_rows,:]  

In [190]:
B_star = get_pure_children(B)
B_star.round(2)

array([[ 0.59,  0.  ,  0.  ,  0.  ],
       [ 0.17,  0.37,  0.  ,  0.  ],
       [-0.24,  0.37,  0.99,  0.  ],
       [-0.35,  0.  ,  0.  ,  0.41]])

In [191]:
def l2_without_max(x):
    m = np.abs(x).argmax()
    mask = np.ones(len(x), dtype=bool)
    mask[m] = False
    return np.linalg.norm(x[mask])

def search_order_noisy(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find row with lowest l2 norm where all entries but the maximum are considered
        target_row = np.apply_along_axis(l2_without_max, 1, matrix).argmin()
        target_col = abs(matrix[target_row,:]).argmax()
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [192]:
order_rows, order_cols = search_order_noisy(B_star)
if order_rows is not None:
    B_star = B_star[order_rows, :][:, order_cols]
B_star.round(2)

array([[ 0.59,  0.  ,  0.  ,  0.  ],
       [ 0.17,  0.37,  0.  ,  0.  ],
       [-0.24,  0.37,  0.99,  0.  ],
       [-0.35,  0.  ,  0.  ,  0.41]])

In [193]:
#B_star[np.triu_indices(4,k=1)] = 0
#B_star

In [194]:
np.diag(B_star)

array([0.58562377, 0.36977767, 0.99292499, 0.40650182])

In [195]:
# Remove sign indetermancy from columns
B_star = np.matmul(B_star, np.diag(np.sign(np.diag(B_star))))

# Remove scaling indetermancy from rows  and solve for A
B_star = np.matmul(np.diag(1/np.diag(B_star)), B_star)

# Solve for A
A = (np.eye(B_star.shape[0]) - np.linalg.inv(B_star))
A.round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.46,  0.  ,  0.  ,  0.  ],
       [-0.41,  0.37,  0.  ,  0.  ],
       [-0.87,  0.  ,  0.  , -0.  ]])

In [196]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.46,  0.  ,  0.  ,  0.  ],
       [-0.87,  0.  ,  0.  ,  0.  ],
       [-0.41,  0.37,  0.  ,  0.  ]])

TODO: 
- Check code for graph recovery on noisy data
- Clean code and add it to LinearMDCRL class
- Implement a routine to check the solution of recovering the joint mixing matrix (involves permutations)
- Implement a routine to check the causal order/ the parameter matrix of the graph
- Write code for simulations

Issues:
- With the current implementation it can happen that G has zero rows...