In this short demo, I will show how to use the csPCR test on a simulated dataset.

In [66]:
'''
First, we generate a dataset, we can change the Alpha_s, Alpha_t and effect parameter to change the distribution of
the generated dataset.
'''

import numpy as np

def generate(ns, nt, p,q, s, t, u, Alpha_s=1, Alpha_t = 0,effect=1, z_diff = 0.1):
    Zs_null = np.random.normal(0,0.1, (ns, q))
    Zt_null = np.random.normal(0,0.1, (nt, q))
    
    Z_source = np.hstack((np.random.normal(0, 1, (ns, p)) , Zs_null))
    Z_target = np.hstack((np.random.normal(z_diff, 1, (nt, p)) , Zt_null))
    
    X_source = Z_source[:, :p] @ u + np.random.normal(0, 1, ns)
    X_target = Z_target[:, :p] @ u + np.random.normal(0, 1, nt)

    V_source = Z_source[:, :p] @ s + Alpha_s * (X_source) + np.random.normal(0, 5, ns)
    V_target = Z_target[:, :p] @ t + Alpha_t * (X_target) + np.random.normal(0, 5, nt)
    
    # V_source = Z_source[:, :p] @ s + 2*X_source 
    # V_target = Z_target[:, :p] @ t - 2*X_target
    
    Y_source = (Z_source[:, :p].sum(axis=1))**2 + effect*V_source + np.random.normal(0, 1, ns)+ 0 * X_source
    Y_target = (Z_target[:, :p].sum(axis=1))**2 + effect*V_target + np.random.normal(0, 1, nt)+ 0 * X_target
    
    
    return Y_source.reshape(-1, 1), X_source.reshape(-1, 1), V_source.reshape(-1, 1), Z_source,\
           Y_target.reshape(-1, 1), X_target.reshape(-1, 1), V_target.reshape(-1, 1), Z_target


'''
According to the generation process, 
we write 2 functions to calculate the xz_ratio and true density ratio of the data
'''


def true_density_ratio(X, Z, V, s=np.array([-0.5, -1.0,  0.3, -0.9, -1.5 ]), t=np.array([ 1.5, -0.2 ,  0.06 , -1.4, -0.5]),\
                       p=5,q=50, Alpha_s = 1, Alpha_t = 0,z_diff = 0.1):
    ratios = []
    size = V.size
    for i in range(size):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        vs_prob = norm.pdf(V[i], loc=Z[i][:p]@s + Alpha_s*X[i], scale =5)
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        vt_prob = norm.pdf(V[i], loc=Z[i][:p]@t + Alpha_t*X[i], scale =5)
        ratios.append((zt_prob*vt_prob)/(zs_prob*vs_prob))
    
    return np.array(ratios)

In [2]:
'''
Define some parameters for data generation
'''
nt, p,q = 2000, 5, 50
est_size = 200
ns = est_size + 500
est_ratio = 500/ns
s = np.array([-0.5, -1.0,  0.3, -0.9, -1.5])
t = np.array([ 1.5, -0.2, 0.06 , -1.4, -0.5])
u = np.array([ 0.1, -1.1,  0.4, -0.6, -0.3])


In [3]:
from csPCR_functions import *

In [4]:
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,effect=1)

'''
Since this method is model_X framework, we assume that the X|Z,V and Z distribution are known. So we need to define three 
extra functions as the input of the test.
1. Model_X: Should be a function of the conditional model of X|Z,V
   Input: z, v values(float)
   Return: one X sample follows the X|Z,V distribution

2. E_X: Should be a function of the conditional expectaion of X|Z,V
   Input: z,v values(float)
   Return: the calculated conditional expectation E[X|Z,V]

3. xz_ratio: Should a function for calculating the (X,Z) density ratio (i.e. P_t(X,Z,V)/P_s(X,Z,V))
   Input: X, Z (ndarrays)
   
   Return: density ratio array (ndarray)
'''

def Model_X(z, v):
    # Conditional distribution of X|Z
    return z[:5] @ u + np.random.normal(0, 1, 1)

def E_X(z, v):
    # Conditional expectation of X|Z
    return z[:5] @ u



def xz_ratio(X, Z, z_diff = 0.1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)

'''
To use the Test function, need input of source and target X,Z,V and source Y data into the function,
Also take 3 functions mentioned above as inputs.
'''
Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X, xz_ratio= xz_ratio, test_size = 0.5, L=2)


[ 0.08511383 -0.83215705  0.28860972 -0.50702646 -1.58417422  0.
 -0.23657813  0.          0.38312594  0.00205537  0.          0.
 -0.1727692  -0.          0.         -0.          0.         -0.
 -0.          0.         -0.21067892  0.          0.         -0.
 -0.         -0.31827685  0.          0.          0.77119617  0.
 -0.         -0.         -0.         -0.         -0.          0.14401686
 -0.          0.41706802 -0.          0.          0.         -0.94624391
 -0.         -0.          0.9006418  -0.         -0.         -0.
 -0.         -0.          0.23625544  0.          0.          0.
 -0.          1.03409317]
[ 1.40785096 -0.07549534  0.03347084 -1.25541563 -0.50811321 -0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.          0.          0.          0.          0.         -0.
 -0.         -0.39110859  0.40899386 -0.46210654  0.         -0.
  0.          0.         -0.49831032 -0.38082118  0.         -0.07678866
 -0.01493081  0.         -0.          0.

0.5050562241016008

In [22]:
Test_pe_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X,  true_dr = true_density_ratio, L=2)

[0.32747974 0.32747974]
[378.61366593 348.84699698]


0.3808304903747278

In [None]:
ns = 1000
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u, Alpha_s=1, Alpha_t = 0,effect=0, z_diff = 1)

Test_pe(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X, xz_ratio = xz_ratio, test_size=est_ratio)

In [71]:
c= 0
nt, p,q = 1000, 10, 10
est_size = 10
ns = est_size + 200
est_ratio = 200/ns
#ns = 500

s = np.array([-1, -0.5, 0, 1, 1.5,-1, -0.5, 0, 1, 1.5])
t = np.array([ -1, -1, 0.5 , 0.5, 1,-1, -1, 0.5 , 0.5, 1])
t = (t+3*s)
u = np.array([ 0, -1, 0.5, -0.5, 1,0, -1, 0.5, -0.5, 1])
# ns = 500
def true_density_ratio(X, Z, V, s=s, t=t,\
                       p=p,q=q, Alpha_s = 1, Alpha_t = 0,z_diff = 0.1):
    ratios = []
    size = V.size
    for i in range(size):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        vs_prob = norm.pdf(V[i], loc=Z[i][:p]@s + Alpha_s*X[i], scale =5)
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        vt_prob = norm.pdf(V[i], loc=Z[i][:p]@t + Alpha_t*X[i], scale =5)
        ratios.append((zt_prob*vt_prob)/(zs_prob*vs_prob))
    
    return np.array(ratios)

def Model_X(z, v):
    # Conditional distribution of X|Z
    return z[:p] @ u + np.random.normal(0, 1, 1)

def E_X(z, v):
    # Conditional expectation of X|Z
    return z[:p] @ u



def xz_ratio(X, Z, z_diff = 0.1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)

for i in trange(100):
    Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,Alpha_s=1, Alpha_t = 0,effect=1, z_diff = 0.1)

    # p_value = Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
    # model_X = Model_X, E_X = E_X, xz_ratio = xz_ratio, test_size=est_ratio, L=3)
    p_value = Test_pe_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
    model_X = Model_X, E_X = E_X,  true_dr = true_density_ratio, L=3)
    print(p_value)
    if p_value < 0.05:
        c += 1

  1%|          | 1/100 [00:00<00:14,  7.07it/s]

[-0.87093    -0.          0.          0.97616334  1.20432746 -0.65816751
 -0.         -0.          1.54227953  0.24405705 -0.         -0.
  0.          0.          0.          0.          0.         -0.
 -0.          0.          1.49849258]
[-4.19770035 -2.22307032  0.47660964  3.24498817  5.49543668 -4.07616599
 -2.48937256  0.43993892  3.41512701  5.43051849 -0.         -0.
 -0.          0.         -0.          0.         -0.          0.
  0.          0.         -0.        ]
max dr: 22.862684218338025
0.9506221241049787
[-0.85257072  0.18933585  0.26173273  0.99056831  0.61438989 -1.02096183
  0.         -0.1929962   0.9827152   1.44240215  0.         -0.
 -0.          0.         -0.          0.         -0.          0.
 -0.          0.          1.54867541]
