In [1]:
import numpy as np
from scipy.stats import uniform, expon, lognorm, weibull_min, chi2, t, gumbel_r

from src.learner import LinearMDCRL
from experiments.rand import rand_model
from experiments.utils import plot_hist_noise

## Create synthetic data

In [2]:
model_specs = {
    "nr_doms": 3,
    "joint_idx": [0,1,2,3],
    "domain_spec_idx": [[4],[5],[6]],
    "noise_rvs": [uniform(), expon(scale=0.1), t(df=5), gumbel_r, 
                  lognorm(s=1), weibull_min(c=2), chi2(df=6)],
    "sample_sizes":  [5000, 5000, 5000],
    "dims": [13,14,15],
    "graph_density": 0.75,
    "mixing_density": 0.9,
    "mixing_distribution": 'normal',  # unif or normal
    "indep_domain_spec": True
}

In [3]:
data, g, B_large = rand_model(model_specs)
# Can we plot g with causaldag?

## Fit model

In [4]:
model = LinearMDCRL()
model.fit(data)

#plot_hist_noise(model.indep_comps)

In [5]:
model.joint_factors

[[0, 1, 0], [1, 3, 4], [2, 4, 3], [3, 0, 2]]

In [6]:
model.joint_mixing.shape

(42, 7)

In [7]:
model.joint_mixing.round(2)

