In this short demo, I will show how to use the csPCR test on a simulated dataset.

In [1]:
'''
First, we generate a dataset, we can change the Alpha_s, Alpha_t and effect parameter to change the distribution of
the generated dataset.
'''

import numpy as np

def generate(ns, nt, p,q, s, t, u, Alpha_s=1, Alpha_t = 0,effect=1, z_diff = 0.1):
    Zs_null = np.random.normal(1,0.1, (ns, q))
    Zt_null = np.random.normal(1,0.1, (nt, q))
    
    Z_source = np.hstack((np.random.normal(0, 1, (ns, p)) , Zs_null))
    Z_target = np.hstack((np.random.normal(z_diff, 1, (nt, p)) , Zt_null))
    
    X_source = Z_source[:, :p] @ u + np.random.normal(0, 1, ns)
    X_target = Z_target[:, :p] @ u + np.random.normal(0, 1, nt)

    V_source = Z_source[:, :p] @ s + Alpha_s * X_source + np.random.normal(0, 5, ns)
    V_target = Z_target[:, :p] @ t + Alpha_t * X_target + np.random.normal(0, 5, nt)
    
    # V_source = Z_source[:, :p] @ s + 2*X_source 
    # V_target = Z_target[:, :p] @ t - 2*X_target
    
    Y_source = (Z_source[:, :p].sum(axis=1))**2 + effect*V_source + np.random.normal(0, 1, ns)+ 0 * X_source
    Y_target = (Z_target[:, :p].sum(axis=1))**2 + effect*V_target + np.random.normal(0, 1, nt)+ 0 * X_target
    
    
    return Y_source.reshape(-1, 1), X_source.reshape(-1, 1), V_source.reshape(-1, 1), Z_source,\
           Y_target.reshape(-1, 1), X_target.reshape(-1, 1), V_target.reshape(-1, 1), Z_target


'''
According to the generation process, 
we write 2 functions to calculate the xz_ratio and true density ratio of the data
'''


def true_density_ratio(X, Z, V, s=np.array([-0.5, -1.0,  0.3, -0.9, -1.5 ]), t=np.array([ 1.5, -0.2 ,  0.06 , -1.4, -0.5]),\
                       p=5,q=50, Alpha_s = 1, Alpha_t = 0,z_diff = 0.1):
    ratios = []
    size = V.size
    for i in range(size):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        vs_prob = norm.pdf(V[i], loc=Z[i][:p]@s + Alpha_s*X[i], scale =5)
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        vt_prob = norm.pdf(V[i], loc=Z[i][:p]@t + Alpha_t*X[i], scale =5)
        ratios.append((zt_prob*vt_prob)/(zs_prob*vs_prob))
    
    return np.array(ratios)

In [2]:
'''
Define some parameters for data generation
'''
nt, p,q = 2000, 5, 50
est_size = 200
ns = est_size + 500
est_ratio = 500/ns
s = np.array([-0.5, -1.0,  0.3, -0.9, -1.5])
t = np.array([ 1.5, -0.2, 0.06 , -1.4, -0.5])
u = np.array([ 0.1, -1.1,  0.4, -0.6, -0.3])


In [3]:
from csPCR_functions import *

In [4]:
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,effect=1)

'''
Since this method is model_X framework, we assume that the X|Z,V and Z distribution are known. So we need to define three 
extra functions as the input of the test.
1. Model_X: Should be a function of the conditional model of X|Z,V
   Input: z, v values(float)
   Return: one X sample follows the X|Z,V distribution

2. E_X: Should be a function of the conditional expectaion of X|Z,V
   Input: z,v values(float)
   Return: the calculated conditional expectation E[X|Z,V]

3. xz_ratio: Should a function for calculating the (X,Z) density ratio (i.e. P_t(X,Z,V)/P_s(X,Z,V))
   Input: X, Z (ndarrays)
   
   Return: density ratio array (ndarray)
'''

