In [887]:
import numpy as np
from itertools import permutations
from scipy.stats import beta, uniform, expon, lognorm, weibull_min, chi2, t, gumbel_r, skewnorm

from src.learner import LinearMDCRL
from experiments.rand import rand_model
from experiments.utils import plot_hist_noise

## Create synthetic data

In [888]:
model_specs = {
    "nr_doms": 3,
    "joint_idx": [0,1,2,3],
    "domain_spec_idx": [[4],[5],[6]],
    "noise_rvs": [beta(2,3), expon(scale=0.1), skewnorm(a=6), gumbel_r, 
                  lognorm(s=1), weibull_min(c=2), chi2(df=6)],
    "sample_sizes":  [10000, 10000, 10000],
    "dims": [10,10,10],
    "graph_density": 0.75,
    "mixing_density": 0.9,
    "mixing_distribution": 'unif',  # unif or normal
    "indep_domain_spec": True
}

In [889]:
data, g, B_large = rand_model(model_specs)
# Can we plot g with causaldag?

In [890]:
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.65, -0.69,  0.  ,  0.  ],
       [-0.34,  0.  ,  0.  ,  0.  ]])

## Fit model

In [891]:
model = LinearMDCRL(metric="1-wasserstein")
model.fit(data)

#plot_hist_noise(model.indep_comps)

In [892]:
model.nr_joint

4

In [893]:
model.joint_mixing.shape

(30, 7)

In [894]:
model.joint_mixing.round(2)

array([[-1.24,  0.15,  0.93, -0.84,  0.28,  0.  ,  0.  ],
       [ 0.63,  0.61, -0.91, -0.03,  0.41,  0.  ,  0.  ],
       [-0.8 , -0.02, -0.02, -0.02, -0.34,  0.  ,  0.  ],
       [-0.01, -0.19, -0.01,  0.62,  0.65,  0.  ,  0.  ],
       [-0.46, -0.68,  0.65,  0.82, -0.6 ,  0.  ,  0.  ],
       [ 0.33, -0.09, -0.95, -0.26,  0.58,  0.  ,  0.  ],
       [ 0.79, -0.16, -0.72,  0.95, -0.29,  0.  ,  0.  ],
       [-0.14, -1.9 ,  0.95,  1.08, -0.39,  0.  ,  0.  ],
       [-0.02,  0.82,  0.  , -0.03,  0.44,  0.  ,  0.  ],
       [-0.64, -1.19,  0.98, -0.73, -0.87,  0.  ,  0.  ],
       [ 0.75,  0.61, -0.51,  0.94,  0.  ,  0.62,  0.  ],
       [-1.17, -0.91,  0.91, -1.02,  0.  , -0.16,  0.  ],
       [ 0.37,  0.33, -0.49, -0.05,  0.  ,  0.66,  0.  ],
       [-0.45, -0.01, -0.02,  0.03,  0.  , -0.25,  0.  ],
       [ 1.07, -1.01, -0.27,  0.99,  0.  , -0.7 ,  0.  ],
       [-0.01,  0.89, -0.04,  0.02,  0.  , -0.01,  0.  ],
       [ 0.62,  0.75, -0.23,  0.34,  0.  ,  1.02,  0.  ],
       [-0.  ,

In [895]:
model.A.round(1)

array([[ 0. , -0. , -0. ,  0. ],
       [ 0. , -0. , -0. , -0. ],
       [ 0. , -0.3, -0. ,  0. ],
       [-0.7,  0.6, -0. ,  0. ]])

In [896]:
# Check
A = np.transpose(g.to_amat()).round(2)
A

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.65, -0.69,  0.  ,  0.  ],
       [-0.34,  0.  ,  0.  ,  0.  ]])

# Score joint mixing matrix

In [897]:
def score_up_to_signed_perm(mat_hat, mat):
    ncols = mat.shape[1]
    min_error = float('inf')
    for perm in permutations(np.arange(ncols), ncols):
        error = np.linalg.norm(abs(mat_hat[:,perm]) - abs(mat))
        if error < min_error:
            min_error = error
            best_perm = perm   
    mat_hat = mat_hat[:, best_perm]
    for i in range(ncols):
        if np.linalg.norm(-mat_hat[:,i] - mat[:,i]) < np.linalg.norm(mat_hat[:,i] - mat[:,i]):
            mat_hat[:,i] = -mat_hat[:,i]
    return (np.linalg.norm(mat_hat - mat), mat_hat)