array([[ 1.97, -1.38, -2.82,  2.01,  1.25,  0.  ,  0.  ],
       [ 0.02, -0.03, -0.82,  0.02,  0.03,  0.  ,  0.  ],
       [-0.12,  0.13,  0.02, -0.16,  0.09,  0.  ,  0.  ],
       [ 1.59,  0.04, -0.39,  0.98, -0.46,  0.  ,  0.  ],
       [-0.04,  0.01,  0.33, -0.77,  0.44,  0.  ,  0.  ],
       [-0.68, -0.85, -0.62,  2.05,  0.65,  0.  ,  0.  ],
       [-0.4 , -0.01, -0.46, -0.24, -0.76,  0.  ,  0.  ],
       [-1.04,  0.27,  1.12, -0.83,  1.09,  0.  ,  0.  ],
       [-0.49, -0.  , -0.6 , -2.24,  0.22,  0.  ,  0.  ],
       [-0.48, -0.6 , -0.33,  0.48, -0.81,  0.  ,  0.  ],
       [-1.62,  0.61,  1.01, -0.78,  0.63,  0.  ,  0.  ],
       [-2.01,  0.08,  0.42, -1.16,  0.4 ,  0.  ,  0.  ],
       [-0.93,  2.41,  1.15, -2.08, -0.29,  0.  ,  0.  ],
       [-0.04, -0.51,  1.14,  0.87,  0.  ,  0.51,  0.  ],
       [ 3.68,  1.95,  0.22, -1.08,  0.  ,  0.2 ,  0.  ],
       [-0.12,  0.06,  1.03, -0.05,  0.  ,  2.26,  0.  ],
       [ 1.  ,  1.49, -0.1 , -1.77,  0.  ,  0.5 ,  0.  ],
       [-0.47,

In [8]:
B_large.round(2)

array([[-2.96,  1.98, -1.83, -1.29, -1.04,  0.  ,  0.  ],
       [-0.82,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ,  0.  ],
       [ 0.03, -0.16,  0.11,  0.13, -0.09,  0.  ,  0.  ],
       [-0.46,  0.98, -1.55,  0.  ,  0.47,  0.  ,  0.  ],
       [ 0.35, -0.75,  0.  ,  0.  , -0.44,  0.  ,  0.  ],
       [-0.69,  2.02,  0.78, -0.77, -0.6 ,  0.  ,  0.  ],
       [-0.45, -0.26,  0.41,  0.  ,  0.73,  0.  ,  0.  ],
       [ 1.18, -0.81,  0.98,  0.28, -1.1 ,  0.  ,  0.  ],
       [-0.52, -2.24,  0.42,  0.01, -0.2 ,  0.  ,  0.  ],
       [-0.36,  0.47,  0.53, -0.58,  0.78,  0.  ,  0.  ],
       [ 1.1 , -0.78,  1.56,  0.63, -0.67,  0.  ,  0.  ],
       [ 0.51, -1.17,  1.97,  0.12, -0.43,  0.  ,  0.  ],
       [ 1.31, -2.08,  0.78,  2.35,  0.21,  0.  ,  0.  ],
       [-1.15,  0.91, -0.03, -0.5 ,  0.  , -0.49,  0.  ],
       [-0.2 , -1.13,  3.77,  1.8 ,  0.  , -0.15,  0.  ],
       [-1.04,  0.  ,  0.  ,  0.  ,  0.  , -2.27,  0.  ],
       [ 0.13, -1.83,  1.09,  1.39,  0.  , -0.54,  0.  ],
       [-0.12,

## Recover graph (in infinite data limit)

In [9]:
B = B_large[:,:4]
B = B[:,[2,0,3,1]]  # add a permutation

In [10]:
def get_row_ranks(matrix):
    d = matrix.shape[0]
    R = np.zeros((d,d))
    for i in range(d):
        for j in range(i+1,d):
            R[i,j] = np.linalg.matrix_rank(matrix[[i,j],:])
    return R

In [11]:
# Find pure children of all (different) latent nodes
R = get_row_ranks(B) 
pos_rows = np.unique(np.where(R==1)[0])

# Remove duplicates
to_remove_idx = np.where(get_row_ranks(B[pos_rows,:])==1)[1]
mask = np.ones(len(pos_rows), dtype=bool)
mask[to_remove_idx] = False
pos_rows = pos_rows[mask]
B = B[pos_rows,:]
B

array([[ 0.        , -0.8209847 ,  0.        ,  0.        ],
       [ 0.11096599,  0.03280899,  0.13252922, -0.16448962],
       [-1.55144494, -0.46064926,  0.        ,  0.98116152],
       [ 0.        ,  0.3534866 ,  0.        , -0.75291002]])

In [12]:
# Remove permutation indetermancies of rows and columns
def single_support_rows(matrix):
    d, l = matrix.shape
    row_ids = []
    max_ids = []
    for i in range(d):
        m = np.abs(matrix[i,:]).argmax()
        mask = np.ones(l, dtype=bool)
        mask[m] = False
        if np.sum(np.abs(matrix[i,mask])) == 0:
            row_ids.append(i)
            max_ids.append(m)
    return row_ids, max_ids

def search_order(matrix):
    d = matrix.shape[0]
    order_rows = []
    order_cols = []
    original_idx_rows = np.arange(d)  
    original_idx_cols = np.arange(d)  
    while 0 < matrix.shape[0]:
        # Find rows with all but one element equal to zero
        row_ids, max_ids = single_support_rows(matrix)
        if len(row_ids) == 0:
            break
        target_row = row_ids[0]
        target_col = max_ids[0]
        # Append index to order
        order_rows.append(original_idx_rows[target_row])
        order_cols.append(original_idx_cols[target_col])
        original_idx_rows = np.delete(original_idx_rows, target_row)
        original_idx_cols = np.delete(original_idx_cols, target_col)
        # Remove the row and the column from B
        row_mask = np.delete(np.arange(matrix.shape[0]), target_row)
        col_mask = np.delete(np.arange(matrix.shape[1]), target_col)
        matrix = matrix[row_mask,:][:, col_mask]
    if len(order_rows) != d:
            return None, None
    return order_rows, order_cols

In [13]:
order_rows, order_cols = search_order(B)
if order_rows is not None:
    B = B[order_rows, :][:, order_cols]

In [14]:
# Remove scaling indetermancy and solve for A
np.eye(B.shape[0]) - np.linalg.inv(np.matmul(np.diag(1/np.diag(B)), B)).round(2)



array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.47,  0.  ,  0.  ,  0.  ],
       [ 0.  , -0.63,  0.  ,  0.  ],
       [-0.34, -0.71,  0.84,  0.  ]])

In [15]:
# Check
np.transpose(g.to_amat()).round(2)

array([[ 0.  ,  0.  ,  0.  ,  0.  ],
       [-0.47,  0.  ,  0.  ,  0.  ],
       [ 0.  , -0.63,  0.  ,  0.  ],
       [-0.34, -0.71,  0.84,  0.  ]])

In [16]:
## okay up to relabeling that preserves causal order, etc

TODO: 
- Implement a routine to check the solution of recovering the joint mixing matrix 
- How to implement graph recovery for real data? (DIFFICULT)
- With the current implementation it can happen that G has zero rows...

## Recover noisy graph