def Model_X(z, v):
    # Conditional distribution of X|Z
    return z[:5] @ u + np.random.normal(0, 1, 1)

def E_X(z, v):
    # Conditional expectation of X|Z
    return z[:5] @ u



def xz_ratio(X, Z, z_diff = 0.1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)

'''
To use the Test function, need input of source and target X,Z,V and source Y data into the function,
Also take 3 functions mentioned above as inputs.
'''
Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X, xz_ratio= xz_ratio, test_size = 0.5, L=2)


[-0.46475516 -0.61001426  0.         -0.5048131  -1.3306911   0.
  0.         -0.          0.          0.         -0.          0.
  0.         -0.          0.          0.          0.          0.
 -0.         -0.         -0.         -0.          0.          0.
  0.         -0.         -0.         -0.          0.          0.
 -0.         -0.         -0.          0.         -0.         -0.
  0.         -0.          0.         -0.         -0.         -0.
 -0.         -0.          0.          0.         -0.          0.
 -0.         -0.         -0.         -0.         -0.          0.
  0.          1.34777263]
[ 1.30276693e+00  9.07644578e-05  8.60992510e-02 -1.28917402e+00
 -3.55035739e-01  0.00000000e+00 -0.00000000e+00  0.00000000e+00
 -0.00000000e+00 -0.00000000e+00  0.00000000e+00  0.00000000e+00
  0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00
 -0.00000000e+00 -0.00000000e+00  0.00000000e+00  0.00000000e+00
 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00  0.00000000e+00

0.6078025885989145

In [22]:
Test_pe_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X,  true_dr = true_density_ratio, L=2)

[0.32747974 0.32747974]
[378.61366593 348.84699698]


0.3808304903747278

In [None]:
ns = 1000
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u, Alpha_s=1, Alpha_t = 0,effect=0, z_diff = 1)

Test_pe(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
     model_X = Model_X, E_X = E_X, xz_ratio = xz_ratio, test_size=est_ratio)

In [None]:
c= 0
nt, p,q = 1000, 10, 50
est_size = 10
ns = est_size + 500
est_ratio = 500/ns
#ns = 500

s = np.array([-1, -0.5, 0, 1, 1.5,-1, -0.5, 0, 1, 1.5])
t = np.array([ -1, -1, 0.5 , 0.5, 1,-1, -1, 0.5 , 0.5, 1])
# t = s
u = np.array([ 0, -1, 0.5, -0.5, 1,0, -1, 0.5, -0.5, 1])
# ns = 1000
def true_density_ratio(X, Z, V, s=s, t=t,\
                       p=p,q=q, Alpha_s = 1, Alpha_t = 0,z_diff = 0.1):
    ratios = []
    size = V.size
    for i in range(size):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        vs_prob = norm.pdf(V[i], loc=Z[i][:p]@s + Alpha_s*X[i], scale =5)
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        vt_prob = norm.pdf(V[i], loc=Z[i][:p]@t + Alpha_t*X[i], scale =5)
        ratios.append((zt_prob*vt_prob)/(zs_prob*vs_prob))
    
    return np.array(ratios)
    # return np.ones(size)

def Model_X(z, v):
    # Conditional distribution of X|Z
    return z[:p] @ u + np.random.normal(0, 1, 1)

def E_X(z, v):
    # Conditional expectation of X|Z
    return z[:p] @ u



def xz_ratio(X, Z, z_diff = 0.1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)

for i in trange(100):
    Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,Alpha_s=0.5, Alpha_t = 0,effect=1, z_diff = 0.1)

    p_value = Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
    model_X = Model_X, E_X = E_X, xz_ratio = xz_ratio, test_size=est_ratio, L=3)
    # p_value = Test_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target,\
    # model_X = Model_X, E_X = E_X,  true_dr = true_density_ratio, L=3)
    print(p_value)
    if p_value < 0.05:
        c += 1

  0%|          | 0/100 [00:00<?, ?it/s]

