In [130]:
import sys
import itertools
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
from timeit import default_timer as timer
sys.path.append("..")
import helper.helper as h
from notears.notears.notears import linear
from notears.notears.notears import utils

#### Exhaustive Kernel

#### Quickly Evaluate Loss Function

In [114]:
def loss(W):
    M = X @ W
    
    # Remove X[0] and XW[last]
    R = X - M
    
    # Frobenius norm squared loss
    loss = 1 / X.shape[0] * (R ** 2).sum()
    
    return loss

#### Do OLS constrained to P, use Kernelized $\Psi$

In [31]:
def PK_OLS(Psi, P):
    W_hat = np.zeros((n, n))

    # get parameters
    for i, p in enumerate(P[:-1]):     
        W_hat[P[i + 1:], p] = np.linalg.inv(Psi[P[i + 1:, None], P[None, i + 1:]]) @ Psi[P[i + 1:], p]
    
    return W_hat

#### Exhaustively try all possible permutations $P$.

In [40]:
def exhaustive_search_K(X):
    
    # initialization
    Psi = X.T.dot(X)

    # initialization
    _, n = np.shape(X)
    W_best = np.identity(n)
    MSE_best = np.inf
    
    # iterate over all permutations
    for perm in tqdm(itertools.permutations(range(n)), total=np.math.factorial(n)):
        
        # get OLS optimal under permutation
        W = PK_OLS(Psi, np.array(perm))
        
        # get its loss
        MSE_W = loss(W)
        
        # remember best one
        if MSE_W < MSE_best:
            
            MSE_best = MSE_W
            W_best = W
            
    return W_best

#### Randomly Search for permutations until $\texttt{checks} > \texttt{max_checks}$ or $\texttt{time} > \texttt{max_time}$.

In [169]:
def random_search_K(X, max_checks, max_time):
    
    print(f"Trying random permutations for {round(max_time, 2)} seconds.")
    
    # start timer
    start = timer()
    
    # compute kernel
    Psi = X.T.dot(X)
    
    # initialize values for search
    n, checks = np.shape(Psi)[0], 0
    W_best, MSE_best, p_best = np.identity(n), np.inf, np.array(range(n))
    
    # iterate over all permutations
    while checks <= max_checks and timer() - start <= max_time:
        
        # get OLS optimal under permutation
        W = PK_OLS(Psi, np.random.permutation(p_best))
        
        # get its loss
        MSE_W = loss(W)
        
        # remember best one
        if MSE_W < MSE_best:
            
            # if better, change to bes tone
            MSE_best = MSE_W
            W_best = W
        
        # increase number of checks
        checks += 1
    
    # return number of permutations tried
    print(f"Number of permutations tried: {checks}.")
    
    # return best W found
    return W_best

#### Generate SEM data

In [204]:
utils.set_random_seed(7)

T, n, s0, graph_type, sem_type = 500, 10, 15, 'ER', 'gauss'
W_true = utils.simulate_parameter(utils.simulate_dag(n, s0, graph_type))
X = utils.simulate_linear_sem(W_true, T, sem_type)

print(np.round(W_true, 2))

[[ 0.    0.   -0.58  0.    0.   -0.74  0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.    1.07  0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.   -0.77  0.    1.48  0.    0.    0.    0.    0.   -0.76]
 [ 0.    0.    1.45  0.    0.    0.    0.    0.    0.    0.  ]
 [ 1.19  0.    0.    0.    0.    0.    0.    0.    0.    0.  ]
 [ 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.65]
 [-1.93  0.    1.63  0.    0.    0.    0.    0.    0.   -1.8 ]
 [ 0.    0.    1.84 -0.51  0.    1.98  0.    0.    0.    0.  ]]


#### Run NOTEARS, check for how long

In [198]:
start = timer()
W_N, h_val = linear.notears_linear(X, lambda1 = 0.0, loss_type = "l2", w_threshold = 1e-3, verbose = False)
max_time = timer() - start

print(f"Found solution in {round(max_time, 2)} seconds.")

Found solution in 20.28 seconds.


#### Do Exhaustive Search for so long

In [202]:
W_R = random_search_K(X, 1e6, max_time)
# W_E = exhaustive_search_K(X)
# print(round(loss(W_E), 3))
print(round(loss(W_R), 3))
print(round(loss(W_N), 3))

Trying random permutations for 20.28 seconds.
Number of permutations tried: 10272.
9.747
9.749


#### Both equal NOTEARS thresholding, results

In [203]:
W_RT = W_R.copy()
W_RT[np.abs(W_RT) <= 0.30] = 0
h.score(X, W_RT, W_true, is_sem = True);
print()

# W_ET = W_E.copy()
# W_ET[np.abs(W_ET) <= 0.30] = 0
# h.score(X, W_ET, W_true, is_sem = True);
# print()

W_NT = W_N.copy()
W_NT[np.abs(W_NT) <= 0.30] = 0
h.score(X, W_NT, W_true, is_sem = True);

True Positive Rate: 1.0.
True Negative Rate: 1.0.
False Prediction Rate: 0.0
Accuracy: 1.0.
R-Squared: 0.947
Mean Squared Error: 10.051

True Positive Rate: 1.0.
True Negative Rate: 1.0.
False Prediction Rate: 0.0
Accuracy: 1.0.
R-Squared: 0.947
Mean Squared Error: 10.083