# This is a two-step approach: 
# First look for best permutation in terms of comparing absoulte values,
# then check if the column itself or multiplies with -1 is favorable

In [898]:
def score_joint_mixing(B_hat, B, m):
    
    B_perm = B_hat.copy()
    nr_joint = len(m["joint_idx"])

    # Score joint graph
    res = score_up_to_signed_perm(B_hat[:,:nr_joint], B[:,:nr_joint])
    B_perm[:,:nr_joint] = res[1]
    
    # Score domain-specific ones
    current_col = nr_joint
    for i in range(m["nr_doms"]):
        nlatents = len(m["domain_spec_idx"][i])
        res = score_up_to_signed_perm(B_hat[:,current_col:(current_col+nlatents)], 
                                      B[:,current_col:(current_col+nlatents)])
        B_perm[:,current_col:(current_col+nlatents)] = res[1]
        current_col = current_col+nlatents

    final_score = np.linalg.norm(B_perm - B)
    return (final_score, B_perm)

In [899]:
B_hat = model.joint_mixing
B = B_large

In [900]:
score, B_perm = score_joint_mixing(B_hat, B, model_specs)

In [901]:
score

0.45943630933029733

In [902]:
(B_perm - B).round(2)

array([[-0.06, -0.01,  0.03, -0.02, -0.01,  0.  , -0.  ],
       [ 0.  ,  0.01, -0.02, -0.03,  0.  ,  0.  , -0.  ],
       [-0.02, -0.01,  0.02, -0.02, -0.01,  0.  , -0.  ],
       [ 0.02,  0.01,  0.01,  0.02,  0.01,  0.  , -0.  ],
       [ 0.03, -0.  ,  0.03,  0.04,  0.  ,  0.  , -0.  ],
       [-0.01, -0.01, -0.02, -0.01, -0.  ,  0.  , -0.  ],
       [ 0.05,  0.01, -0.01,  0.02,  0.01,  0.  , -0.  ],
       [ 0.06, -0.02,  0.03,  0.1 ,  0.01,  0.  , -0.  ],
       [-0.01,  0.02, -0.  , -0.03,  0.  ,  0.  , -0.  ],
       [-0.03, -0.04,  0.01,  0.04, -0.01,  0.  , -0.  ],
       [-0.02, -0.  , -0.07, -0.05, -0.  ,  0.14, -0.  ],
       [ 0.05, -0.  ,  0.02, -0.03, -0.  , -0.16, -0.  ],
       [-0.01, -0.01, -0.04, -0.05, -0.  ,  0.03, -0.  ],
       [-0.01,  0.01,  0.02,  0.03, -0.  ,  0.01, -0.  ],
       [-0.05, -0.  , -0.03,  0.06, -0.  ,  0.13, -0.  ],
       [ 0.01,  0.01,  0.04,  0.02, -0.  , -0.01, -0.  ],
       [ 0.01, -0.01, -0.07, -0.1 , -0.  ,  0.05, -0.  ],
       [-0.  ,

In [903]:
B.round(2)

array([[ 0.21,  1.25, -0.96, -0.82, -0.27,  0.  ,  0.  ],
       [ 0.6 , -0.64,  0.93,  0.  , -0.41,  0.  ,  0.  ],
       [ 0.  ,  0.8 ,  0.  ,  0.  ,  0.35,  0.  ,  0.  ],
       [-0.21,  0.  ,  0.  ,  0.6 , -0.66,  0.  ,  0.  ],
       [-0.71,  0.47, -0.68,  0.78,  0.6 ,  0.  ,  0.  ],
       [-0.08, -0.32,  0.96, -0.26, -0.58,  0.  ,  0.  ],
       [-0.21, -0.8 ,  0.74,  0.93,  0.28,  0.  ,  0.  ],
       [-1.95,  0.17, -0.98,  0.97,  0.38,  0.  ,  0.  ],
       [ 0.82,  0.  ,  0.  ,  0.  , -0.44,  0.  ,  0.  ],
       [-1.16,  0.68, -0.99, -0.77,  0.89,  0.  ,  0.  ],
       [ 0.63, -0.75,  0.58,  0.99,  0.  ,  0.48,  0.  ],
       [-0.96,  1.17, -0.93, -0.99,  0.  ,  0.  ,  0.  ],
       [ 0.35, -0.37,  0.53,  0.  ,  0.  ,  0.63,  0.  ],
       [ 0.  ,  0.45,  0.  ,  0.  ,  0.  , -0.26,  0.  ],
       [-0.96, -1.07,  0.3 ,  0.93,  0.  , -0.83,  0.  ],
       [ 0.89,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.74, -0.61,  0.3 ,  0.44,  0.  ,  0.98,  0.  ],
       [-0.09,

In [904]:
B_perm.round(2)

array([[ 0.15,  1.24, -0.93, -0.84, -0.28,  0.  , -0.  ],
       [ 0.61, -0.63,  0.91, -0.03, -0.41,  0.  , -0.  ],
       [-0.02,  0.8 ,  0.02, -0.02,  0.34,  0.  , -0.  ],
       [-0.19,  0.01,  0.01,  0.62, -0.65,  0.  , -0.  ],
       [-0.68,  0.46, -0.65,  0.82,  0.6 ,  0.  , -0.  ],
       [-0.09, -0.33,  0.95, -0.26, -0.58,  0.  , -0.  ],
       [-0.16, -0.79,  0.72,  0.95,  0.29,  0.  , -0.  ],
       [-1.9 ,  0.14, -0.95,  1.08,  0.39,  0.  , -0.  ],
       [ 0.82,  0.02, -0.  , -0.03, -0.44,  0.  , -0.  ],
       [-1.19,  0.64, -0.98, -0.73,  0.87,  0.  , -0.  ],
       [ 0.61, -0.75,  0.51,  0.94, -0.  ,  0.62, -0.  ],
       [-0.91,  1.17, -0.91, -1.02, -0.  , -0.16, -0.  ],
       [ 0.33, -0.37,  0.49, -0.05, -0.  ,  0.66, -0.  ],
       [-0.01,  0.45,  0.02,  0.03, -0.  , -0.25, -0.  ],
       [-1.01, -1.07,  0.27,  0.99, -0.  , -0.7 , -0.  ],
       [ 0.89,  0.01,  0.04,  0.02, -0.  , -0.01, -0.  ],
       [ 0.75, -0.62,  0.23,  0.34, -0.  ,  1.02, -0.  ],
       [-0.09,

# Score paramter matrix of latent graph

In [905]:
def permutations_respecting_graph(A):
    l = A.shape[0]
    adj = (A!=0)
    respecting_perms = []
    for perm in permutations(range(l),l):
        is_respecting = True
        for j in range(l):
            for i in range(j+1,l):
                if adj[i,j] and perm[j] > perm[i]:
                    is_respecting = False
                    break
        if is_respecting:
            respecting_perms.append(perm)
    return(respecting_perms)

# Attention: transposed mat
def get_permutation_matrix(perm: list):
    nodes = list(range(len(perm)))
    mat = np.zeros((len(perm), len(perm)), dtype=int)
    mat[nodes, perm] = 1
    return mat.T

In [906]:
def score_graph_param_matrix(A_hat, A):
    l = A.shape[0]
    min_error = float('inf')
    for perm in permutations_respecting_graph(A):
        P = get_permutation_matrix(perm)
        A_hat_perm = P @ A_hat @ P.T
        error = np.linalg.norm(A_hat_perm - A)
        if error < min_error:
            min_error = error
            best_solution = (min_error, A_hat_perm)
    return best_solution

In [907]:
score_graph_param_matrix(model.A, np.transpose(g.to_amat()))[0]

0.10054815031637314

In [908]:
score_graph_param_matrix(model.A, np.transpose(g.to_amat()))[1].round(2)

array([[-0.01,  0.03, -0.  , -0.04],
       [-0.05,  0.02,  0.03, -0.02],
       [ 0.64, -0.7 ,  0.02, -0.02],
       [-0.31,  0.04,  0.02, -0.01]])

In [909]:
# Check
A = np.transpose(g.to_amat()).round(2)
A

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.65, -0.69,  0.  ,  0.  ],
       [-0.34,  0.  ,  0.  ,  0.  ]])

TODO: 
- Implement a routine to check the solution of recovering the joint mixing matrix (involves permutations)
- Implement a routine to check the causal order/ the parameter matrix of the graph
- Write code for simulations

Issues:
- With the current implementation it can happen that G has zero rows...
- Should be also possible to do first step with not allowing for pure children...
- We need that effects on pure children are positive to get exact matrix?! (otherwise "just" correct structure)

## Old: Recover graph (in infinite data limit)

In [270]:
B = B_large[:,:4]
B = B[:,[2,0,3,1]]  # add a permutation

In [271]:
def get_row_ranks(matrix):
    d = matrix.shape[0]
    R = np.zeros((d,d))
    for i in range(d):
        for j in range(i+1,d):
            R[i,j] = np.linalg.matrix_rank(matrix[[i,j],:])
    return R

In [272]:
# Find pure children of all (different) latent nodes
R = get_row_ranks(B) 
pos_rows = np.unique(np.where(R==1)[0])

# Remove duplicates
to_remove_idx = np.where(get_row_ranks(B[pos_rows,:])==1)[1]
mask = np.ones(len(pos_rows), dtype=bool)
mask[to_remove_idx] = False
pos_rows = pos_rows[mask]
B = B[pos_rows,:]
B

array([[ 0.        , -0.53299559,  0.        ,  0.70059211],
       [ 0.21788435, -0.02307926,  0.5478787 ,  0.50780665],
       [ 0.        ,  0.42496325,  0.        ,  0.        ],
       [ 0.46636378, -0.25776784,  0.        ,  0.        ]])

In [273]:
# Remove permutation indetermancies of rows and columns
def single_support_rows(matrix):
    d, l = matrix.shape
    row_ids = []
    max_ids = []
    for i in range(d):
        m = np.abs(matrix[i,:]).argmax()
        mask = np.ones(l, dtype=bool)
        mask[m] = False
        if np.sum(np.abs(matrix[i,mask])) == 0:
            row_ids.append(i)
            max_ids.append(m)
    return row_ids, max_ids

def search_order(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find rows with all but one element equal to zero
        row_ids, max_ids = single_support_rows(matrix)
        if len(row_ids) == 0:
            break
        target_row = row_ids[0]
        target_col = max_ids[0]
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [274]:
B

array([[ 0.        , -0.53299559,  0.        ,  0.70059211],
       [ 0.21788435, -0.02307926,  0.5478787 ,  0.50780665],
       [ 0.        ,  0.42496325,  0.        ,  0.        ],
       [ 0.46636378, -0.25776784,  0.        ,  0.        ]])

In [275]:
order_rows, order_cols = search_order(B)
if order_rows is not None:
    B = B[order_rows, :][:, order_cols]

In [276]:
# Remove scaling indetermancy and solve for A
np.eye(B.shape[0]) - np.linalg.inv(np.matmul(np.diag(1/np.diag(B)), B)).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.76,  0.  ,  0.  ,  0.  ],
       [-0.55,  0.  ,  0.  ,  0.  ],
       [ 0.88,  0.93,  0.4 ,  0.  ]])

In [277]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.76,  0.  ,  0.  ,  0.  ],
       [-0.55,  0.  ,  0.  ,  0.  ],
       [ 0.88,  0.93,  0.4 ,  0.  ]])

In [278]:
## okay up to relabeling that preserves causal order, etc

## Old: Recover noisy graph

In [279]:
l = len(model.joint_factors)
B_hat = model.joint_mixing[:,:l]#B_large[:,:l]#

In [280]:
def score_rows(matrix):
    d = matrix.shape[0]
    R = np.full(shape=(d,d),fill_value=.0)
    for i in range(d):
        for j in range(i+1,d):
            u, s, v = np.linalg.svd(matrix[[i,j],:])
            #s2 = np.sort(s) ** 2
            #R[i,j] = s2[-1] / s2.sum()
            R[i,j] = 1/(min(s)+0000000.1)
    return R

In [281]:
# Having 2*l candidates for the pure children, we have to make sure that they are pure children from different nodes

# Checks: 
# - 8 unique children
# - the rank between two children from different nodes should be 2 (p(W)<level)
# - If a two-children tuple does not satisfy one of the above checks: 
#    Remove it from candidate list. Add two-children tuple with next highest p(W) to candidate list

In [282]:
def get_duplicates(cand_ids, ord_tup):
    l = len(cand_ids)
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    uniques = np.unique(np.concatenate((ids0, ids1), axis=0))
    if len(uniques)==(2*l):
        return None
    else:
        for i in range(l):
            for j in range(i+1,l):
                if (ids0[i]==ids0[j]) or (ids0[i]==ids1[j]):
                    return cand_ids[j]

In [283]:
def get_low_rank(cand_ids, ord_tup, R, level=0.95):
    ids0 = ord_tup[0][cand_ids]
    ids1 = ord_tup[1][cand_ids]
    #print(ids0,  ids1)
    for i in range(l):
        for j in range(i+1,l):
            if (R[ids0[i],ids0[j]] > level) or (R[ids0[i],ids1[j]] > level):
                #print(i,j,R[ids0[i],ids0[j]],R[ids0[i],ids1[j]])
                return cand_ids[i], cand_ids[j]
    return None        

In [284]:
def update_cand_ids(cand_ids, to_remove):
    current_max = max(cand_ids)
    cand_ids.remove(to_remove)
    cand_ids.append(current_max+1)
    return(cand_ids)

In [285]:
def get_pure_children(B):
    
    # Score all pairs of rows
    R = score_rows(B)
    d = R.shape[0]
    nr_tuples = int(d * (d-1) / 2)
    
    ord_tup = np.unravel_index(R.ravel().argsort(), R.shape)
    ord_tup = (np.flip(ord_tup[0][nr_tuples:]), 
               np.flip(ord_tup[1][nr_tuples:]))
    
    # Make R symmetric now (important)
    for i in range(d):
        for j in range(i+1,d):
            R[j,i]=R[i,j]
            
    # Choose rows that maximize thw "within-tuple" score and at the same time have a low "inter-tuple" score 
    cand_ids = list(np.arange(l))
    while max(cand_ids) < nr_tuples:
        to_remove = get_duplicates(cand_ids, ord_tup)
        if to_remove is not None:
            cand_ids = update_cand_ids(cand_ids, to_remove)
            continue
        to_remove = get_low_rank(cand_ids, ord_tup, R, level=10)
        if to_remove is not None:
            temp_cands = cand_ids.copy()
            temp_cands = update_cand_ids(temp_cands, to_remove[0])
            if get_low_rank(temp_cands, ord_tup, R, level=10) is None:
                cand_ids = update_cand_ids(cand_ids, to_remove[0])
            else:
                cand_ids = update_cand_ids(cand_ids, to_remove[1])
            continue
        break   
    pure_children_rows = ord_tup[0][cand_ids]
    pure_children_rows
    
    return B[pure_children_rows,:]  

In [286]:
B_star = get_pure_children(B_hat)
B_star.round(2)

array([[-0.71,  0.02, -0.02,  0.52],
       [ 0.01, -0.02,  0.01, -0.43],
       [-0.02,  0.03,  0.45,  0.26],
       [-0.52,  0.56,  0.19, -0.01]])

In [287]:
def l2_without_max(x):
    m = np.abs(x).argmax()
    mask = np.ones(len(x), dtype=bool)
    mask[m] = False
    return np.linalg.norm(x[mask])

def search_order_noisy(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find row with lowest l2 norm where all entries but the maximum are considered
        target_row = np.apply_along_axis(l2_without_max, 1, matrix).argmin()
        target_col = abs(matrix[target_row,:]).argmax()
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [288]:
order_rows, order_cols = search_order_noisy(B_star)
if order_rows is not None:
    B_star = B_star[order_rows, :][:, order_cols]
B_star.round(2)

array([[-0.43,  0.01,  0.01, -0.02],
       [ 0.52, -0.71, -0.02,  0.02],
       [ 0.26, -0.02,  0.45,  0.03],
       [-0.01, -0.52,  0.19,  0.56]])

In [291]:
# Remove sign indetermancy from columns
B_star = np.matmul(B_star, np.diag(np.sign(np.diag(B_star))))

# Remove scaling indetermancy from rows  and solve for A
B_star = np.matmul(np.diag(1/np.diag(B_star)), B_star)

# Solve for A
A = (np.eye(B_star.shape[0]) - np.linalg.inv(B_star))
A.round(2)

array([[ 0.03,  0.02,  0.03, -0.04],
       [-0.75, -0.01, -0.01,  0.  ],
       [-0.6 , -0.01, -0.01,  0.05],
       [ 0.91,  0.93,  0.35, -0.02]])

In [292]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.76,  0.  ,  0.  ,  0.  ],
       [-0.55,  0.  ,  0.  ,  0.  ],
       [ 0.88,  0.93,  0.4 ,  0.  ]])