[-1.01505273 -0.28273852  0.          1.6549759   1.04995034 -0.89191725
 -0.22921369 -0.26683925  0.80214977  1.07900173  0.          0.
  0.          0.          0.          0.          0.         -0.
  0.         -0.          0.         -0.         -0.          0.
  0.          0.         -0.          0.          0.         -0.
  0.         -0.         -0.         -0.         -0.          0.
  0.         -0.         -0.          0.         -0.          0.
  0.          0.         -0.         -0.          0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.4221434 ]
[-1.21932691 -0.7026528   0.27544487  0.54384219  0.47993908 -0.9837403
 -0.66799648  0.48027997  0.58216357  0.32090738 -0.          0.
  0.          0.          0.          0.          0.          0.
  0.          0.         -0.          0.          0.          0.
  0.         -0.         -0.         -0.         -0.         

  1%|          | 1/100 [00:00<00:53,  1.83it/s]

0.22158522052194785
[-1.14328494 -0.55888713 -0.          0.41331879  1.42856524 -0.63505015
 -0.64418358 -0.26938015  0.39003962  1.58514501 -0.          0.
  0.         -0.         -0.         -0.          0.          0.
 -0.          0.          0.         -0.         -0.          0.
 -0.          0.         -0.          0.         -0.         -0.
 -0.          0.          0.          0.          0.         -0.
 -0.          0.         -0.         -0.         -0.         -0.
  0.         -0.          0.         -0.          0.          0.
  0.          0.          0.         -0.          0.         -0.
 -0.         -0.         -2.42572123  0.         -0.          0.
  0.39961424]
[-1.20853461 -0.97332314  0.53446197  0.60954572  0.8083097  -1.06647189
 -0.65304156  0.50377633  0.59102736  0.80816306 -0.          0.
 -0.          0.          0.         -0.          0.          0.
  0.          0.         -0.         -0.          0.         -0.
 -0.         -0.          0.          0.

  2%|▏         | 2/100 [00:01<01:01,  1.58it/s]

0.9103475376332273
[-1.39896344 -0.52281644  0.12563809  0.64891479  1.11783918 -0.77911072
 -0.74466385 -0.05649814  1.1678793   1.85670526  0.         -0.
  0.          0.          0.          0.         -0.          0.
  0.          0.          0.         -0.          0.          0.
 -0.         -0.          0.         -0.         -0.         -0.
  0.         -0.          0.         -0.          0.         -0.
  0.         -0.         -0.          0.          0.          0.
 -0.          0.          0.          0.          0.          0.
  0.          0.         -0.          0.         -0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.232812  ]
[-0.92475896 -0.92426587  0.25537155  0.39500329  0.45779669 -0.66937809
 -0.96690454  0.31042456  0.21505358  0.72288655  0.          0.
  0.          0.         -0.         -0.          0.         -0.
  0.         -0.53031559 -0.         -0.          0.          0.
 -0.         -0.          0.          0. 

  3%|▎         | 3/100 [00:02<01:08,  1.42it/s]

0.07800180829341408
[-0.95962022 -0.69200967  0.60460171  0.78401082  1.23850223 -0.81706409
 -0.82321279  0.          0.2551684   1.72574307  0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.          0.         -0.         -0.          0.         -0.
 -0.          0.          0.         -1.86840488 -0.          0.
 -0.          0.         -0.          0.          0.          0.
  0.          0.         -0.          0.          0.         -0.
 -0.         -0.          0.         -0.          0.         -0.
 -0.          0.         -0.          0.         -0.          0.
 -0.         -0.          0.          0.         -0.         -0.
  0.30315194]
[-0.80964251 -1.31985432  0.41907048  0.53787261  1.17666278 -0.96735353
 -1.15643953  0.46099368  0.64991736  0.968955   -0.          0.
  0.         -0.          0.          0.          0.         -0.
 -0.         -0.          0.         -0.         -0.          0.
  0.         -0.          0.         -0.

  4%|▍         | 4/100 [00:02<01:09,  1.39it/s]

0.6389199352210865
[-0.95347389 -0.38558588 -0.10929033 -0.02376198  1.40831885 -1.18381711
 -0.84167014  0.3469524   0.68745792  1.64106171  0.         -0.
 -0.         -0.         -0.         -0.         -0.          0.
 -0.         -0.         -0.          0.         -0.         -0.
 -0.         -0.         -0.          0.         -0.         -0.
  0.         -0.          0.          0.          0.         -0.
 -0.         -0.         -0.          0.          0.          0.
  0.          0.          0.         -0.         -0.          0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.          0.         -0.         -0.          0.         -0.
  0.25733947]
[-0.91975638 -1.06293931  0.78481585  0.23882133  1.14653968 -1.20283372
 -0.7972425   0.71455259  0.46901221  0.99738529  0.15003902  0.
  0.          0.          1.73250502  0.          0.42284393  0.
  0.3657464   0.          0.          0.          0.         -0.
  0.02149911 -0.         -0.         -0. 

  5%|▌         | 5/100 [00:03<01:02,  1.52it/s]

0.3395653728125507
[-0.38836325 -0.03882123  0.          1.33816646  0.42754503 -1.04596106
  0.         -0.38153642  1.71060179  0.907732   -0.         -0.
  0.         -0.          0.         -0.          0.          0.
  0.         -0.         -0.         -0.         -0.         -0.
 -0.         -0.          0.          0.         -0.          0.
 -0.         -0.         -0.          0.         -0.         -0.
  0.          0.          0.          0.         -0.         -0.
  0.         -0.         -0.          0.         -0.         -0.
  0.         -0.          0.         -0.          0.          0.
  0.          0.         -0.          0.         -0.          0.
  1.01049743]
[-1.17074209 -0.48140796  0.42444777  0.45955869  0.4801025  -1.17293189
 -0.5490857   0.31329561  0.52333016  0.57187028  0.         -0.
 -0.          0.         -0.          0.         -0.         -0.
  0.         -0.         -0.          0.         -0.          0.
 -0.          0.          0.          0. 

  6%|▌         | 6/100 [00:03<00:56,  1.67it/s]

In [35]:
c/100

0.06

# PCR test

In [53]:

ns,nt, p,q = 500,1000, 10, 100

s = np.array([-1, -0.5, 0, 1, 1.5,-1, -0.5, 0, 1, 1.5])
t = np.array([ -1, -1, 0.5 , 0.5, 1,-1, -1, 0.5 , 0.5, 1])
# t = (t+3*s)
u = np.array([ 0, -1, 0.5, -0.5, 1,0, -1, 0.5, -0.5, 1])

def Model_X(z, v):
    # Conditional distribution of X|Z
    return z[:p] @ u + np.random.normal(0, 1, 1)

def E_X(z, v):
    # Conditional expectation of X|Z
    return z[:p] @ u


Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,Alpha_s=2, Alpha_t = 0,effect=1, z_diff = 0.1)
_,stat = PCRtest(Y_source, X_source, Z_source, V_source, model_X= Model_X, E_X = E_X, density_ratio=None,L=3, K=20, covariate_shift=False) 
chi_squared_p_value(stat, 2)

0.02965869327658377

## Check density ratio scale

In [55]:
'''
Define some parameters for data generation
'''
nt, p,q = 1000, 10, 100
est_size = 500
ns = est_size + 200
est_ratio = 200/ns


s = np.array([-1, -0.5, 0, 1, 1.5,-1, -0.5, 0, 1, 1.5])
t = np.array([ -1, -1, 0.5 , 0.5, 1,-1, -1, 0.5 , 0.5, 1])
t = (t+3*s)
u = np.array([ 0, -1, 0.5, -0.5, 1,0, -1, 0.5, -0.5, 1])

In [60]:
def xz_ratio(X, Z, z_diff = 1):
    ratios = []
    for i in range(X.shape[0]):
        zs_prob = multivariate_normal.pdf(Z[i][:p], mean = 0*np.ones(p), cov= 1*np.identity(p))
        zt_prob = multivariate_normal.pdf(Z[i][:p], mean = z_diff*np.ones(p), cov= 1*np.identity(p))
        ratios.append((zt_prob)/(zs_prob))
    
    return np.array(ratios)
Y_source, X_source, V_source, Z_source, Y_target, X_target, V_target, Z_target = generate(ns,nt, p,q, s, t, u,Alpha_s=0, Alpha_t = 100,effect=1, z_diff = 1)

est_v_dr,X_s,Z_s,V_s,Y_s = est_v_ratio(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, test_size=est_ratio)
est_dr = xz_ratio(X_s, Z_s) * est_v_dr
true_dr = true_density_ratio(X_s, Z_s, V_s, s, t, p, q)

[-1.36211716 -0.24089191 -0.1545009   1.36262841  1.24278281 -0.7649584
 -0.17888822  0.37845853  0.6361632   1.48072965 -0.         -0.
 -0.         -0.         -0.         -0.         -0.          0.
  0.         -0.          0.         -0.10135055 -0.         -0.
  0.         -0.07143916  0.          0.          0.          0.
  0.         -0.         -0.         -0.         -0.         -0.
 -0.          0.         -0.         -0.          0.          0.
 -0.          0.          0.          0.          0.          0.
 -0.         -0.         -0.         -0.          0.         -0.
 -0.          0.         -0.         -0.          0.         -0.
  0.          0.46859665 -0.         -0.         -0.         -1.58146136
  0.         -0.         -0.          1.0756933  -0.          0.
 -0.         -0.         -0.         -0.         -0.         -0.
 -0.          0.         -0.          0.         -0.          0.
  0.         -0.         -0.15910225  0.          0.         -1.54104467
 -

In [61]:
print(max(true_dr))
print(np.mean(abs((true_dr - est_dr.reshape(-1,1)))))

[13.07777322]
0.8244227702728405


In [12]:
sum = 0
for i in range(200):
    sum+=I(Bin_pvalue(V_source[i], X_source[i], Z_source[i], V_source[i], Model_X, E_X, L=3, K=20), 3)*True_d[i]
sum

array([0.])

In [13]:
sum = 0
for i in range(1000):
    sum+=I(Bin_pvalue(V_target[i], X_target[i], Z_target[i], V_target[i],Model_X,E_X, L=3, K=20), 3)
sum*200/1000

0.0

In [41]:
WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X = Model_X, E_X = E_X, L=3, K=20, density_ratio = True_d)

# Check

In [189]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm, chi2, multivariate_normal
from sklearn import linear_model
from sklearn.linear_model import LassoCV
from sklearn.linear_model import Lasso
from sklearn.model_selection import train_test_split
from tqdm import trange
from densratio import densratio
from numpy import linalg as la
import momentchi2 as mchi
from sklearn.linear_model import LinearRegression


# Function for estimate the conditional model of V|X,Z
from sklearn.linear_model import ElasticNetCV


def est_v_ratio(X_s, Z_s, V_s, Y_s, X_t, Z_t, V_t, test_size=0.5):
    '''
    Input:
    the X,Z,V data on both source and target (ndarray)
    Return:
    v_ratio_test: the density ratio for V|X,Z(ndarray)
    X_s_test,Z_s_test, V_s_test, Y_s_test :the (X,Z,V,Y)s data used for the test(ndarray)
    '''
    # Train-test split for source domain
    X_s_train, X_s_test, Z_s_train, Z_s_test, V_s_train, V_s_test, Y_s_train, Y_s_test = train_test_split(
        X_s, Z_s, V_s, Y_s, test_size=test_size, random_state=42
    )
   
    # Concatenate the X and Z data
    D_s_train = np.concatenate((Z_s_train, X_s_train), axis=1)
    model_s = ElasticNetCV(cv=5)
    model_s.fit(D_s_train, V_s_train.ravel())
    
    # Estimate the variance of the V|X,Z model for testing data
    D_s_test = np.concatenate((Z_s_test, X_s_test), axis=1)
    V_pred_s_test = model_s.predict(D_s_test)
    residual_s_test = V_s_test.ravel() - V_pred_s_test
    est_var_s_test = np.var(residual_s_test)

    # Estimate the V probability for each sample in the testing set
    V_s_prob_test = norm.pdf(V_s_test.ravel(), loc=V_pred_s_test, scale=np.sqrt(est_var_s_test))

    # No need to test train split, use all samples in the target domain
    X_t_train, X_t_test, Z_t_train, Z_t_test, V_t_train, V_t_test = train_test_split(
        X_t, Z_t, V_t, test_size=0.01, random_state=42
    )

    # Train the conditional model V|X,Z on target
    D_t_train = np.concatenate((Z_t_train, X_t_train), axis=1)
    model_t = ElasticNetCV(cv=5)
    model_t.fit(D_t_train, V_t_train.ravel())

    # Estimate the variance of the V|X,Z model for testing data
    D_t_test = np.concatenate((Z_t_test, X_t_test), axis=1)
    V_pred_t_test = model_t.predict(D_t_test)
    residual_t_test = V_t_test.ravel() - V_pred_t_test
    est_var_t_test = np.var(residual_t_test)

    V_pred_st_test = model_t.predict(D_s_test)

    V_t_prob_test = norm.pdf(V_s_test.ravel(), loc=V_pred_st_test, scale=np.sqrt(est_var_t_test))

    v_ratio_test = V_t_prob_test / V_s_prob_test

    return v_ratio_test, X_s_test, Z_s_test, V_s_test, Y_s_test

# Function for CRT statistic calculation for each sample
def T_statistic(y, x, z, v, E_X):
    '''
    Input:
    - y, x, z, v: Sample data
    - E_X: Expectation function E_X(z, v)

    Return:
    - Test statistic for the sample
    '''
    d_x = E_X(z,v)
    
    # Return the test statistic
    return y*x

# Function for ranking the pvalues of all conterfeits and assign the sample the bin index
def Bin_pvalue(y, x, z, v, model_X, E_X, L, K):
    '''
    Input:
    - y, x, z, v: Sample data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - Bin index for the sample
    '''
        
    # The total number of bins
    M = L * K - 1
    cnt = 0
    
    # Calculate the test statistic of current sample
    t_stat = T_statistic(y, x, z, v, E_X)
    
    # Generate M counterfeits
    for i in range(M):
        x_ = model_X(z, v)
        if t_stat > T_statistic(y, x_, z, v, E_X):
            cnt=cnt+1
    # Find the bin index for the current sample 
    return cnt // K


# The main function for csPCR test
def PCRtest( Y, X, Z, V,model_X, E_X,density_ratio, L, K,covariate_shift):
    '''
    Input:
    - Y, X, Z, V: Data arrays
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - density_ratio: Density ratio for V|X,Z
    - L: Number of bins
    - K: Number of counterfeits per bin
    - covariate_shift: Boolean indicating whether to consider covariate shift

    Return:
    - W: Array of weights in each bin
    - Test statistic for csPCR test
    '''
    n = Y.size
    # initialize the weight in each bin
    W = np.array([0.0]*L)

    # Loop over all samples
    for j in range(n):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        
        # With Covariate shift
        if covariate_shift == True:
            ind = Bin_pvalue(y, x, z, v,model_X, E_X, L, K)
            W[ind] += density_ratio[j]
           
        # Normal PCR test
        if covariate_shift == False:
            W[Bin_pvalue(y, x, z, v, L, K, model_X, E_X)] += 1
    
    # Return the weights and the test statistic for csPCR test
    return W, L/n * np.dot(W - n/L, W - n/L)


# Function for generating the covariance matrix of the test statistic distribution
def generate_cov_matrix(Y, X, Z, V, model_X, E_X, density_ratio,L, K):
    '''
    Input:
    - Y, X, Z, V: Data arrays
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - density_ratio: Density ratio for V|X,Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - Covariance matrix for the test statistic distribution
    '''
    
    n = Y.size
    diag = np.array([0.0]*L)
    
    # Loop over all samples and add corresponding weights
    for j in range(n):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        diag[Bin_pvalue(y, x, z, v,model_X, E_X,L, K)] += (density_ratio[j]**2)
        
    diag = L*(diag/n)- 1/L
     # Fill all entries with 1/L
    covariance_matrix = np.full((L, L), -1/L)
    
    # Set diagonal entries to 1 - 1/L^2
    np.fill_diagonal(covariance_matrix, diag) 
    
    # Return the 
    return covariance_matrix


import scipy.stats as stats


# Calculate chi-squared p-value
def chi_squared_p_value(chi_squared_statistic, df):
    '''
    Input:
    - chi_squared_statistic: Observed chi-squared test statistic
    - df: Degrees of freedom

    Return:
    - Calculated p-value
    '''

    p_value = 1 - stats.chi2.cdf(chi_squared_statistic, df)
    return p_value


# Calculate the normal quadratic form p-value
def moment_chi_pvalue(statistic, cov1):
    '''
    Input:
    - statistic: Test statistic
    - cov1: Covariance matrix

    Return:
    - Calculated p-value using momentchi2 library
    '''
    weight = la.eigh(cov1)[0]

    p_value = 1-mchi.hbe(coeff=abs(weight), x=statistic)
    
    return p_value


# Function for the testing procedure
def Test(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X, xz_ratio, L=3, K=20, test_size = 0.5):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Estimate the density ratio by the V|X,Z conditional model using Lasso
    v_dr, X_source, Z_source, V_source, Y_source = est_v_ratio(X_source, Z_source, V_source,Y_source, X_target, Z_target, V_target)
    
    # Calculate the xz_ratio by the given function
    xz_dr = xz_ratio(X_source,Z_source)
    # Calculate the estimated density ratio
    est_dr = v_dr * xz_dr
    
    print('max dr: ' + str(max(est_dr)))
    # Estimate the covariance matrix for p-value calculation
    cov1 = generate_cov_matrix(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, density_ratio = est_dr)
    
    # Get the csPCR test statistic
    w, statistic = PCRtest(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, covariate_shift = True, density_ratio = est_dr)
    
    #print(w)
    # Call moment chi function to get the final p-value for the test
    p_value = moment_chi_pvalue(statistic, cov1)
    
    return p_value



def Test_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X,L=3, K=20, true_dr = None):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Calculate the true density ratio
    est_dr = true_dr(X_source, Z_source, V_source)
    
    print(max(est_dr))
    # Estimate the covariance matrix for p-value calculation
    cov1 = generate_cov_matrix(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, density_ratio = est_dr)
    
    # Get the csPCR test statistic
    w, statistic = PCRtest(Y_source, X_source, Z_source,V_source,model_X, E_X, L = L, K = K, covariate_shift = True, density_ratio = est_dr)
    # print(statistic)
    # Call moment chi function to get the final p-value for the test
    p_value = moment_chi_pvalue(statistic, cov1)
    
    return p_value
    
    



# Function for power enhancement version PCR test
def PCRtest_Powen(Y, X, Z, V, X_, Z_, V_, model_X, E_X, L, K, density_ratio):

    y_ind, v_ind, c = [], [], []
    W = np.array([0.0]*L)
    ns, nt = V.size, V_.size
    
    g_lst = np.zeros(L)
        
    for j in range(ns):
        y, x, z, v = Y[j], X[j], Z[j], V[j]
        ind_y = Bin_pvalue(y, x, z, v, model_X, E_X, L, K)
        ind_v = Bin_pvalue(v, x, z, v, model_X, E_X, L, K)
        y_ind.append(ind_y)
        v_ind.append(ind_v)
    
    y_ind = np.array(y_ind)
    v_ind = np.array(v_ind)
        
    density_ratio=np.array(density_ratio).ravel()
    for l in range(L):
        a = np.array([1 if x == l else 0 for x in y_ind])
        b = np.array([1 if x == l else 0 for x in v_ind])
        a_d = a-(a@density_ratio.T)/density_ratio.sum()
        b_d = b-(b@density_ratio.T)/density_ratio.sum()

        g_lst[l] = ((density_ratio*a_d)@b_d.T)/((density_ratio*b_d)@b_d.T)
    
    print(g_lst)

        
    for j in range(nt):
        x_, z_, v_ = X_[j], Z_[j], V_[j]
        ind_v_ = Bin_pvalue(v_, x_, z_, v_,model_X, E_X, L, K)
        W[ind_v_] += (ns/nt)*g_lst[ind_v_]
        c.append(ind_v_)

    c = np.array(c)
    for j in range(ns):
        W[y_ind[j]] += density_ratio[j]
        W[v_ind[j]] -= density_ratio[j]*g_lst[v_ind[j]]   

    return W, L/ns * np.dot(W - ns/L, W - ns/L),y_ind, v_ind, c, g_lst


def I(a, b):
    if a == b:
        return 1
    else:
        return 0
    

# Generate covariance matrix for the power enhancement version
def generate_cov_matrix_powen(ind_y_source, ind_v_source, ind_v_target ,g_lst, L, K, density_ratio):
    ns = ind_y_source.size
    nt = ind_v_target.size
    
    ad = []
    num_row = ns + nt
    num_col = L
    for l in range(L):
        row = []
        for s in range(ns):
            row.append(density_ratio[s]*(I(l, ind_y_source[s]) - g_lst[l]*I(l, ind_v_source[s])))
        for t in range(nt):
            row.append(ns/nt*g_lst[l]*I(l, ind_v_target[t]))
        ad.append(row)
    ad = np.array(ad)
    
    cov_matrix = np.cov(ad, rowvar=True)
    return cov_matrix*L/ns*(ns+nt)


# Function for power enhancement implementation
def Test_pe(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X, xz_ratio, L=3, K=20, test_size = 0.5):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Estimate the density ratio by the V|X,Z conditional model using Lasso
    v_dr, X_source, Z_source, V_source, Y_source = est_v_ratio(X_source, Z_source, V_source,Y_source, X_target, Z_target, V_target, test_size = test_size)
    # Calculate the xz_ratio by the given function
    xz_dr = xz_ratio(X_source,Z_source)
    # Calculate the estimated density ratio
    est_dr = v_dr * xz_dr
    
    
    #print(max(est_dr))
    WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X, E_X, L, K, est_dr)
    print(WV)
    cov = generate_cov_matrix_powen(a, b, c, g, L, K, density_ratio = est_dr)

    p_value = moment_chi_pvalue(statistic, cov)

    return p_value



def Test_pe_true_dr(X_source, Z_source, V_source, Y_source, X_target, Z_target, V_target, model_X, E_X,L=3, K=20, true_dr = None):
    '''
    Input:
    - X_source, Z_source, V_source, Y_source: Source domain data
    - X_target, Z_target, V_target: Target domain data
    - model_X: Model for generating X
    - E_X: Expectation function E_X(z, v)
    - xz_ratio: Function for ratio X|Z
    - L: Number of bins
    - K: Number of counterfeits per bin

    Return:
    - p_value: Resulting p-value from the csPCR test
    '''
    
    # Calculate the true density ratio
    est_dr = true_dr(X_source, Z_source, V_source).reshape(-1)
    
    #print(est_dr)
    
    WV, statistic, a,b,c,g = PCRtest_Powen(Y_source, X_source, Z_source, V_source, X_target, Z_target, V_target, model_X, E_X, L, K, est_dr)
    print(WV)
    cov = generate_cov_matrix_powen(a, b, c, g, L, K, density_ratio = est_dr)

    p_value = moment_chi_pvalue(statistic, cov)
    
    return p_value
    