[-4.02536117 -2.15701527  0.22821365  3.46630095  5.49489533 -4.04121028
 -2.11482725  0.34058074  3.6282141   5.1157334   0.54866056 -2.49282464
  0.          0.         -0.          0.50571958 -1.51218282  0.
  0.09303225 -0.    

  3%|▎         | 3/100 [00:00<00:12,  7.85it/s]

0.08559433132869454
[-1.22356197 -0.78666814 -0.20979949  1.04496906  1.53320587 -1.21348447
 -0.81314539 -0.58001548  1.31618905  0.          0.         -0.
 -0.          0.          0.         -0.          0.          0.
 -0.         -0.          1.14032073]
[-3.97489902 -2.31431701  0.89501656  3.53045899  5.54815596 -3.99225958
 -2.46738238  0.56510361  3.41279459  5.56485511 -0.          0.
 -0.          0.         -0.         -0.         -0.         -0.
 -0.          1.35656652 -0.        ]
max dr: 7.199030731404195
0.3306538089569092
[-0.69978663 -0.39117247  0.          1.16764952  0.51604405 -1.21457745
 -0.78491836 -0.          1.03138271  1.23261145 -0.         -0.
 -0.         -0.         -0.         -0.          0.          0.
 -0.         -0.          1.33566669]
[-4.15009525 -2.26717184  0.30717327  3.71779715  5.04334945 -4.02745064
 -2.24042924  0.25079157  3.58126509  5.16135825 -0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.     

  5%|▌         | 5/100 [00:00<00:11,  8.06it/s]

0.07567201958339231
[-1.248225    0.52091782 -0.58396479  1.31310558  0.42679517 -1.22739624
  0.22820785  0.1307873   1.83682964  1.13787741 -0.         -0.
  1.71800615 -4.70539491  0.         -0.         -0.          0.
 -0.          0.          2.02092863]
[-3.68437101 -2.7549394   0.61981524  3.53675285  5.51596328 -4.20265025
 -2.3427565   0.46128515  3.70066201  5.37745037 -0.         -0.
  0.         -0.          0.          0.         -0.         -0.
  0.          0.         -0.        ]
max dr: 16.68028243273856
0.34510600738902875
[-0.20389571 -0.         -0.          0.4431462   0.         -0.01799114
  0.          0.07648361  0.          0.47394083 -0.          0.
 -0.          0.          0.         -0.         -0.          0.
  0.          0.          1.54936114]
[-3.9679687  -2.42161739  0.80780832  3.60538584  5.72816524 -4.24552065
 -2.74345954  0.58881617  3.24173911  5.78324114 -0.         -0.
 -0.         -0.          0.          0.         -0.         -0.
  0.    

  7%|▋         | 7/100 [00:00<00:11,  8.04it/s]

0.3770653825143835
[ 0.         -0.59192934 -1.00364312  0.50848637  2.92036214 -1.11960111
 -1.0048555   0.          1.1939205   1.29959473 -0.          0.
 -0.         -0.         -0.         -0.          0.          0.
  0.         -0.          0.74238362]
[-4.09542571 -2.59489915  0.26511994  3.42167644  5.68073418 -4.31885472
 -2.5793313   0.46097029  3.60425521  5.51734224 -1.97015559  0.
  0.          0.         -2.24604582 -1.73302515  0.          0.47904737
  0.         -0.42267945 -0.1312754 ]
max dr: 7.826591411267531
0.031310330368144346
[-0.17754146 -0.          0.          0.81684841  0.3261885  -0.60171244
 -0.          0.          1.05620136  0.27576803 -0.         -0.
 -0.          0.          0.          0.          0.         -0.
  0.         -0.          1.88751474]
[-3.79332253 -2.15384821  0.49175432  3.88649035  5.2235226  -3.82790564
 -2.46999062  0.34896197  3.62391645  5.20248288  0.          0.
 -0.          0.         -0.          0.         -0.         -0.


  9%|▉         | 9/100 [00:01<00:11,  8.11it/s]

0.015375141662321568
[-0.8172235  -0.60585358  0.47433993  0.74420657  2.1209137  -0.35693408
  0.10599889  0.44746098  0.64792679  1.7396515  -0.         -0.40221756
  5.69419762 -0.57421032 -0.          7.67416088  4.15872567 -4.29338825
 -0.         -0.          0.8082004 ]
[-3.7528124  -2.28092791  0.19408873  3.4891276   4.98422248 -4.17971476
 -2.19438341  0.53299109  3.5540641   5.0928236  -0.          0.36897997
  0.         -0.         -0.          0.          0.          0.
 -0.         -0.          0.15878499]
max dr: 14.468023736456379
0.192499641368164
[-0.         -0.          0.          0.          0.         -0.
  0.         -0.          0.          0.03263721  0.         -0.
  0.          0.         -0.         -0.          0.         -0.
  0.          0.          1.75197372]
[-3.73194136 -2.55689221  0.39411941  3.53474829  5.8403761  -3.91521852
 -2.48993757  0.4634902   3.59082031  5.65678985 -0.         -0.
 -0.         -0.          0.         -0.         -0.     

 11%|█         | 11/100 [00:01<00:10,  8.15it/s]

0.19450473463423745
[-0.          0.         -0.         -0.          0.40864272 -0.
 -0.         -0.          0.49368044  0.         -0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.          0.          1.33214253]
[-4.14624652 -2.27404148  0.10292151  3.63241171  5.38366259 -4.17536981
 -2.28243329  0.65160122  3.5325968   5.36783801  0.          0.
 -0.22098919  0.          0.         -0.          0.          0.
 -0.         -0.          0.07978203]
max dr: 155.89193225038596
0.29989088934620434
[-0.         -0.         -0.          0.50530785  0.08043244 -1.10908912
 -0.         -0.07395313  1.11496701  0.83550448  0.          0.
 -0.         -0.         -0.          0.         -0.          0.
  0.         -0.          1.26312908]
[-3.89502982 -2.35116403  0.44235371  3.60979697  5.17680151 -4.12360577
 -2.56259248  0.32553806  3.54050602  5.27785063 -0.         -0.
  0.          0.          0.         -0.          0.         -0.
  0.          0

 13%|█▎        | 13/100 [00:01<00:10,  8.20it/s]

0.12871537962743573
[-1.10231139 -0.14590097  0.          0.29109217  1.53385883 -0.25693561
 -0.         -0.35343417  0.50317974  1.87056625  0.          0.
 -0.         -0.          0.         -0.         -0.          0.
 -0.         -0.          1.05379706]
[-4.03311711 -2.36893138  0.27780252  3.57778858  5.53374464 -4.17603063
 -2.45483936  0.70485716  3.30531102  5.22530831  0.         -0.
  0.          0.         -0.         -0.          0.         -0.
 -0.         -0.          0.05141737]
max dr: 47.87439220088677
0.4057314361470974
[-1.14355817  0.          0.01238861  0.63370683  1.17777502 -1.08448009
 -0.7367909   0.82888773  0.16079878  1.0288025  -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
  0.          0.          0.85239448]
[-4.11599746 -2.01704225  0.69568917  3.6911857   5.13584998 -3.95655481
 -2.17291009  0.57656197  3.70978776  5.26959051 -0.0470808   0.
  0.          0.          0.          0.          0.          0.
  0.     

 15%|█▌        | 15/100 [00:01<00:10,  8.20it/s]

0.5588557909865803
[-1.73293142  0.         -0.67006855  0.56829413  0.83594417 -1.64279788
 -1.52306737  0.48969307  0.44724309  0.89012444  1.15877561  0.
  0.         -0.         -0.          1.84207814  0.          0.
  0.          0.          1.07829405]
[-3.751737   -2.57484132  0.4534903   3.43440417  5.61354579 -3.83561123
 -2.27696004  0.51232558  3.43213677  5.45215441  0.         -0.
 -0.          0.         -0.          0.         -0.         -0.
 -0.          0.          0.009889  ]
max dr: 17.806095491396615
0.8719785638873466
[-1.95861671 -0.          0.19756876  1.03559568  0.14216863 -0.89210961
  0.         -0.05532585  1.12525093  0.07298493  0.          0.
 -0.         -0.         -0.         -0.          0.          0.
 -0.          0.          1.57367204]
[-3.94247278 -2.44548967  0.35609148  3.51525578  5.32815017 -3.84478415
 -2.29528836  0.38204399  3.43152809  5.31673717  0.          0.
  0.          0.          0.          0.          0.         -0.
 -0.     

 17%|█▋        | 17/100 [00:02<00:10,  8.10it/s]

0.37911597759291404
[-1.39021579 -0.57714206  0.34228355  0.55227226  2.02610279 -0.51820166
 -0.92687862 -0.          0.74166982  1.48517792 -0.          0.
  0.         -0.         -0.          0.          0.         -0.
 -0.          0.          0.77470791]
[-4.13889113 -2.51838478  0.50633429  3.62197549  5.56442806 -3.88963735
 -2.25385288  0.32055502  3.16526183  5.5209964   0.          0.
  0.          0.         -0.          0.          0.          0.
  0.          0.          0.        ]
max dr: 20.12217545973335
0.9276547611924849
[-0.65232514 -1.63692747 -0.02522424 -0.2186034   3.63574777 -0.26721304
 -1.39402355 -0.          0.49821256  2.85054806  0.         -0.
  0.         -0.         -0.          0.          4.17045907 -0.22141199
 -0.          0.          0.10064274]
[-3.95441744 -2.32699667  0.17768387  3.10099573  5.27296931 -4.21964491
 -2.54710721  0.34828044  3.27016865  5.38884743  0.          2.07072859
  0.8346652   0.         -0.          0.          0.      

 19%|█▉        | 19/100 [00:02<00:10,  7.97it/s]

0.028883757309528302
[-0.74566745 -0.         -0.51318132  0.55662619  0.80530184 -0.41023062
 -0.73433535  0.04719197  0.99884901  0.55047507 -0.         -0.
  0.          0.          0.          0.         -0.          0.
 -0.         -0.          1.17805902]
[-4.16226073 -2.43005073  0.6429498   3.27076045  5.45861256 -4.2160687
 -2.22310482  0.69450379  3.49034442  5.3904354   0.         -0.
 -0.          0.          0.         -0.          0.          0.
  0.         -0.         -0.        ]
max dr: 20.734984143315398
0.29209660043810615
[-1.53368422  0.49122819  0.88337657 -0.05269201  0.72931232 -1.39526603
 -0.         -0.25400882  1.96250714  0.88145393 -0.          4.60762759
 -0.         -0.         -0.         -0.         -0.          0.
  0.         -0.          1.28192855]


 21%|██        | 21/100 [00:02<00:10,  7.22it/s]

[-4.07457728 -2.20214275  0.71829394  3.49455747  5.55245502 -4.15859944
 -2.73386198  0.54201485  3.3239901   5.56860595 -0.         -0.
  0.          0.          0.          0.02019921  0.         -0.
  0.         -0.          0.08792191]
max dr: 46.710309556445345
0.2521024743012885
[-0.58264025  0.          0.17122415  0.71746085  1.55206184 -0.91856873
 -0.40158327 -0.82806326  0.32048599  0.49953947 -0.          0.
  0.          0.         -0.          0.          0.         -0.
 -0.         -0.          1.03399203]
[-4.16757006 -2.20981698  0.48704949  3.35616437  5.35292807 -3.99595308
 -2.34958794  0.32116594  3.47666506  5.51097694  0.          0.
  0.         -0.          0.         -0.         -0.         -0.
 -0.         -0.          0.22925803]
max dr: 18.17680331135656
0.615640926190847


 22%|██▏       | 22/100 [00:02<00:10,  7.38it/s]

[-0.74920626  0.19515662  0.33364725  1.69024114  0.36221312 -0.80170499
 -0.44982313  0.01769877  0.73471587  0.61134996 -0.         -0.
  0.          0.         -0.         -0.         -0.          0.
 -0.          0.          1.59189324]
[-4.08812101 -2.60981118  0.67169614  3.44338972  5.79150284 -4.06339815
 -2.32354779  0.54420385  3.24573615  5.27454456  0.          0.
  0.          0.          0.          0.          0.         -0.
 -0.          0.          0.        ]
max dr: 32.768709491981525
0.3244170575683256
[ 0.         -0.         -0.          0.859364    0.         -0.
 -0.         -0.          0.28576381  0.69460818  0.          0.
  0.         -0.          0.         -0.          0.         -0.
  0.         -0.          1.80656104]
[-4.08669073 -2.4538059   0.18128145  3.64576272  5.14401124 -3.99478226
 -2.23441867  0.38433876  3.72242066  5.31000826  0.          0.
 -0.          0.         -0.          0.          0.         -0.
 -0.          0.          0.23419461

 24%|██▍       | 24/100 [00:03<00:09,  7.73it/s]

0.49576781081100896
[-0.06241648  0.         -0.          0.66520497  0.15354454 -0.80949132
 -0.         -0.          0.          1.00652543 -0.          0.
  0.          0.         -0.         -0.         -0.         -0.
  0.         -0.          1.42312636]
[-4.17119063 -2.39210205  0.39597752  3.53786622  5.51551401 -3.95803153
 -2.53446793  0.33439029  3.4803917   5.49690991  0.         -0.
 -0.          0.          0.          0.         -0.          0.
  0.          0.         -0.        ]
max dr: 16.618159342972312
0.7149636854906658
[-0.90699405  0.         -0.35957786  1.11812981  0.76823722 -0.67399688
 -0.54997511  0.25096563  1.02806856  0.56404903 -0.         -0.
  0.          0.          0.          0.         -0.          0.
 -0.          0.          1.23718852]
[-3.61114545 -2.34678774  0.32270402  3.54986986  5.63487183 -3.84765541
 -2.00293784  0.45290255  3.58233688  5.0691817  -0.64886039  0.
 -0.65348232  0.         -0.         -0.         -1.9754048  -1.6141542
 

 26%|██▌       | 26/100 [00:03<00:09,  7.94it/s]

0.3194537778620913
[-0.68638585 -0.35443572 -0.53613044  1.21694054  1.28864084 -0.54977064
 -0.12064255  0.04836631  0.62325023  0.86743051 -1.9475469  -3.23766581
  0.         -0.         -4.77789945 -0.          0.          0.
 -0.          6.53899298  1.5350254 ]
[-3.96013885e+00 -2.68328401e+00  6.85350918e-01  3.39748864e+00
  5.23190202e+00 -3.98865547e+00 -2.48029266e+00  5.81088608e-01
  3.41231552e+00  5.50565110e+00  0.00000000e+00 -0.00000000e+00
 -0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00 -0.00000000e+00  0.00000000e+00
  2.72932305e-03]
max dr: 8.752268432368696
0.3007769744792824
[-1.20590084 -0.04731408 -0.61093951  0.58257488  0.64635076 -0.97102004
 -0.          0.08665738  1.48699007  0.44258565  0.          0.
 -0.          0.         -0.          0.          0.          0.
  0.         -0.          1.58013516]
[-4.14348719 -2.03539355  0.35718183  3.53125356  5.31217336 -4.13902778
 -2.40734141  0.13419782  3.62

 28%|██▊       | 28/100 [00:03<00:08,  8.05it/s]

0.5093391488218173
[-1.11896998  0.         -0.44793567  2.06185256  1.04938773 -0.91923829
  0.13502422 -0.32500844  1.03116506  1.28640499 -0.          0.
 -0.          0.         -0.         -0.          0.          0.
  0.          0.          1.77295679]
[-3.88493958 -1.88462596  0.27058296  3.41119889  4.84194642 -4.06760285
 -2.15198125  0.51531629  3.28252067  4.92633457  0.          0.
  0.         -0.         -0.         -0.         -0.          0.
 -0.         -0.          0.21086943]
max dr: 50.27674543316785
0.4395462658910022
[-0.5170412  -0.67933116 -0.          0.89099899  1.45479203 -0.12011044
 -1.12272842 -0.26101877  0.57294816  1.58901641  0.          0.
  0.          0.         -0.         -0.          0.          0.
  0.          0.          0.73658489]
[-4.04093222 -2.56390926  0.56660928  3.84615815  5.04503301 -4.15850587
 -2.4512675   0.54980665  3.67610939  5.48211073  0.27128895 -0.
  0.10354948  0.         -0.          1.17397575 -0.          0.
 -0.      

 30%|███       | 30/100 [00:03<00:08,  8.07it/s]

0.8534778127749709
[-0.         -0.5248965  -0.          1.19004846  0.66664495 -0.9092708
  0.         -0.4426148   1.14585676  0.29161331 -0.          0.
 -0.         -0.          0.          0.         -0.          0.
 -0.          0.          1.43338673]
[-3.88507493e+00 -2.28630602e+00  4.34612755e-01  3.49422933e+00
  5.55611755e+00 -4.16058939e+00 -2.09278742e+00  4.52597472e-01
  3.44573363e+00  5.31519933e+00  0.00000000e+00 -0.00000000e+00
 -0.00000000e+00  0.00000000e+00 -1.43855790e+00 -5.15961138e-03
  0.00000000e+00 -0.00000000e+00  2.76929067e+00  1.74859914e+00
  1.79092421e-01]
max dr: 20.77668751361571
0.14646059195753014
[-1.38348250e+00 -6.91627105e-01  1.07854696e-03  5.23810091e-01
  2.32676513e+00 -1.80200276e+00 -1.01502377e+00  1.29294902e+00
  0.00000000e+00  1.18089386e+00 -0.00000000e+00 -0.00000000e+00
 -0.00000000e+00 -0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00 -0.00000000e+00
  1.85685364e-01]
[-3.73954

 32%|███▏      | 32/100 [00:04<00:08,  8.13it/s]

0.2611008162940702
[-0.86628822 -0.40828662 -0.27160967  1.06772056  0.39127078 -0.79607342
  0.          0.          0.7920916   0.         -0.          0.
 -0.         -0.          0.          0.          0.         -0.
  0.          0.          1.7161442 ]
[-4.01957282 -2.39435416  0.53543965  3.82404134  5.11638379 -4.34156225
 -2.34494118  0.4630157   3.40162654  5.50394681 -0.6077935  -0.
 -0.         -1.45238896  0.          0.          0.          0.
  0.         -0.          0.10033603]
max dr: 117.13575925126267
0.31893542193049396
[-1.08861602 -0.          0.46481     1.18273576  1.20711852 -0.55431087
 -0.77057837  0.24506807  0.55968074  0.71639701  0.         -0.
 -1.6903683  -0.         -0.         -0.          0.         -0.
  0.          0.          0.91491587]
[-4.01736773 -2.68437143  0.6693507   3.50878853  5.1988374  -4.18631496
 -2.36312545  0.4683875   3.83766312  5.68968106  0.          0.
 -0.         -0.          0.          0.          0.         -0.
  0.    

 34%|███▍      | 34/100 [00:04<00:08,  8.18it/s]

0.6525703368443612
[-0.86640997  0.01412356 -0.20785058  1.46268826  1.23458172 -0.7569245
 -0.02224789  0.46108238  0.7480188   0.4099461   0.         -0.
 -0.          0.         -0.          0.          0.         -0.
  0.          0.          1.45028742]
[-4.14846208 -2.35865975  0.22543037  3.42609622  5.27395097 -4.15711385
 -1.93783323  0.67075946  3.43065623  5.36569679  0.          0.
 -0.         -0.         -0.          0.         -0.         -0.
 -0.          0.          0.19204037]
max dr: 4.568595220393785
0.5895428273661948
[-6.83083290e-05 -0.00000000e+00 -3.55259318e-01  9.80150306e-01
  5.50203141e-01 -7.25591068e-01 -0.00000000e+00  0.00000000e+00
  9.98936317e-01  1.59090811e+00  0.00000000e+00 -0.00000000e+00
 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00  0.00000000e+00
  0.00000000e+00 -0.00000000e+00  0.00000000e+00  0.00000000e+00
  1.26500817e+00]
[-3.62733725 -2.30249319  0.17934105  3.48316945  5.32335653 -4.06638182
 -2.23085069  0.67586256  3.5539295   5

 36%|███▌      | 36/100 [00:04<00:07,  8.20it/s]

0.36949676488027283
[-0.36859761  0.24329769 -0.27049072  0.98648706  0.73745022 -0.90710792
 -0.13212173  0.3246741   0.76129488  0.76431042  0.          0.
  0.         -0.          0.         -0.         -0.         -0.
  0.         -0.          1.65189414]
[-3.93493282 -2.49068622  0.52801269  3.55785164  5.71289644 -3.87418917
 -2.35310055  0.2615778   3.66069401  5.33535011 -0.         -0.
 -0.         -0.         -0.         -0.         -0.         -0.
  0.          0.          0.02367754]
max dr: 12.507017499376753
0.6738360951169906
[-0.47278603 -0.         -0.          1.17212224  1.65524179 -0.15713729
 -0.         -0.07770201  1.50996747  1.27757844  0.          0.
  0.          0.         -0.         -0.          0.         -0.
  0.         -0.          1.35263415]
[-3.8368325  -2.76772862  0.40439386  3.3390648   5.71468392 -4.21585656
 -2.45977893  0.37822447  3.46538274  5.5925776   0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
  0.    

 38%|███▊      | 38/100 [00:04<00:07,  8.21it/s]

0.22671364430961105
[-0.04754035 -0.37906193  0.38232513  1.49554047  1.38512412 -1.30310273
 -0.83762803  0.41398694  0.83658172  1.28223021  0.          0.
  0.         -3.48173339  0.         -0.         -0.         -0.
 -0.          0.          0.93303373]
[-3.98493619 -2.20886538  0.5049086   3.83135033  5.14134014 -3.92940492
 -2.41859501  0.45508961  3.67075263  5.26841358  0.          0.
  0.          0.          0.         -0.          0.         -0.
 -0.         -0.          0.26341556]
max dr: 8.229285901600779
0.6283552771248744
[-0.04869905 -0.80716294  0.          0.55101215  1.98161817 -0.40550941
 -0.          0.45102265  0.17774706  1.62292615  0.          0.
  0.         -0.         -0.          0.         -0.          0.
 -0.          0.          0.99306989]
[-3.94267756 -1.70756337  0.08960123  3.74739005  5.08915842 -3.62035478
 -1.96859642  0.47935115  3.81888144  5.07282234 -0.         -0.
 -0.         -0.         -0.         -0.         -0.          0.
  0.     

 40%|████      | 40/100 [00:05<00:07,  8.23it/s]

0.12537500259452572
[-0.99487728 -0.          0.          0.46762377  1.21852432 -0.6566274
 -0.24586817 -0.          0.71354465  1.66793738 -0.         -0.
 -0.          0.          0.         -0.         -0.          0.
 -0.          0.          1.2954226 ]
[-3.8202644  -2.37751633  0.63066894  3.29018541  5.51762653 -3.84731397
 -2.56748354  0.6165298   3.61888273  5.24707347 -0.          0.93536372
  0.          2.04322151  0.42468063 -0.         -0.         -0.
 -0.          0.          0.12720286]
max dr: 10.743022228474697
0.35567027623881065
[-0.71287277 -0.31573585  0.30883483  0.82058516  2.46282692 -0.98321163
 -1.32555008  0.0207236  -0.41561738  1.1561015   1.7962499   0.
  0.         -0.         -0.         -0.          0.          6.26518889
  0.          0.          0.3138968 ]
[-3.79634804 -2.37603913  0.39439717  3.1277252   5.33180118 -3.77328362
 -2.44666936  0.58494632  3.5257217   5.181537   -0.          0.
 -0.          0.          0.         -0.          0.     

 42%|████▏     | 42/100 [00:05<00:07,  8.21it/s]

0.2610580981931965
[ 0.          0.          0.          0.98420863  0.56986041 -0.97234068
 -0.         -0.34591656  1.309698    1.01724961 -0.         -0.
 -0.         -0.         -0.         -0.          0.         -0.
  0.         -0.          1.45226919]
[-3.81874124 -2.49762727  0.51710567  3.60336082  5.52259451 -4.05267363
 -2.35244188  0.63008806  3.333997    5.16031677  0.          0.
  0.          0.          0.         -0.         -0.         -0.
  0.         -0.          0.02189771]
max dr: 28.952183703913356
0.6450187071253122
[-0.75531243 -0.21311369 -0.59801686  0.79099119  1.06581318 -0.90107683
 -0.20200059  0.08331361  0.7885909   0.97272729 -0.          0.
  0.          0.          0.          0.          0.          0.
 -0.          0.          1.19382827]
[-3.89611245 -2.44117598  0.27098807  3.4822319   5.16539751 -3.90357317
 -2.33174106  0.58924748  3.7021826   5.28194896  0.         -0.
 -0.          0.         -0.          0.         -0.          0.
  0.     

 44%|████▍     | 44/100 [00:05<00:06,  8.24it/s]

0.2871110836990304
[-0.72553989 -0.54351666 -0.28008883  0.4614598   1.98201318 -1.23548283
 -0.05307622  0.3051152   0.70729137  1.66942606 -0.          0.
  0.         -0.         -0.         -0.          0.          0.
 -0.         -0.          0.76470327]
[-3.95319477 -1.9620866   0.38135007  3.65560768  5.28896739 -4.16405036
 -2.24422654  0.31446837  3.55916679  5.28386123  0.          0.
 -0.         -0.          0.          0.          0.          0.
  0.         -0.          0.30323459]
max dr: 6.789957394229587
0.10144483608258048
[-0.29756464  0.26011692  0.          0.70225596  2.1237905  -0.35622966
 -0.73904746 -0.          1.1777063   1.07027509 -0.         -0.
 -0.          0.          0.          0.          0.          0.
  0.         -0.          0.99036183]
[-3.92581338 -2.18827491  0.48177861  3.60268412  5.2119873  -4.02808359
 -2.63933522  0.43066478  3.31222644  5.52114665 -0.         -0.
 -0.         -0.          1.63575192  0.         -1.2051906  -0.
 -1.62368

 46%|████▌     | 46/100 [00:05<00:06,  8.26it/s]

0.11713925770643752
[-0.69031534 -0.         -0.31400372  0.68255291  0.49396403 -0.
  0.         -0.43939382  0.97647085  1.05185013  0.          0.
  0.          0.         -0.          0.          0.          0.
 -0.         -0.          1.17129342]
[-3.99647103 -2.44982429  0.60724043  3.36322893  5.46347002 -4.03478451
 -2.60085757  0.51808541  3.58716773  5.42394878  0.82928818 -0.93513519
  0.          0.27690796 -0.          0.         -0.00923745  0.
 -0.         -0.93520955 -0.        ]
max dr: 20.441190958798273
0.5895899857130207
[-1.50219041 -0.60576728 -0.05433679  2.38107044  1.40448921 -1.0650353
 -0.41598522 -1.01080142  1.06323332  0.73017118 10.74280379  1.62383353
  0.          7.20539931  0.         -4.21510413  0.         -0.
  6.64426413  3.28502865  1.56646585]
[-3.87993047 -2.20705649  0.12263782  3.66659866  5.45911065 -4.1941278
 -2.09495149  0.26545182  3.97757624  4.78624856 -0.         -0.35336592
  0.         -2.21541232 -0.         -0.14718088 -0.5861479

 48%|████▊     | 48/100 [00:05<00:06,  8.24it/s]

0.04587596441816044
[-0.62837356 -0.         -0.          0.47354245  1.06034651 -1.25035275
 -0.3079866   0.06207468  0.74901188  1.61053101 -0.         -0.
  0.          0.         -0.         -0.          0.          0.
  0.         -0.          1.10644808]
[-3.76668592 -1.67975213  0.56886077  3.94684962  4.96648104 -3.90370547
 -2.20455139  0.28529462  3.88619292  5.02429444  0.96465651 -0.13402541
  1.16056406  0.         -0.         -0.          0.          0.
  0.          0.          0.48640754]
max dr: 8.524305809254765
0.5839875021557521
[-0.75774122  0.61023048  0.          0.30148786  0.65811336 -0.91114177
 -0.33379618 -0.04606956  1.10136171  0.         -0.         -0.
 -0.          0.          0.         -0.         -0.          0.
 -0.         -0.          1.22056373]
[-4.02867362 -2.09154962  0.61715186  3.57244142  5.40927468 -3.994724
 -2.71190667  0.37511286  3.30796437  5.44995126  0.         -0.
 -0.          0.          0.         -0.          0.         -0.
 -0

 50%|█████     | 50/100 [00:06<00:06,  8.16it/s]

0.469897609687178
[-0.3107425  -0.40218576 -0.          0.45654011  0.84948867 -1.09351568
 -0.5196002   0.          1.38981238  1.85422683  0.         -0.
 -0.         -0.          0.         -0.          0.          0.
  0.         -0.          1.23438471]
[-3.92570371 -2.23227231  0.47993424  3.52236646  5.28708591 -4.01039126
 -2.25732767  0.19893775  3.28198496  5.20419965  0.          0.
  0.         -0.         -0.92716363  0.          0.          0.
  0.01971699 -1.52499037  0.34225326]
max dr: 4.84876566744509
0.14904784877999988
[-0.          0.         -0.23097828  0.86344721  1.07147959 -0.57601125
 -0.23565559 -0.69318634  1.14062507  0.83147262 -0.          0.
  0.          0.         -0.         -0.          0.          0.
 -0.          0.          1.79221508]
[-4.16360227 -2.16023166  0.49106863  3.74138113  5.16406528 -4.20328754
 -2.41535622  0.52865989  3.54251866  5.26045328  0.         -0.
  0.          0.          0.         -0.          0.          0.
 -0.       

 52%|█████▏    | 52/100 [00:06<00:05,  8.16it/s]

0.7463589417985335
[-0.61164284 -0.30977892 -0.78735429  1.6400881   1.37885214  0.18405738
 -0.52787973  0.50595499  1.63079016 -0.         -0.          0.
  0.          0.          0.          0.         -0.          0.
 -0.          0.          1.79318112]
[-3.89122605 -2.50792334  0.253847    3.50741607  5.50586317 -3.90058011
 -2.59249137  0.19766808  3.44977184  5.24394364 -0.          0.6055786
 -0.         -0.          0.          0.         -0.          0.
 -0.         -0.          0.00833283]
max dr: 15.480054976010727
0.4032644760768236
[-1.36211381 -0.          0.          0.62675792  0.68565209 -1.07348039
 -0.46872208  0.00887236  0.18895176  1.35122275  0.         -0.
  0.          0.         -0.          0.         -0.         -0.
  0.         -0.          1.12344919]
[-3.78604567 -2.35392123  0.51950446  3.59750776  5.08970381 -3.72544029
 -2.59890738  0.6486279   3.60362425  5.53653764  0.          0.
 -0.          0.79369845  0.         -0.         -1.1697588   0.
 -

 54%|█████▍    | 54/100 [00:06<00:05,  8.17it/s]

0.32294611462597245
[-0.49791343  0.          0.          0.36764357  1.20012053 -2.49301815
 -0.         -0.          1.05877633  1.05309109 -0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.          0.          1.00759947]
[-4.04545522 -2.19679626  0.25783444  3.71762592  4.62631128 -3.80847247
 -2.1659612   0.46626831  3.61930093  5.05128116 -0.          0.
 -0.          0.          0.         -0.         -0.          0.
 -0.          0.          0.39328407]
max dr: 6.380415741129334
0.7815950494557377
[-0.01916597 -0.         -0.          0.          0.76127179  0.
  0.         -0.          1.06162471  0.66360422  0.         -0.
 -0.          0.          0.          0.         -0.          0.
  0.          0.          1.16094902]
[-3.88637785 -2.41505111  0.50617888  3.40742053  5.56806707 -4.15674752
 -2.37721611  0.44745278  3.10821029  5.86125584 -0.          0.
 -0.          0.         -0.          0.          0.          0.
  0.         -0. 

 56%|█████▌    | 56/100 [00:06<00:05,  8.15it/s]

0.6753031133179735
[-0.93643076 -0.          0.          0.28805128  1.32660593 -0.94920151
 -0.2420536  -0.57835057  1.29577197  1.13761924 -0.          0.
 -0.          0.          0.          0.          0.          0.
 -0.          0.          1.18517627]
[-3.93096741 -2.46557589  0.65513439  3.61898695  5.3445382  -4.09356255
 -2.44554417  0.31462932  3.04325775  5.11137683 -0.         -0.
 -0.          0.         -0.          0.          0.         -0.
  0.         -0.          0.148083  ]
max dr: 13.562560515674177
0.5675243034808889
[-1.25807163 -2.13180867  0.66077253  0.2347191   3.07120542 -1.18107344
 -2.18040296 -0.          0.74570484  2.67789905 -0.          0.
 -0.          0.          0.         -0.         -0.         -0.6298249
  3.71850098  0.          0.24778513]
[-4.16394179 -2.31020794  0.50479253  3.52087605  5.61109463 -3.98852794
 -2.55175997  0.39716543  3.47109498  5.44101703 -0.         -0.
  0.         -0.          0.         -0.          0.         -0.
 -

 58%|█████▊    | 58/100 [00:07<00:05,  8.18it/s]

0.2263956878407941
[-0.01708199 -0.58107964 -0.86234731  0.97776454  1.14152656 -1.48568096
  0.6518071  -0.23815532  0.4411258   1.77296856  0.         -0.
 -0.          0.          0.         -0.         -0.          0.
  2.76814937 -0.          1.51629378]
[-3.90618141 -2.78419722  0.55928166  3.49504567  5.28479839 -4.20986568
 -2.28216018  0.82778903  3.55496366  5.6213376  -0.         -0.
  0.         -0.          0.         -0.         -0.         -0.
 -0.77230473  0.         -0.01366296]
max dr: 30.699455506226442
0.17899608442330195
[-0.         -0.14165995  0.          1.27026884  0.20079653 -1.14833573
  0.44158734 -0.13138779  0.94354069  0.15505889  0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.          0.          1.33690162]
[-3.85835895 -1.88530767  0.38286797  3.59918947  4.96157939 -3.92523402
 -2.31660771  0.05822807  3.53821104  5.25731131 -0.          0.
  0.          0.          1.40096512 -0.         -0.         -0.
  0.    

 60%|██████    | 60/100 [00:07<00:04,  8.21it/s]

0.15665465161296432
[-0.38326465 -0.         -0.          1.0431958   0.83003858 -0.2892601
 -0.          0.          0.          1.73215547 -0.         -0.
 -0.          0.          0.         -0.          0.          0.
 -0.         -0.          1.3812976 ]
[-4.20746584 -2.07456576  0.33513923  3.62770472  5.09314237 -3.95388994
 -2.13675181  0.          3.43751231  5.18649451 -0.         -0.
  0.          0.         -0.          0.         -0.          0.
  0.          0.          0.36512635]
max dr: 51.37957064136837
0.46852731317304697
[-0.         -0.19521508 -0.3321596   0.          0.51460672 -0.
 -0.          0.          1.4880195   0.61879421 -0.         -0.
  0.         -0.          0.          0.          0.         -0.
 -0.         -0.          1.77210192]
[-3.9652097  -2.61460862  0.48732165  3.48388609  5.6392925  -4.06861879
 -2.63023979  0.53571601  3.5150143   5.47959105 -0.          0.
  0.         -0.          0.         -0.         -0.         -0.
  0.          0. 

 62%|██████▏   | 62/100 [00:07<00:04,  8.21it/s]

0.34847266724468784
[-0.76099775 -0.          0.31786398  1.50903865  0.06050304 -1.05844817
 -0.40022543 -0.          1.26661796  0.48193595 -0.         -0.
  0.          0.         -0.          0.         -0.          0.
 -0.         -0.          1.48264256]
[-3.87632499 -2.23425865  0.76651629  3.59373279  5.17967013 -3.81340204
 -2.61456113  0.22007708  3.27851048  5.34349128 -2.38071792 -0.95746772
 -0.58404412  0.83209503 -0.         -0.          0.89632588  0.
 -0.         -0.17095728  0.04295009]
max dr: 6.659938937735449
0.3843230915833761
[-0.30062712 -0.         -0.82603915  0.75648749  0.56111636 -0.68757108
 -0.24677885  0.          1.64725553  1.45051879 -0.         -0.
 -0.         -0.         -0.         -0.         -0.          0.
 -0.         -0.          1.50742513]
[-3.93400642 -2.40772704  0.09074122  3.59901894  5.28586089 -3.94816483
 -2.08094602  0.33681116  3.96874487  5.09962614 -0.         -0.
  0.         -0.         -0.          0.         -0.         -0.
 

 64%|██████▍   | 64/100 [00:07<00:04,  8.17it/s]

0.0759119875575549
[-2.07098135  0.         -0.11275146  1.70899689  0.54472622 -0.83794606
 -0.78901172  0.          1.05705992  0.87658651 -0.         -0.
  0.         -0.          0.         -0.          0.         -0.
  0.         -0.          1.52199464]
[-3.72589513 -2.3749357   0.4061598   3.3248211   5.15385039 -3.68345471
 -2.77846776  0.28149169  3.25619373  5.21762905 -0.         -0.
 -0.          0.          0.          0.         -0.         -0.
 -0.          0.          0.01160232]
max dr: 15.69287878426807
0.6746393338412047
[-1.48283844  0.68734822 -0.          1.7930707   0.73103905 -0.49596415
 -1.56689221 -0.21518353  1.22417616  0.25501192 -0.          0.28531291
  0.         -0.         13.76118019  0.          0.          0.
 -0.         -0.          1.33957092]
[-4.12768043 -2.19247015  0.07504183  3.58624475  5.37007574 -4.15403048
 -2.45856627  0.2254423   3.59279688  5.24795243  0.         -1.06495441
 -0.          0.          0.         -0.         -0.       

 66%|██████▌   | 66/100 [00:08<00:04,  8.20it/s]

0.42466067684676934
[-0.99781219 -0.35371945 -0.          2.07190589  1.47540458 -0.
 -0.         -0.          0.43583827  0.69896586  0.         -0.
  0.         -0.          0.          0.          0.          0.
 -0.         -0.          1.21054022]
[-3.94514144 -2.54614345  0.64377845  3.17979236  5.43459097 -3.76602144
 -2.16451297  0.35445235  3.25414147  5.6418716  -0.          0.
 -0.          1.3465418  -0.          0.         -0.57066452  0.
  0.          0.         -0.        ]
max dr: 65.36620795363352
0.30463705699527566
[-1.20210401 -1.17405674  0.281719    0.55522709  0.79100154 -1.58722606
 -0.81614898  0.85825625  0.          2.65713319 -0.          0.
 -0.          0.         -0.          0.          0.          0.
 -0.          0.          0.03846506]
[-3.74969716 -2.09718595  0.30358812  3.87090203  5.16557094 -3.96010941
 -2.09490005  0.30121049  3.64165552  5.13320979 -0.          0.
  0.          0.         -2.12280203  0.         -0.          0.
 -0.         -0.

 68%|██████▊   | 68/100 [00:08<00:03,  8.17it/s]

0.10692887917570515
[-0.69504047 -0.57501279 -0.07483634  0.64154982  1.76030212 -1.04486852
 -0.         -0.71691492  1.00319257 -0.          0.          0.
  0.         -0.         -0.          0.         -0.          0.
 -0.          0.          1.07451351]
[-4.04291896 -2.67613839  0.71398474  3.21924317  5.60339377 -4.21912922
 -2.7775151   0.76674012  3.01360012  5.79062988  1.77085223 -0.
  0.          0.         -0.          0.         -1.39141488  0.
  0.          0.         -0.25559637]
max dr: 23.09244014770133
0.6202271826367229
[-0.70620169 -0.39832646 -0.41412933  1.10724424  1.64995305 -1.4795763
 -1.19457216  0.76871374  0.85111297  2.14216215 -0.         -0.
 -0.         -0.          0.         -0.          0.          0.
 -0.          0.          0.94831916]
[-3.9537841  -2.53343932  0.3020695   3.45614044  5.52143549 -3.93640091
 -2.59818911  0.42917582  3.57242501  5.32436587  0.         -0.
  0.          0.         -0.         -0.          0.          0.
  0.      

 70%|███████   | 70/100 [00:08<00:03,  8.16it/s]

0.9050420735329356
[-1.21407878 -0.37026982 -0.          0.43013965  1.70830863 -1.37150696
  0.         -0.47191627  0.85869808  0.67728642 -0.          0.
 -0.          0.         -0.          0.         -0.         -0.
 -0.         -0.          1.12422097]
[-4.04427512 -2.03025936  0.25319061  3.77658326  5.00697697 -3.89516595
 -2.16560821  0.52880407  3.5637043   5.2976746   0.         -0.
 -0.         -0.         -0.         -0.          0.          0.
 -0.          0.          0.18002831]
max dr: 3.316133638693808
0.0032053706869545895
[-1.78408678 -0.32162936  1.14449558  0.21812299  2.06924506 -0.60705664
 -0.34603373  0.20163677  1.13876167  1.09886834  0.         -0.
  0.         -0.          0.         -0.          0.         -0.
 -5.30862615 -0.          0.33196717]
[-3.95478034 -2.4832369   0.43864749  3.27847723  5.07316333 -3.94928025
 -2.47773537  0.53139105  3.5879368   5.55213365  0.          0.
  0.         -0.          0.          0.77539852 -0.          0.
  0.   

 72%|███████▏  | 72/100 [00:08<00:03,  8.14it/s]

0.5250005582000543
[-1.16215046  0.37534621  0.88732257  0.23154832  0.89566508 -1.83693979
  0.78324423  0.66229332  0.59352519  1.15252761  4.04056309  0.4019742
  0.         -7.67729805  1.80286674  0.         -0.          7.65430747
 -4.45146489 -0.          1.17564742]
[-4.17918315 -2.2968506   0.52870976  3.84497126  5.03666011 -3.86149503
 -2.29146353  0.26826496  3.38569449  5.09512621 -0.         -0.
  0.         -0.         -0.          0.          0.         -0.22432358
 -0.         -0.          0.21641571]
max dr: 25.223440981848313
0.6852191478650025
[-2.28767307 -1.52396214  0.53550414  1.1642293   2.45917471 -0.82052115
 -0.32013253  0.75922388  0.4293394   1.54355876  0.         -0.
 -0.          0.         -0.         -0.         -0.         -0.
 -0.          0.          0.45942638]
[-4.11382533 -2.26599958  0.46970687  3.26951585  5.46324519 -3.90620381
 -2.18327204  0.42751423  3.71718912  5.41844517 -0.         -0.
 -0.         -0.15168906 -0.         -0.          0

 74%|███████▍  | 74/100 [00:09<00:03,  8.20it/s]

0.1508819282095898
[-0.78657252 -0.         -0.          0.9921906   2.75756954 -1.24827417
  0.          0.          1.03915294  0.52036486  0.          0.
  0.         -0.         -0.          0.          0.          0.
 -0.          0.          0.91150052]
[-3.80078377 -2.4752639   0.64433611  3.40693612  5.29889598 -3.92827536
 -2.45002734  0.61932425  3.51836275  5.34051328 -0.         -0.
  0.         -0.         -0.          0.          0.          0.
 -0.         -0.          0.06512083]
max dr: 25.515270175631414
0.184983882741781
[-1.41478524 -0.31880938 -0.          0.7073073   1.65410592 -0.58105456
 -0.97824783  0.          1.33540692  1.6225342  -0.         -0.
  0.          0.          0.          0.         -0.          0.
  0.         -0.          0.69206417]
[-3.98270877 -2.28159962  0.47800773  3.50396181  5.5993458  -4.29391921
 -2.61868979  0.25428146  3.5184484   5.3900497   0.          0.
 -0.          0.         -0.82227295  0.          0.97943312 -0.
 -0.      

 76%|███████▌  | 76/100 [00:09<00:02,  8.22it/s]

0.18782148944811428
[-0.79990987 -0.          0.42981541  0.          0.56415806 -0.93275502
  0.         -0.00723109  1.22574844  0.50839973  0.         -0.
  0.         -0.         -0.          0.         -0.          0.
 -0.          0.          1.51558567]
[-3.93772247 -2.16419082  0.26112988  3.61007139  5.30567159 -4.05441501
 -2.12188322  0.14150013  3.62085135  5.02124025  2.68990033 -0.10801172
 -0.         -0.35940343 -0.         -0.          0.         -0.
 -0.          0.          0.32191413]
max dr: 6.466806544218636
0.16325188072233376
[-0.         -0.          0.77288051  0.47207124  0.7936538  -0.99539332
  0.08954414  0.2794067   0.75923967  1.41899958 -0.          0.
  0.          0.          0.          0.          0.         -0.
 -0.         -0.          1.48755769]
[-3.85714356 -2.26039009  0.31458328  3.53444692  5.39159766 -4.32864931
 -2.27613502  0.46586996  3.67374318  5.10305835 -0.          1.26451744
 -0.         -0.31038239 -0.         -0.         -0.     

 78%|███████▊  | 78/100 [00:09<00:02,  8.27it/s]

0.46110704274455583
[-0.72549897 -0.          0.          0.65734322  2.29632687 -0.96273466
 -0.86114802 -0.01447339  0.8023661   1.46112309  0.          0.
  0.          0.          0.          0.         -0.          0.
 -0.         -0.          0.99846169]
[-3.75350062 -2.75023426  0.69758134  3.38336839  5.52849855 -3.82517464
 -2.35286816  0.31786567  3.35299312  5.71760589 -0.         -0.
  0.         -0.         -0.          0.          0.         -0.
  0.          0.          0.        ]
max dr: 16.966891325263155
0.6436647108667053
[-0.48880979 -0.28821214  0.          0.72894736  1.73426217 -1.46482841
 -0.84415375 -0.          0.19408453  1.34340482  0.          0.
  0.          0.         -0.          0.          0.          0.
  0.          0.          1.0379749 ]
[-3.93400089 -2.83838833  0.62938938  3.24716355  5.49219403 -3.92601934
 -3.27207813  0.49929265  3.30102126  5.44946985 -0.          0.
  2.26690219  0.         -2.08283567  0.84127595 -0.          0.
  1.9513

 80%|████████  | 80/100 [00:09<00:02,  8.18it/s]

0.8174364686944924
[-1.16018728 -1.15515051  1.24063785  0.98257108  2.84561549 -0.98404053
 -1.68117657 -0.24065904  0.31335267  1.5604707   0.          8.55680013
 -0.          0.          6.7464827  -0.         -0.         -0.
  3.25392462 -0.          0.31450645]
[-4.06140268 -2.33487865  0.65399465  3.30256628  5.2717093  -4.02831826
 -2.61800595  0.5848448   3.53977811  5.28356205  0.          0.
 -0.         -0.          0.         -0.         -0.          0.
  0.         -0.         -0.        ]
max dr: 21.47781796331378
0.703770438677242
[ 0.56337354 -1.44816353 -0.         -0.41804765  2.50073034 -1.55169119
 -0.6473604  -0.          0.52895281  2.14362614  0.         -0.
 -0.          0.          0.          0.         -0.         -0.
  0.         -0.          0.2314025 ]
[-3.71331825 -2.55544929  0.51944545  3.45326667  5.5321579  -4.29574155
 -2.26125875  0.65105182  3.70183878  5.04025288 -1.18703438  0.
  0.         -0.         -0.         -0.          0.32080615 -0.
 -0

 82%|████████▏ | 82/100 [00:10<00:02,  8.20it/s]

0.4814923560963992
[-1.36671665 -0.03738954 -0.44353936  1.32890814  1.89123043 -0.62310057
 -0.          0.          0.12922327  0.87983638 -0.         -0.
  0.         -0.          0.         -0.         -0.          0.
  0.          0.          1.01470087]
[-4.01333286 -2.33813453  0.48645107  3.60660369  5.17669468 -3.88719375
 -2.49146315  0.22633608  3.43910355  5.75938118  0.          0.6510902
 -0.         -0.         -0.          0.43767163  0.          0.
 -0.         -0.42772912 -0.02393732]
max dr: 6.213990588756544
0.15661035041716032
[-0.58680769  0.         -0.80378315  2.04463989  0.99741836 -1.73539777
 -1.05729966 -0.81600373  1.0966208   0.445727    0.          0.
  0.          0.         -0.         -0.         -0.         -0.
 -0.         -0.          1.48152009]
[-4.04898255 -2.05191912  0.30855048  3.45041682  5.3932755  -3.68422085
 -2.22236037  0.16204877  3.72544878  5.23124944  0.         -0.
  0.          0.         -0.         -0.          0.         -0.
  

 84%|████████▍ | 84/100 [00:10<00:01,  8.21it/s]

0.4352161113124535
[-0.18214008 -0.60145017 -0.          1.15646745  1.79152904 -0.89414409
  0.         -0.05036418  0.55190455  1.38505581  0.         -0.
 -0.         -0.         -0.         -0.          0.          0.
 -0.          0.          1.26793141]
[-3.67131612 -2.34602581  0.40186871  3.44858045  5.2376928  -3.93646561
 -2.39461892  0.58081066  3.33073429  5.27195741  0.         -0.
 -0.          0.          0.          0.          0.         -0.
  0.         -0.          0.08943611]
max dr: 9.207085217188064
0.6208714742087547
[-1.9095476  -1.13862055 -0.02632688  0.40031312  1.50383433 -0.62908587
  0.          0.          0.89576358  1.1833406  -0.         -0.
  0.         -0.          0.         -0.         -0.          0.
 -0.          0.          1.04619073]
[-4.07897582 -2.42989843  0.70171789  3.47687717  5.67637857 -4.10586757
 -2.59330262  0.39329023  3.40785166  5.75653221  0.          0.84626593
  0.20773758  0.          0.         -0.         -0.         -0.
  

 86%|████████▌ | 86/100 [00:10<00:01,  8.22it/s]

0.37199565584692573
[ -1.46842992  -0.75720275   0.57694161   1.40836497   1.67430217
  -1.57282766  -0.55787443  -1.06807271   0.90270019   2.37498065
  -8.12099996 -10.73984065  -1.56480359   1.06651742   0.
   8.52891725   2.84615645   0.          -7.38240542  -0.61691774
   0.8196337 ]
[-3.81324764 -2.4368648   0.43612271  3.55687244  5.25404969 -3.79479667
 -2.44948021  0.34426053  3.45237133  5.15638806 -0.         -0.
 -0.77899704  0.          0.89632649  0.          0.         -0.
  0.         -0.          0.07633462]
max dr: 56.76060474146235
0.3835804270226848
[-0.51833547  0.58616079 -0.80411289  1.41146501  0.67711386 -1.22879527
 -0.46905393 -0.40288169  1.41582915  0.15798253  0.         -0.
  3.07739647  0.          0.          0.         -0.          0.
 -5.78973653 -0.          1.36149878]
[-3.79533073 -2.58028765  0.55924793  3.6058373   5.29172307 -4.37006236
 -2.27975027  0.33816912  3.91377539  5.22184959  1.2410421  -0.
  3.04803442  0.         -0.29040443 -0.1993

 88%|████████▊ | 88/100 [00:10<00:01,  8.18it/s]

0.34781433737522083
[-1.08724146 -0.          0.          1.18413732  0.02308546 -0.53656715
 -0.         -1.41640838  0.98535804  0.27666263  0.          0.
 -0.          0.         -0.          0.          0.         -0.
 -0.         -0.          1.61493272]
[-3.89282627 -2.15159079  0.15729515  3.80559254  5.10638912 -3.86404066
 -2.26843318  0.09604018  3.32645873  5.27994142 -0.         -0.
 -0.         -0.         -0.          0.         -0.          0.
  0.          0.          0.251761  ]
max dr: 4.916560214704964
0.38410813367534813
[ -0.53161786  -0.28975662   0.44882967   1.1709757    2.86323854
  -0.58255087   0.06857279   1.60078515   0.52384017   1.46909242
  -0.           0.80248449   0.          -4.20725184 -10.80276819
  -9.66011835  -6.21804328   4.569829    -0.           2.89742043
   0.93443719]
[-3.79109577 -2.16482835  0.57555144  3.29430359  5.26428954 -4.16062455
 -1.92661206  0.40964611  3.29046268  5.39537737  0.         -0.
 -0.          0.          0.       

 90%|█████████ | 90/100 [00:11<00:01,  8.16it/s]

0.16718563991129676
[-0.95793371  0.          0.          1.05966385  1.20926413 -0.
 -0.          0.          0.          0.67370294 -0.         -0.
  0.         -0.         -0.         -0.         -0.         -0.
  0.         -0.          1.43482049]
[-4.15240213 -2.53554632  0.27330069  3.77696524  5.37641172 -3.66012928
 -2.8171885   0.35445551  3.33200763  5.0248034  -1.8764778  -0.
 -0.          0.          0.          0.         -0.82390392  0.
 -1.99624146  1.17000393  0.04696952]
max dr: 49.273515032645555
0.32062588246290735
[-0.76370407 -1.20396601 -0.03815044  1.52820793  2.03912269  0.06820031
 -1.06728092 -0.69145283  0.18731685  1.57725565 -2.71253138  0.
 -0.         -0.         -0.          0.          0.         -0.
 -0.          0.          0.98288017]
[-3.95464486 -2.18147677  0.17904396  3.71977072  5.41510486 -3.93390743
 -2.35316238  0.41830223  3.54446481  5.40183007  0.          0.
  0.         -0.          0.         -0.          0.         -0.
  0.         -0

 92%|█████████▏| 92/100 [00:11<00:00,  8.21it/s]

0.48984396037502953
[-0.          0.         -0.11767479  0.54369616  0.3669091  -0.
 -0.          0.          1.43401037  0.05313836  0.         -0.
 -0.         -0.          0.         -0.          0.         -0.
  0.          0.          1.57044855]
[-4.01681033 -1.92483098  0.33527827  3.46155844  5.30831442 -4.19747623
 -2.07891706  0.19941245  3.399565    5.41027255  0.          0.
 -0.          0.          0.          0.         -0.          0.
 -0.          0.          0.26557275]
max dr: 16.04158431919441
0.5394248050680317
[-0.84663556 -0.          0.          0.5554978   2.29258452 -0.07978297
 -1.5843828  -0.03494782  1.36651997  1.20220849 -0.         -0.
  0.         -0.         -0.          0.         -0.          0.
  0.          0.          0.70327729]
[-4.06437282 -2.53207535  1.00298886  3.63673475  5.72390126 -4.06140789
 -2.50359499  0.43297393  3.46216668  5.98858934  0.         -0.
  0.          0.         -0.          0.         -0.         -0.
  0.         -0. 

 94%|█████████▍| 94/100 [00:11<00:00,  8.14it/s]

0.06242628903923231
[-0.55932831  0.35138823 -0.75725669  1.06592062  1.11142287 -1.08381557
 -0.48133201 -0.89842259  1.48125312  1.11504663  0.         -0.
  1.78929719 -0.          0.         -3.42935348  0.          0.
 -5.76189624  5.16904726  1.56961694]
[-3.5868607  -2.46772481  0.2470588   3.29591793  5.55681963 -4.07983681
 -2.42134212  0.35800708  3.53362387  5.50232278  0.          0.
 -0.          0.          0.          0.         -0.          0.
 -0.          0.         -0.        ]
max dr: 10.772440207991483
0.15950766961985385
[-0.44536098 -0.30232182  0.          0.82447181  0.78340095 -1.08388251
  0.47831749 -0.48802498  1.8035265   0.00220884 -0.          0.
  0.         -0.         -0.          0.          0.         -0.
 -0.         -0.          1.61796517]
[-4.05153253 -2.33589783  0.77179183  3.6788869   5.72065352 -3.89185262
 -2.4462785   0.29167146  3.22586645  5.33271421  0.         -0.
  0.          0.          0.         -0.         -0.          0.
 -0.   

 96%|█████████▌| 96/100 [00:11<00:00,  8.17it/s]

0.2760963614993237
[-0.42386855 -0.49348094 -0.          1.23697035  1.38192101 -0.71567561
 -0.         -0.          0.36176871  1.76894054  0.         -0.
  0.          0.         -0.          0.         -0.         -0.
  0.          0.          1.13564655]
[-4.01499503 -2.73858502  0.68475686  3.65514355  5.33675412 -4.13489705
 -2.31146102  0.28037038  3.57664587  5.44048341 -0.         -0.82706973
  0.         -0.          0.          0.         -0.          0.
 -0.         -1.83905737  0.0277337 ]
max dr: 6.705546007450942
0.21804226886814826
[-0.49042767  0.          0.07703466  0.59511239  0.46075093 -0.59937222
  0.          0.          1.25191196  0.         -0.          0.
  0.          0.         -0.         -0.          0.          0.
  0.          0.          1.83531316]
[-3.71949304 -2.43440621  0.22048728  3.4455151   5.42195786 -3.43605725
 -2.35766997  0.42885235  3.44753519  5.46029315 -0.         -0.
  0.         -0.          0.         -0.          0.          0.
 

 98%|█████████▊| 98/100 [00:12<00:00,  8.17it/s]

0.8314947455613325
[-0.39103019 -0.         -0.          1.1352752   0.40776433 -0.82532391
  0.          0.          1.05255546  0.98265742  0.         -0.
  0.          0.          0.          0.          0.          0.
 -0.          0.          1.36584576]
[-3.92660538 -2.48686537  0.55209484  3.54702118  5.47724559 -3.9564347
 -2.34445359  0.78691855  3.57824294  5.38861354 -0.          0.
 -0.         -0.         -0.          0.          0.         -0.
 -0.          0.         -0.10201102]
max dr: 12.043258598843604
0.34465508643274345
[-0.2729133  -0.         -0.02219904  0.49418257  0.67901457 -0.7953807
 -0.51040313 -0.48737595  0.57595791  0.92815634  0.         -0.
  0.         -0.         -0.         -0.          0.         -0.
 -0.         -0.          1.4773036 ]
[-4.14907840e+00 -2.46750927e+00  6.02874441e-01  3.30115956e+00
  5.59069797e+00 -3.87876569e+00 -2.22076514e+00  5.72288261e-01
  3.52816879e+00  5.61055702e+00  6.21429839e-01  6.51404621e-01
 -1.31312482e+00  

100%|██████████| 100/100 [00:12<00:00,  8.10it/s]

0.3094734034270331
[-0.71547311  0.          0.3034285   1.64302686  0.         -0.
 -0.         -0.10678959  1.23558281  0.84670689  0.         -0.
 -0.         -0.          0.          0.         -0.         -0.
 -0.          0.          1.32220278]
[-3.81219497 -2.4456239   0.44525427  3.61458358  5.13916846 -4.05262524
 -2.12071695  0.          3.61190758  5.31214992 -0.          0.
  0.          0.          0.         -0.          0.          0.
  0.         -2.4135895   0.26266692]
max dr: 16.773998477998624
0.2692371621757512





In [72]:
c/100

0.05

## Check density ratio scale

In [55]:
'''
Define some parameters for data generation
'''
nt, p,q = 1000, 10, 100
est_size = 500
ns = est_size + 200
est_ratio = 200/ns


s = np.array([-1, -0.5, 0, 1, 1.5,-1, -0.5, 0, 1, 1.5])
t = np.array([ -1, -1, 0.5 , 0.5, 1,-1, -1, 0.5 , 0.5, 1])
t = (t+3*s)
u = np.array([ 0, -1, 0.5, -0.5, 1,0, -1, 0.5, -0.5, 1])

In [60]:
def xz_ratio(X, Z, z_diff = 1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,Alpha_s=0, Alpha_t = 100,effect=1, z_diff = 1)

est_v_dr,X_s,Z_s,V_s,Y_s = est_v_ratio(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, test_size=est_ratio)
est_dr = xz_ratio(X_s, Z_s) * est_v_dr
true_dr = true_density_ratio(X_s, Z_s, V_s, s, t, p, q)

[-1.36211716 -0.24089191 -0.1545009   1.36262841  1.24278281 -0.7649584
 -0.17888822  0.37845853  0.6361632   1.48072965 -0.         -0.
 -0.         -0.         -0.         -0.         -0.          0.
  0.         -0.          0.         -0.10135055 -0.         -0.
  0.         -0.07143916  0.          0.          0.          0.
  0.         -0.         -0.         -0.         -0.         -0.
 -0.          0.         -0.         -0.          0.          0.
 -0.          0.          0.          0.          0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.          0.         -0.         -0.          0.         -0.
  0.          0.46859665 -0.         -0.         -0.         -1.58146136
  0.         -0.         -0.          1.0756933  -0.          0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.          0.         -0.          0.         -0.          0.
  0.         -0.         -0.15910225  0.          0.         -1.54104467
 -

In [61]:
print(max(true_dr))
print(np.mean(abs((true_dr - est_dr.reshape(-1,1)))))

[13.07777322]
0.8244227702728405


In [12]:
sum = 0
for i in range(200):
    sum+=I(Bin_pvalue(V_source[i], X_source[i], Z_source[i], V_source[i], Model_X, E_X, L=3, K=20), 3)*True_d[i]
sum

array([0.])

In [13]:
sum = 0
for i in range(1000):
    sum+=I(Bin_pvalue(V_target[i], X_target[i], Z_target[i], V_target[i],Model_X,E_X, L=3, K=20), 3)
sum*200/1000

0.0

In [41]:
WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X = Model_X, E_X = E_X, L=3, K=20, density_ratio = True_d)

# Check

In [189]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm, chi2, multivariate_normal
from sklearn import linear_model
from sklearn.linear_model import LassoCV
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from tqdm import trange
from densratio import densratio
from numpy import linalg as la
import momentchi2 as mchi
from sklearn.linear_model import LinearRegression


# Function for estimate the conditional model of V|X,Z
from sklearn.linear_model import ElasticNetCV


def est_v_ratio(X_s, Z_s, V_s, Y_s, X_t, Z_t, V_t, test_size=0.5):
    '''
    Input:
    the X,Z,V data on both source and target (ndarray)
    Return:
    v_ratio_test: the density ratio for V|X,Z(ndarray)
    X_s_test,Z_s_test, V_s_test, Y_s_test :the (X,Z,V,Y)s data used for the test(ndarray)
    '''
    # Train-test split for source domain
    X_s_train, X_s_test, Z_s_train, Z_s_test, V_s_train, V_s_test, Y_s_train, Y_s_test = train_test_split(
        X_s, Z_s, V_s, Y_s, test_size=test_size, random_state=42
    )
   
    # Concatenate the X and Z data
    D_s_train = np.concatenate((Z_s_train, X_s_train), axis=1)
    model_s = ElasticNetCV(cv=5)
    model_s.fit(D_s_train, V_s_train.ravel())
    
    # Estimate the variance of the V|X,Z model for testing data
    D_s_test = np.concatenate((Z_s_test, X_s_test), axis=1)
    V_pred_s_test = model_s.predict(D_s_test)
    residual_s_test = V_s_test.ravel() - V_pred_s_test
    est_var_s_test = np.var(residual_s_test)

    # Estimate the V probability for each sample in the testing set
    V_s_prob_test = norm.pdf(V_s_test.ravel(), loc=V_pred_s_test, scale=np.sqrt(est_var_s_test))

    # No need to test train split, use all samples in the target domain
    X_t_train, X_t_test, Z_t_train, Z_t_test, V_t_train, V_t_test = train_test_split(
        X_t, Z_t, V_t, test_size=0.01, random_state=42
    )

    # Train the conditional model V|X,Z on target
    D_t_train = np.concatenate((Z_t_train, X_t_train), axis=1)
    model_t = ElasticNetCV(cv=5)
    model_t.fit(D_t_train, V_t_train.ravel())

    # Estimate the variance of the V|X,Z model for testing data
    D_t_test = np.concatenate((Z_t_test, X_t_test), axis=1)
    V_pred_t_test = model_t.predict(D_t_test)
    residual_t_test = V_t_test.ravel() - V_pred_t_test
    est_var_t_test = np.var(residual_t_test)

    V_pred_st_test = model_t.predict(D_s_test)

    V_t_prob_test = norm.pdf(V_s_test.ravel(), loc=V_pred_st_test, scale=np.sqrt(est_var_t_test))

    v_ratio_test = V_t_prob_test / V_s_prob_test

    return v_ratio_test, X_s_test, Z_s_test, V_s_test, Y_s_test

# Function for CRT statistic calculation for each sample
def T_statistic(y, x, z, v, E_X):
    '''
    Input:
    - y, x, z, v: Sample data
    - E_X: Expectation function E_X(z, v)

    Return:
    - Test statistic for the sample
    '''
    d_x = E_X(z,v)
    
    # Return the test statistic
    return y*x

# Function for ranking the pvalues of all conterfeits and assign the sample the bin index
def Bin_pvalue(y, x, z, v, model_X, E_X, L, K):
    '''
    Input:
    - y, x, z, v: Sample data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - Bin index for the sample
    '''
        
    # The total number of bins
    M = L * K - 1
    cnt = 0
    
    # Calculate the test statistic of current sample
    t_stat = T_statistic(y, x, z, v, E_X)
    
    # Generate M counterfeits
    for i in range(M):
        x_ = model_X(z, v)
        if t_stat > T_statistic(y, x_, z, v, E_X):
            cnt=cnt+1
    # Find the bin index for the current sample 
    return cnt // K


# The main function for csPCR test
def PCRtest( Y, X, Z, V,model_X, E_X,density_ratio, L, K,covariate_shift):
    '''
    Input:
    - Y, X, Z, V: Data arrays
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - density_ratio: Density ratio for V|X,Z
    - L: Number of bins
    - K: Number of counterfeits per bin
    - covariate_shift: Boolean indicating whether to consider covariate shift

    Return:
    - W: Array of weights in each bin
    - Test statistic for csPCR test
    '''
    n = Y.size
    # initialize the weight in each bin
    W = np.array([0.0]*L)

    # Loop over all samples
    for j in range(n):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        
        # With Covariate shift
        if covariate_shift == True:
            ind = Bin_pvalue(y, x, z, v,model_X, E_X, L, K)
            W[ind] += density_ratio[j]
           
        # Normal PCR test
        if covariate_shift == False:
            W[Bin_pvalue(y, x, z, v, L, K, model_X, E_X)] += 1
    
    # Return the weights and the test statistic for csPCR test
    return W, L/n * np.dot(W - n/L, W - n/L)


# Function for generating the covariance matrix of the test statistic distribution
def generate_cov_matrix(Y, X, Z, V, model_X, E_X, density_ratio,L, K):
    '''
    Input:
    - Y, X, Z, V: Data arrays
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - density_ratio: Density ratio for V|X,Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - Covariance matrix for the test statistic distribution
    '''
    
    n = Y.size
    diag = np.array([0.0]*L)
    
    # Loop over all samples and add corresponding weights
    for j in range(n):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        diag[Bin_pvalue(y, x, z, v,model_X, E_X,L, K)] += (density_ratio[j]**2)
        
    diag = L*(diag/n)- 1/L
     # Fill all entries with 1/L
    covariance_matrix = np.full((L, L), -1/L)
    
    # Set diagonal entries to 1 - 1/L^2
    np.fill_diagonal(covariance_matrix, diag) 
    
    # Return the 
    return covariance_matrix


import scipy.stats as stats


# Calculate chi-squared p-value
def chi_squared_p_value(chi_squared_statistic, df):
    '''
    Input:
    - chi_squared_statistic: Observed chi-squared test statistic
    - df: Degrees of freedom

    Return:
    - Calculated p-value
    '''

    p_value = 1 - stats.chi2.cdf(chi_squared_statistic, df)
    return p_value


# Calculate the normal quadratic form p-value
def moment_chi_pvalue(statistic, cov1):
    '''
    Input:
    - statistic: Test statistic
    - cov1: Covariance matrix

    Return:
    - Calculated p-value using momentchi2 library
    '''
    weight = la.eigh(cov1)[0]

    p_value = 1-mchi.hbe(coeff=abs(weight), x=statistic)
    
    return p_value


# Function for the testing procedure
def Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X, xz_ratio, L=3, K=20, test_size = 0.5):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Estimate the density ratio by the V|X,Z conditional model using Lasso
    v_dr, X_source, Z_source, V_source, Y_source = est_v_ratio(X_source, Z_source, V_source,Y_source, X_target, Z_target, V_target)
    
    # Calculate the xz_ratio by the given function
    xz_dr = xz_ratio(X_source,Z_source)
    # Calculate the estimated density ratio
    est_dr = v_dr * xz_dr
    
    print('max dr: ' + str(max(est_dr)))
    # Estimate the covariance matrix for p-value calculation
    cov1 = generate_cov_matrix(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, density_ratio = est_dr)
    
    # Get the csPCR test statistic
    w, statistic = PCRtest(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, covariate_shift = True, density_ratio = est_dr)
    
    #print(w)
    # Call moment chi function to get the final p-value for the test
    p_value = moment_chi_pvalue(statistic, cov1)
    
    return p_value



def Test_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X,L=3, K=20, true_dr = None):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Calculate the true density ratio
    est_dr = true_dr(X_source, Z_source, V_source)
    
    print(max(est_dr))
    # Estimate the covariance matrix for p-value calculation
    cov1 = generate_cov_matrix(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, density_ratio = est_dr)
    
    # Get the csPCR test statistic
    w, statistic = PCRtest(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, covariate_shift = True, density_ratio = est_dr)
    # print(statistic)
    # Call moment chi function to get the final p-value for the test
    p_value = moment_chi_pvalue(statistic, cov1)
    
    return p_value
    
    



# Function for power enhancement version PCR test
def PCRtest_Powen(Y, X, Z, V, X_, Z_, V_, model_X, E_X, L, K, density_ratio):

    y_ind, v_ind, c = [], [], []
    W = np.array([0.0]*L)
    ns, nt = V.size, V_.size
    
    g_lst = np.zeros(L)
        
    for j in range(ns):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        ind_y = Bin_pvalue(y, x, z, v, model_X, E_X, L, K)
        ind_v = Bin_pvalue(v, x, z, v, model_X, E_X, L, K)
        y_ind.append(ind_y)
        v_ind.append(ind_v)
    
    y_ind = np.array(y_ind)
    v_ind = np.array(v_ind)
        
    density_ratio=np.array(density_ratio).ravel()
    for l in range(L):
        a = np.array([1 if x == l else 0 for x in y_ind])
        b = np.array([1 if x == l else 0 for x in v_ind])
        a_d = a-(a@density_ratio.T)/density_ratio.sum()
        b_d = b-(b@density_ratio.T)/density_ratio.sum()

        g_lst[l] = ((density_ratio*a_d)@b_d.T)/((density_ratio*b_d)@b_d.T)
    
    print(g_lst)

        
    for j in range(nt):
        x_, z_, v_ = X_[j], Z_[j], V_[j]
        ind_v_ = Bin_pvalue(v_, x_, z_, v_,model_X, E_X, L, K)
        W[ind_v_] += (ns/nt)*g_lst[ind_v_]
        c.append(ind_v_)

    c = np.array(c)
    for j in range(ns):
        W[y_ind[j]] += density_ratio[j]
        W[v_ind[j]] -= density_ratio[j]*g_lst[v_ind[j]]   

    return W, L/ns * np.dot(W - ns/L, W - ns/L),y_ind, v_ind, c, g_lst


def I(a, b):
    if a == b:
        return 1
    else:
        return 0
    

# Generate covariance matrix for the power enhancement version
def generate_cov_matrix_powen(ind_y_source, ind_v_source, ind_v_target ,g_lst, L, K, density_ratio):
    ns = ind_y_source.size
    nt = ind_v_target.size
    
    ad = []
    num_row = ns + nt
    num_col = L
    for l in range(L):
        row = []
        for s in range(ns):
            row.append(density_ratio[s]*(I(l, ind_y_source[s]) - g_lst[l]*I(l, ind_v_source[s])))
        for t in range(nt):
            row.append(ns/nt*g_lst[l]*I(l, ind_v_target[t]))
        ad.append(row)
    ad = np.array(ad)
    
    cov_matrix = np.cov(ad, rowvar=True)
    return cov_matrix*L/ns*(ns+nt)


# Function for power enhancement implementation
def Test_pe(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X, xz_ratio, L=3, K=20, test_size = 0.5):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Estimate the density ratio by the V|X,Z conditional model using Lasso
    v_dr, X_source, Z_source, V_source, Y_source = est_v_ratio(X_source, Z_source, V_source,Y_source, X_target, Z_target, V_target, test_size = test_size)
    # Calculate the xz_ratio by the given function
    xz_dr = xz_ratio(X_source,Z_source)
    # Calculate the estimated density ratio
    est_dr = v_dr * xz_dr
    
    
    #print(max(est_dr))
    WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X, E_X, L, K, est_dr)
    print(WV)
    cov = generate_cov_matrix_powen(a, b, c, g, L, K, density_ratio = est_dr)

    p_value = moment_chi_pvalue(statistic, cov)

    return p_value



def Test_pe_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X,L=3, K=20, true_dr = None):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Calculate the true density ratio
    est_dr = true_dr(X_source, Z_source, V_source).reshape(-1)
    
    #print(est_dr)
    
    WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X, E_X, L, K, est_dr)
    print(WV)
    cov = generate_cov_matrix_powen(a, b, c, g, L, K, density_ratio = est_dr)

    p_value = moment_chi_pvalue(statistic, cov)
    
    return p_value
    