# Inner Split Function

In [1]:
def confoundSplit(p_pos_train_z1, p_pos_train_z0, p_mix_z1, alpha_test):
    
    assert 0 <= p_pos_train_z1 <= 1
    assert 0 <= p_pos_train_z0 <= 1
    assert 0 <= p_mix_z1 <= 1
    assert alpha_test >= 0
    
    C_z = p_mix_z1
    
    p_mix_z0 = 1 - p_mix_z1
    
    # C_y = p_train(y=1) = p_train(z=0) * p_train(y=1|z=0) + p_train(z=1) * p_train(y=1|z=1) 
    # C_y = p_test(y=1) = p_test(z=0) * p_test(y=1|z=0) + p_test(z=1) * p_test(y=1|z=1)
    C_y = p_mix_z0 * p_pos_train_z0 + p_mix_z1 * p_pos_train_z1
    
    
    
    
    p_pos_test_z0 = C_y / (1 - (1-alpha_test) * C_z)
    p_pos_test_z1 = alpha_test * p_pos_test_z0
    
    
    return {"p_pos_train_z0": p_pos_train_z0,
            "p_pos_train_z1": p_pos_train_z1,
            "p_pos_train": C_y,
            "p_pos_test": C_y,
            "p_mix_z0": p_mix_z0,
            "p_mix_z1": p_mix_z1,
            "alpha_test": alpha_test,
            "p_pos_test_z0": p_pos_test_z0,
            "p_pos_test_z1": p_pos_test_z1,
            
            "C_y": C_y,
            "C_z": C_z,
           }
    

In [2]:
confoundSplit(p_pos_train_z0=0.6, p_pos_train_z1 = 0.2, p_mix_z1 = 0.5, alpha_test = 1)

{'p_pos_train_z0': 0.6,
 'p_pos_train_z1': 0.2,
 'p_pos_train': 0.4,
 'p_pos_test': 0.4,
 'p_mix_z0': 0.5,
 'p_mix_z1': 0.5,
 'alpha_test': 1,
 'p_pos_test_z0': 0.4,
 'p_pos_test_z1': 0.4,
 'C_y': 0.4,
 'C_z': 0.5}

In [3]:
confoundSplit(p_pos_train_z0=0.8, p_pos_train_z1 = 0.2, p_mix_z1 = 0.2, alpha_test = 1)

{'p_pos_train_z0': 0.8,
 'p_pos_train_z1': 0.2,
 'p_pos_train': 0.6800000000000002,
 'p_pos_test': 0.6800000000000002,
 'p_mix_z0': 0.8,
 'p_mix_z1': 0.2,
 'alpha_test': 1,
 'p_pos_test_z0': 0.6800000000000002,
 'p_pos_test_z1': 0.6800000000000002,
 'C_y': 0.6800000000000002,
 'C_z': 0.2}

In [4]:
confoundSplit(p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 1)

{'p_pos_train_z0': 0.8,
 'p_pos_train_z1': 0.1,
 'p_pos_train': 0.6600000000000001,
 'p_pos_test': 0.6600000000000001,
 'p_mix_z0': 0.8,
 'p_mix_z1': 0.2,
 'alpha_test': 1,
 'p_pos_test_z0': 0.6600000000000001,
 'p_pos_test_z1': 0.6600000000000001,
 'C_y': 0.6600000000000001,
 'C_z': 0.2}

In [5]:
confoundSplit(p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 2)

{'p_pos_train_z0': 0.8,
 'p_pos_train_z1': 0.1,
 'p_pos_train': 0.6600000000000001,
 'p_pos_test': 0.6600000000000001,
 'p_mix_z0': 0.8,
 'p_mix_z1': 0.2,
 'alpha_test': 2,
 'p_pos_test_z0': 0.5500000000000002,
 'p_pos_test_z1': 1.1000000000000003,
 'C_y': 0.6600000000000001,
 'C_z': 0.2}

# Outer Wrapper accepting two datasets and some extra parameters

In [6]:
import pandas as pd
import numpy as np

### WLS

In [7]:

df_wls = pd.read_csv("/edata/TRESTLE/testWLS.tsv", sep='\t')

df_wls_label = pd.read_csv("/edata/TRESTLE/WLS-labels.csv")

df_wls_merge = df_wls.merge(df_wls_label, left_on='file', right_on='idtlkbnk', how='inner')

df_wls_merge.rename(columns={"> 1 sd below mean for normals ages 60-79 (Tombaugh, Kozak, & Rees, 1999) -- normal cutoff = 12+ for 9-12 yrs eductation, 14+ for 13-21 yrs education":
                             "label",
                             
                            },
                    inplace=True
                   )

df_wls_merge.loc[df_wls_merge['label'] == 'y','label'] = 'Y'

condlist = [
    df_wls_merge['label'] == 'Y',
    df_wls_merge['label'] == 'N',
    df_wls_merge['label'].isna()
]
choicelist = [
    1,
    0,
    np.nan
]

df_wls_merge['label'] = np.select(condlist, choicelist)

df_wls_merge = df_wls_merge.loc[df_wls_merge['label'].notna(), :].reset_index(drop=True)

In [8]:
df_wls_merge.groupby('label', dropna=False).size()

label
0.0    1167
1.0     110
dtype: int64

### ADReSS

In [9]:
df_adress_train = pd.read_csv("/edata/ADReSS-IS2020-data/dataframes/adre_train.csv")

df_adress_test = pd.read_csv("/edata/ADReSS-IS2020-data/dataframes/adre_test.csv")

df_adress = pd.concat([df_adress_train, df_adress_test], ignore_index=True)

df_adress.rename(columns={"sentence": "text"}, inplace=True)

In [10]:
df_adress

Unnamed: 0.1,Unnamed: 0,sentence_source,label,text
0,6,S094.txt,1,oh yes a little girl and the little boy is ...
1,0,S138.txt,1,the the water's flowing on the floor and sh...
2,30,S118.txt,1,oh there's a cookie jar and a youngster with ...
3,37,S114.txt,1,mhm well the kids is robbing a cookie jar ...
4,49,S143.txt,1,well little boy reaching out for the cookie j...
...,...,...,...,...
151,43,S198.txt,1,you mean like the woman doing the dishes and ...
152,44,S180.txt,0,well the boy is taking cookies outof the cook...
153,45,S194.txt,1,well the mother has water spilling all over t...
154,46,S176.txt,1,whew do i hafta use my my my personal descr...


## Wrapper Function

In [11]:
(df_wls_merge['label'] == 0).sum()

1167

In [12]:
df_wls_merge.groupby('label', dropna=False).size()

label
0.0    1167
1.0     110
dtype: int64

In [13]:
confoundSplit(p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 2)

{'p_pos_train_z0': 0.8,
 'p_pos_train_z1': 0.1,
 'p_pos_train': 0.6600000000000001,
 'p_pos_test': 0.6600000000000001,
 'p_mix_z0': 0.8,
 'p_mix_z1': 0.2,
 'alpha_test': 2,
 'p_pos_test_z0': 0.5500000000000002,
 'p_pos_test_z1': 1.1000000000000003,
 'C_y': 0.6600000000000001,
 'C_z': 0.2}

In [14]:
import math


def confoundSplitNumbers(df0, df1, 
                    df0_label, df1_label,
                    p_pos_train_z1, p_pos_train_z0, p_mix_z1, alpha_test, 
                    train_test_ratio = 4,
                   ):
    
    """
    df0_label, df1_label: 0/1, or True/False coded
    
    
    """
    assert df0[df0_label].isin([0,1]).all(axis=0)
    assert df1[df1_label].isin([0,1]).all(axis=0)
    
    
    mix_param_dict = confoundSplit(p_pos_train_z0=p_pos_train_z0, p_pos_train_z1 = p_pos_train_z1, p_mix_z1 = p_mix_z1, alpha_test = alpha_test)
    
    
    
    N_df0_pos = (df0[df0_label] == 1).sum()
    N_df0_neg = (df0[df0_label] == 0).sum()
    
    N_df1_pos = (df1[df1_label] == 1).sum()
    N_df1_neg = (df1[df1_label] == 0).sum()
    
    N_df0 = N_df0_pos + N_df0_neg
    N_df1 = N_df1_pos + N_df1_neg
        
    
    n_df0_test_pos = math.floor(N_df0 / (train_test_ratio + 1))
    
    
    while(n_df0_test_pos > 0):
        
        n_df0_test_neg = math.floor(n_df0_test_pos / mix_param_dict['p_pos_test_z0'] * (1-mix_param_dict['p_pos_test_z0']))
        
        
        n_df0_train_pos = math.floor((n_df0_test_pos + n_df0_test_neg) * train_test_ratio * mix_param_dict['p_pos_train_z0'])
        n_df0_train_neg = math.floor((n_df0_test_pos + n_df0_test_neg) * train_test_ratio * (1 - mix_param_dict['p_pos_train_z0']))
        
        
        
        
        n_df1_train = math.floor(mix_param_dict['C_z'] / (1 - mix_param_dict['C_z']) * (n_df0_train_pos + n_df0_train_neg))
        n_df1_train_pos = math.floor(n_df1_train * mix_param_dict['p_pos_train_z1'])
        n_df1_train_neg = math.floor(n_df1_train * (1 - mix_param_dict['p_pos_train_z1']))
        
        n_df1_test = math.floor(n_df1_train/train_test_ratio)
        n_df1_test_pos = math.floor(n_df1_test * mix_param_dict['p_pos_test_z1'])
        n_df1_test_neg = math.floor(n_df1_test * (1 - mix_param_dict['p_pos_test_z1']))
        
        
        test1 = 0 < (n_df0_train_pos + n_df0_test_pos) <= N_df0_pos
        test2 = 0 < (n_df0_train_neg + n_df0_test_neg) <= N_df0_neg
        
        test3 = 0 < (n_df1_train_pos + n_df1_test_pos) <= N_df1_pos
        test4 = 0 < (n_df1_train_neg + n_df1_test_neg) <= N_df1_neg
        
        
        test5 = 0 < n_df0_train_pos
        test6 = 0 < n_df0_train_neg
        test7 = 0 < n_df1_train_pos
        test8 = 0 < n_df1_train_neg
        
        test9 = 0 < n_df0_test_pos
        test10 = 0 < n_df0_test_neg
        test11 = 0 < n_df1_test_pos
        test12 = 0 < n_df1_test_neg
        
        
        
        if test1 and test2 and test3 and test4 and test5 and test6 and test7 and test8 and test9 and test10 and test11 and test12:
            return {"n_df0_train_pos": n_df0_train_pos,
                    "n_df0_test_pos": n_df0_test_pos,
                    "n_df0_train_neg": n_df0_train_neg,
                    "n_df0_test_neg": n_df0_test_neg,
                    
                    "n_df1_train_pos": n_df1_train_pos,
                    "n_df1_test_pos":n_df1_test_pos,
                    "n_df1_train_neg":n_df1_train_neg,
                    "n_df1_test_neg":n_df1_test_neg,
                    
                   }
        else:        
            n_df0_test_pos -= 1
        
        if n_df0_test_pos == 0:
            return None

In [15]:
confoundSplit(p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 2)

{'p_pos_train_z0': 0.8,
 'p_pos_train_z1': 0.1,
 'p_pos_train': 0.6600000000000001,
 'p_pos_test': 0.6600000000000001,
 'p_mix_z0': 0.8,
 'p_mix_z1': 0.2,
 'alpha_test': 2,
 'p_pos_test_z0': 0.5500000000000002,
 'p_pos_test_z1': 1.1000000000000003,
 'C_y': 0.6600000000000001,
 'C_z': 0.2}

In [16]:

confoundSplitNumbers(df0=df_wls_merge, df1=df_adress, 
                    df0_label='label', df1_label='label',
                    
                    p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 2,
                    
                    train_test_ratio = 5,
                   )

In [17]:

confoundSplitNumbers(df0=df_wls_merge, df1=df_adress, 
                    df0_label='label', df1_label='label',
                    
                    p_pos_train_z0=0.8, p_pos_train_z1 = 0.1, p_mix_z1 = 0.2, alpha_test = 2,
                    
                    train_test_ratio = 1,
                   )



In [18]:
np.arange(0, 1, 0.1)

array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])

In [19]:
import itertools

In [28]:

valid_combinations = []

for combination in itertools.product([0.5], 
                                    [0.1], 
                                     np.arange(0.1, 0.999, 0.1),
                                     np.arange(0,10,0.1),
                                     [4]
                                     
                                    ):
    
    ret = confoundSplitNumbers(df0=df_wls_merge, df1=df_adress, 
                        df0_label='label', df1_label='label',

                        p_pos_train_z0=combination[0], p_pos_train_z1 = combination[1], p_mix_z1 = combination[2], alpha_test = combination[3],

                        train_test_ratio = combination[4],
                       )
    
    if ret is not None:
        valid_combinations.append(combination)
        
    

In [29]:
len(valid_combinations)

379

In [22]:
len([c for c in valid_combinations if (0.1 <= c[0] <= 0.2) and (0.6 <= c[1] <= 0.65) and (c[4] == 4)])

0

In [23]:
len([c for c in valid_combinations if (0.1 <= c[0] <= 0.3) and (0.6 <= c[1] <= 0.8) and (c[4] == 4) and (1 <= c[3] <= 1.5)])

0

In [24]:
valid_high_combinations = []

for combination in valid_combinations:

    ret = confoundSplitNumbers(df0=df_wls_merge, df1=df_adress, 
                            df0_label='label', df1_label='label',

                            p_pos_train_z0=combination[0], p_pos_train_z1 = combination[1], p_mix_z1 = combination[2], alpha_test = combination[3],

                            train_test_ratio = combination[4],
                           )
    
    if (ret['n_df0_train_pos'] >= 10) and (combination[4] == 4):
            valid_high_combinations.append(combination)
    
    
    

In [25]:
len(valid_high_combinations)

530

In [27]:
valid_high_combinations[:3]

[(0.2, 0.4, 0.1, 0.4, 4),
 (0.2, 0.4, 0.1, 0.5, 4),
 (0.2, 0.4, 0.1, 0.6000000000000001, 4)]

In [30]:

combination= (0.2, 0.6, 0.6, 4.800000000000001, 4)

ret = confoundSplitNumbers(df0=df_wls_merge, df1=df_adress, 
                            df0_label='label', df1_label='label',

                            p_pos_train_z0=combination[0], p_pos_train_z1 = combination[1], p_mix_z1 = combination[2], alpha_test = combination[3],

                            train_test_ratio = combination[4],
                           )

In [31]:
ret

{'n_df0_train_pos': 11,
 'n_df0_test_pos': 2,
 'n_df0_train_neg': 44,
 'n_df0_test_neg': 12,
 'n_df1_train_pos': 49,
 'n_df1_test_pos': 12,
 'n_df1_train_neg': 32,
 'n_df1_test_neg': 7}

# In Usage

In [1]:
import sys, os

In [46]:
sys.path.insert(0, "..")
from src.utils import confoundSplitNumbers, confoundSplitDF
from src.data_process import load_wls_adress_AddDomain

In [47]:
df_wls_merge = load_wls_adress_AddDomain(dt="wls")
df_adress = load_wls_adress_AddDomain(dt="adress")

In [209]:
n_valid_high = 10
n_test = [150]
n_test_error = [20]

valid_high_combinations=[]
valid_full_settings=[]
for combination in itertools.product(
        [0.5],
        [0.2],
        np.arange(0.1, 0.999, 0.1),
        np.arange(0, 5, 0.05),
        [4],
        n_test,
        n_test_error,
    ):
        ret = confoundSplitNumbers(
            df0=df_wls_merge,
            df1=df_adress,
            df0_label="label",
            df1_label="label",
            p_pos_train_z0=combination[0],
            p_pos_train_z1=combination[1],
            p_mix_z1=combination[2],
            alpha_test=combination[3],
            train_test_ratio=combination[4],
            n_test=combination[5],
            n_test_error=combination[6],
        )

        if (ret is not None) and (
            ret["n_df0_train_pos"] >= n_valid_high
        ):  # valie high combos
            valid_high_combinations.append(combination)
            valid_full_settings.append(ret)

In [210]:
len(valid_high_combinations)

75

In [211]:
tmp = pd.DataFrame(valid_high_combinations)

In [212]:
tmp.head()

Unnamed: 0,0,1,2,3,4,5,6
0,0.5,0.2,0.4,1.4,4,150,20
1,0.5,0.2,0.4,1.65,4,150,20
2,0.5,0.2,0.4,1.7,4,150,20
3,0.5,0.2,0.4,1.8,4,150,20
4,0.5,0.2,0.4,1.95,4,150,20


In [213]:
# alpha_test range
tmp.iloc[:,[3]].describe()

Unnamed: 0,3
count,75.0
mean,3.458667
std,0.941906
min,1.4
25%,2.775
50%,3.5
75%,4.225
max,4.95


In [214]:
# p_mix range
tmp.iloc[:,[2]].describe()

Unnamed: 0,2
count,75.0
mean,0.417333
std,0.038108
min,0.4
25%,0.4
50%,0.4
75%,0.4
max,0.5


# Test Two Split Framework

In [2]:
import numpy as np
import pandas as pd

In [3]:
import itertools

In [4]:
import sys, os

## Framework 1

In [107]:
sys.path.insert(0, "..")
from src.utils import confoundSplit, confoundSplitNumbers, confoundSplitDF


In [6]:
np.random.seed(12)

df = pd.DataFrame({"z": np.random.binomial(1, 0.2, 2000),
                   "y": np.random.binomial(1, 0.5, 2000)
                  })

In [7]:
df.groupby(['z','y']).size()

z  y
0  0    813
   1    762
1  0    219
   1    206
dtype: int64

In [8]:
df0 = df.query("z == 0").reset_index(drop=True)
df1 = df.query("z == 1").reset_index(drop=True)

In [141]:
n_valid_high = 0
n_test = [150]
n_test_error = [5]

valid_high_combinations=[]
valid_full_settings=[]
for combination in itertools.product(
        [0.1, 0.5],
        [0.2, 0.6],
        np.arange(0.1, 0.999, 0.1),
        np.arange(0, 5, 0.05),
        [4],
        n_test,
        n_test_error,
    ):
        ret = confoundSplitNumbers(
            df0=df0,
            df1=df1,
            df0_label="y",
            df1_label="y",
            p_pos_train_z0=combination[0],
            p_pos_train_z1=combination[1],
            p_mix_z1=combination[2],
            alpha_test=combination[3],
            train_test_ratio=combination[4],
            n_test=combination[5],
            n_test_error=combination[6],
        )

        if (ret is not None) and (
            ret["n_df0_train_pos"] >= n_valid_high
        ):  # valie high combos
            valid_high_combinations.append(combination)
            valid_full_settings.append(ret)

In [142]:
valid_full_settings[1]

{'n_df0_train_pos': 56,
 'n_df0_test_pos': 16,
 'n_df0_train_neg': 504,
 'n_df0_test_neg': 124,
 'n_df1_train_pos': 12,
 'n_df1_test_pos': 1,
 'n_df1_train_neg': 49,
 'n_df1_test_neg': 13,
 'mix_param_dict': {'p_pos_train_z0': 0.1,
  'p_pos_train_z1': 0.2,
  'p_pos_train': 0.11000000000000001,
  'p_pos_test': 0.11000000000000001,
  'p_mix_z0': 0.9,
  'p_mix_z1': 0.1,
  'alpha_train': 2.0,
  'alpha_test': 0.65,
  'p_pos_test_z0': 0.1139896373056995,
  'p_pos_test_z1': 0.07409326424870467,
  'C_y': 0.11000000000000001,
  'C_z': 0.1}}

In [143]:
len(valid_high_combinations)

1128

In [13]:
df_eval_num = pd.DataFrame(valid_full_settings)

In [14]:
df_eval_num.drop('mix_param_dict', axis=1, inplace=True)

In [15]:
df_eval_params = pd.DataFrame([x['mix_param_dict'] for x in valid_full_settings])

In [16]:
sum(df_eval_params.round(4).duplicated())

0

In [17]:
df_eval = pd.concat([df_eval_num, df_eval_params.round(4)], axis=1)

In [18]:
df_eval

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
0,54,15,489,121,12,1,48,13,0.1,0.2,0.11,0.11,0.9,0.1,2.0,1.00,0.1100,0.1100,0.11,0.1
1,54,14,489,122,12,2,48,12,0.1,0.2,0.11,0.11,0.9,0.1,2.0,1.70,0.1028,0.1748,0.11,0.1
2,54,14,489,122,12,2,48,12,0.1,0.2,0.11,0.11,0.9,0.1,2.0,1.75,0.1023,0.1791,0.11,0.1
3,54,13,489,123,12,3,48,11,0.1,0.2,0.11,0.11,0.9,0.1,2.0,2.55,0.0952,0.2429,0.11,0.1
4,54,12,489,124,12,4,48,10,0.1,0.2,0.11,0.11,0.9,0.1,2.0,3.50,0.0880,0.3080,0.11,0.1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
292,182,27,182,64,145,54,96,5,0.5,0.6,0.54,0.54,0.6,0.4,1.2,3.10,0.2935,0.9098,0.54,0.4
293,182,26,182,65,145,55,96,4,0.5,0.6,0.54,0.54,0.6,0.4,1.2,3.25,0.2842,0.9237,0.54,0.4
294,182,25,182,66,145,56,96,3,0.5,0.6,0.54,0.54,0.6,0.4,1.2,3.45,0.2727,0.9409,0.54,0.4
295,182,24,182,67,145,57,96,2,0.5,0.6,0.54,0.54,0.6,0.4,1.2,3.65,0.2621,0.9568,0.54,0.4


## Framework 2

In [56]:
df2 = pd.read_csv("./all_combo.csv")

In [57]:
df2.drop('Unnamed: 0', axis=1, inplace=True)

In [58]:
df2 = df2.rename({"n_z0_pos_train":"n_df0_train_pos",
                  "n_z0_neg_train":"n_df0_train_neg",
                  "n_z0_neg_test":"n_df0_test_neg",
                  "n_z1_pos_test":"n_df1_test_pos",
                  "n_z0_pos_test":"n_df0_test_pos",
                  "n_z1_neg_test":"n_df1_test_neg",
                  "n_z1_neg_train":"n_df1_train_neg",
                  "n_z1_pos_train":"n_df1_train_pos",
                 }, axis=1)

df2 = df2[df_eval.columns]

## Checking...

In [63]:
df_eval.iloc[:,:8].sum(axis=1)

0      753
1      753
2      753
3      753
4      753
      ... 
292    755
293    755
294    755
295    755
296    755
Length: 297, dtype: int64

In [64]:
df2.iloc[:,:8].sum(axis=1)

0       750
1       750
2       750
3       750
4       750
       ... 
1201    750
1202    750
1203    750
1204    750
1205    750
Length: 1206, dtype: int64

In [65]:
df2.dtypes

n_df0_train_pos      int64
n_df0_test_pos       int64
n_df0_train_neg      int64
n_df0_test_neg       int64
n_df1_train_pos      int64
n_df1_test_pos       int64
n_df1_train_neg      int64
n_df1_test_neg       int64
p_pos_train_z0     float64
p_pos_train_z1     float64
p_pos_train        float64
p_pos_test         float64
p_mix_z0           float64
p_mix_z1           float64
alpha_train        float64
alpha_test         float64
p_pos_test_z0      float64
p_pos_test_z1      float64
C_y                float64
C_z                float64
dtype: object

In [66]:
df2.columns

Index(['n_df0_train_pos', 'n_df0_test_pos', 'n_df0_train_neg',
       'n_df0_test_neg', 'n_df1_train_pos', 'n_df1_test_pos',
       'n_df1_train_neg', 'n_df1_test_neg', 'p_pos_train_z0', 'p_pos_train_z1',
       'p_pos_train', 'p_pos_test', 'p_mix_z0', 'p_mix_z1', 'alpha_train',
       'alpha_test', 'p_pos_test_z0', 'p_pos_test_z1', 'C_y', 'C_z'],
      dtype='object')

In [67]:
df2.duplicated()

0       False
1       False
2       False
3       False
4       False
        ...  
1201    False
1202    False
1203    False
1204    False
1205    False
Length: 1206, dtype: bool

In [68]:
df2.groupby("alpha_test").size()

alpha_test
0.05     8
0.10    12
0.15    14
0.20    14
0.25    15
        ..
4.75     9
4.80     9
4.85     9
4.90     9
4.95     9
Length: 99, dtype: int64

In [69]:
df_eval.groupby("alpha_test").size()

alpha_test
0.05    4
0.10    6
0.15    5
0.20    5
0.25    5
       ..
4.60    3
4.65    3
4.70    2
4.75    1
4.95    1
Length: 86, dtype: int64

In [75]:
df_eval.query("alpha_test==0.05")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
98,30,51,273,25,181,2,121,72,0.1,0.6,0.35,0.35,0.5,0.5,6.0,0.05,0.6667,0.0333,0.35,0.5
150,212,61,212,45,36,1,144,43,0.5,0.2,0.41,0.41,0.7,0.3,0.4,0.05,0.5734,0.0287,0.41,0.3
231,212,79,212,27,108,1,72,43,0.5,0.6,0.53,0.53,0.7,0.3,1.2,0.05,0.7413,0.0371,0.53,0.3
263,182,80,182,11,145,2,96,57,0.5,0.6,0.54,0.54,0.6,0.4,1.2,0.05,0.871,0.0435,0.54,0.4


In [76]:
df2.query("alpha_test==0.05")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
482,42,37,378,68,108,1,72,44,0.1,0.6,0.25,0.25,0.7,0.3,6.0,0.05,0.34965,0.017483,0.25,0.3
581,36,44,324,46,144,1,96,59,0.1,0.6,0.3,0.3,0.6,0.4,6.0,0.05,0.483871,0.024194,0.3,0.4
680,30,50,270,25,180,2,120,73,0.1,0.6,0.35,0.35,0.5,0.5,6.0,0.05,0.666667,0.033333,0.35,0.5
745,240,65,240,55,24,1,96,29,0.5,0.2,0.44,0.44,0.8,0.2,0.4,0.05,0.54321,0.02716,0.44,0.2
809,210,60,210,45,36,1,144,44,0.5,0.2,0.41,0.41,0.7,0.3,0.4,0.05,0.573427,0.028671,0.41,0.3
1009,240,77,240,43,72,1,48,29,0.5,0.6,0.52,0.52,0.8,0.2,1.2,0.05,0.641975,0.032099,0.52,0.2
1057,210,78,210,27,108,2,72,43,0.5,0.6,0.53,0.53,0.7,0.3,1.2,0.05,0.741259,0.037063,0.53,0.3
1116,180,78,180,12,144,3,96,57,0.5,0.6,0.54,0.54,0.6,0.4,1.2,0.05,0.870968,0.043548,0.54,0.4


In [None]:
 [0.1, 0.5],
[0.2, 0.6],
np.arange(0.1, 0.999, 0.1),
np.arange(0, 5, 0.05),
        
p_pos_train_z0=combination[0],
p_pos_train_z1=combination[1],
p_mix_z1=combination[2],
alpha_test=combination[3],

In [97]:
i = "p_mix_z1"

print(sorted(df_eval[i].unique()))
print(sorted(df2[i].unique()))

[0.1, 0.2, 0.3, 0.4, 0.5]
[0.1, 0.2, 0.3, 0.4, 0.5]


In [96]:
i = "alpha_test"

print(sorted(df_eval[i].unique()))
print(sorted(df2[i].unique()))

[0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0, 1.05, 1.1, 1.15, 1.2, 1.25, 1.3, 1.35, 1.4, 1.45, 1.5, 1.55, 1.6, 1.65, 1.7, 1.75, 1.8, 1.85, 1.9, 1.95, 2.05, 2.1, 2.15, 2.2, 2.25, 2.3, 2.35, 2.4, 2.45, 2.5, 2.55, 2.6, 2.65, 2.7, 2.75, 2.8, 2.85, 2.9, 3.0, 3.05, 3.1, 3.15, 3.2, 3.25, 3.3, 3.35, 3.4, 3.45, 3.5, 3.55, 3.6, 3.65, 3.75, 3.8, 3.85, 3.95, 4.0, 4.1, 4.2, 4.25, 4.3, 4.5, 4.6, 4.65, 4.7, 4.75, 4.95]
[0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6000000000000001, 0.65, 0.7000000000000001, 0.75, 0.8, 0.8500000000000001, 0.9, 0.95, 1.0, 1.05, 1.1, 1.15, 1.2000000000000002, 1.25, 1.3, 1.35, 1.4, 1.4500000000000002, 1.5, 1.55, 1.6, 1.65, 1.7000000000000002, 1.75, 1.8, 1.85, 1.9, 1.95, 2.0, 2.0500000000000003, 2.1, 2.15, 2.2, 2.25, 2.3000000000000003, 2.35, 2.4000000000000004, 2.45, 2.5, 2.5500000000000003, 2.6, 2.6500000000000004, 2.7, 2.75, 2.8000000000000003, 2.85, 2.9000000000000004, 2.95, 3.0, 3.050

In [99]:
df_eval.query("(alpha_test==0.2) & (p_pos_train_z0 == 0.1) & (p_pos_train_z1 == 0.6)").sort_values("n_df0_train_pos")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
55,42,35,381,71,108,2,72,42,0.1,0.6,0.25,0.25,0.7,0.3,6.0,0.2,0.3289,0.0658,0.25,0.3
37,48,29,435,92,72,1,48,28,0.1,0.6,0.2,0.2,0.8,0.2,6.0,0.2,0.2381,0.0476,0.2,0.2


In [100]:
df2.query("(alpha_test==0.2) & (p_pos_train_z0 == 0.1) & (p_pos_train_z1 == 0.6)").sort_values("n_df0_train_pos")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
683,30,44,270,31,180,9,120,66,0.1,0.6,0.35,0.35,0.5,0.5,6.0,0.2,0.583333,0.116667,0.35,0.5
584,36,40,324,50,144,5,96,55,0.1,0.6,0.3,0.3,0.6,0.4,6.0,0.2,0.441176,0.088235,0.3,0.4
485,42,35,378,70,108,3,72,42,0.1,0.6,0.25,0.25,0.7,0.3,6.0,0.2,0.328947,0.065789,0.25,0.3
386,48,29,432,91,72,1,48,29,0.1,0.6,0.2,0.2,0.8,0.2,6.0,0.2,0.238095,0.047619,0.2,0.2


In [104]:
for combination in itertools.product(
        [0.1, 0.5],
        [0.2, 0.6],
        np.arange(0.1, 0.999, 0.1),
        np.arange(0, 5, 0.05),
        [4],
        n_test,
        n_test_error,
    ):
    if combination[0] == 0.1 and combination[1] == 0.6 and combination[2] ==  0.4 and combination[3]==0.2:
        print(combination)

(0.1, 0.6, 0.4, 0.2, 4, 150, 0)


In [None]:
confoundSplitDF

In [110]:
combination

(0.5, 0.6, 0.9, 4.95, 4, 150, 0)

In [133]:
combination = (0.1, 0.6, 0.5, 0.2, 4, 150, 0)
confoundSplit(p_pos_train_z1=combination[1], p_pos_train_z0=combination[0], p_mix_z1=combination[2], alpha_test=combination[3])

{'p_pos_train_z0': 0.1,
 'p_pos_train_z1': 0.6,
 'p_pos_train': 0.35,
 'p_pos_test': 0.35,
 'p_mix_z0': 0.5,
 'p_mix_z1': 0.5,
 'alpha_train': 5.999999999999999,
 'alpha_test': 0.2,
 'p_pos_test_z0': 0.5833333333333334,
 'p_pos_test_z1': 0.11666666666666668,
 'C_y': 0.35,
 'C_z': 0.5}

In [134]:


ret = confoundSplitNumbers(
            df0=df0,
            df1=df1,
            df0_label="y",
            df1_label="y",
            p_pos_train_z0=combination[0],
            p_pos_train_z1=combination[1],
            p_mix_z1=combination[2],
            alpha_test=combination[3],
            train_test_ratio=combination[4],
            n_test=combination[5],
            n_test_error=combination[6],
        )

In [135]:
ret

In [139]:
import math


def test(
    df0,
    df1,
    df0_label,
    df1_label,
    p_pos_train_z1,
    p_pos_train_z0,
    p_mix_z1,
    alpha_test,
    train_test_ratio=4,
    n_test=None,  # set the number for tests
    n_test_error = 0,
):

    """
    df0_label, df1_label: 0/1, or True/False coded


    """
    assert df0[df0_label].isin([0, 1]).all(axis=0)
    assert df1[df1_label].isin([0, 1]).all(axis=0)

    mix_param_dict = confoundSplit(
        p_pos_train_z0=p_pos_train_z0,
        p_pos_train_z1=p_pos_train_z1,
        p_mix_z1=p_mix_z1,
        alpha_test=alpha_test,
    )

    N_df0_pos = (df0[df0_label] == 1).sum()
    N_df0_neg = (df0[df0_label] == 0).sum()

    N_df1_pos = (df1[df1_label] == 1).sum()
    N_df1_neg = (df1[df1_label] == 0).sum()

    N_df0 = N_df0_pos + N_df0_neg
    N_df1 = N_df1_pos + N_df1_neg

    n_df0_test_pos = math.floor(N_df0 / (train_test_ratio + 1))
    
    while n_df0_test_pos > 0:
        
        n_df0_test_neg = math.floor(
            n_df0_test_pos
            / mix_param_dict["p_pos_test_z0"]
            * (1 - mix_param_dict["p_pos_test_z0"])
        )

        n_df0_train_pos = math.floor(
            (n_df0_test_pos + n_df0_test_neg)
            * train_test_ratio
            * mix_param_dict["p_pos_train_z0"]
        )
        n_df0_train_neg = math.floor(
            (n_df0_test_pos + n_df0_test_neg)
            * train_test_ratio
            * (1 - mix_param_dict["p_pos_train_z0"])
        )

        n_df1_train = math.floor(
            mix_param_dict["C_z"]
            / (1 - mix_param_dict["C_z"])
            * (n_df0_train_pos + n_df0_train_neg)
        )
        n_df1_train_pos = math.floor(n_df1_train * mix_param_dict["p_pos_train_z1"])
        n_df1_train_neg = math.floor(
            n_df1_train * (1 - mix_param_dict["p_pos_train_z1"])
        )

        n_df1_test = math.floor(n_df1_train / train_test_ratio)
        n_df1_test_pos = math.floor(n_df1_test * mix_param_dict["p_pos_test_z1"])
        n_df1_test_neg = math.floor(n_df1_test * (1 - mix_param_dict["p_pos_test_z1"]))

        test1 = 0 < (n_df0_train_pos + n_df0_test_pos) <= N_df0_pos
        test2 = 0 < (n_df0_train_neg + n_df0_test_neg) <= N_df0_neg

        test3 = 0 < (n_df1_train_pos + n_df1_test_pos) <= N_df1_pos
        test4 = 0 < (n_df1_train_neg + n_df1_test_neg) <= N_df1_neg

        test5 = 0 < n_df0_train_pos
        test6 = 0 < n_df0_train_neg
        test7 = 0 < n_df1_train_pos
        test8 = 0 < n_df1_train_neg

        test9 = 0 < n_df0_test_pos
        test10 = 0 < n_df0_test_neg
        test11 = 0 < n_df1_test_pos
        test12 = 0 < n_df1_test_neg

        tester_positive_number = (
            test1
            and test2
            and test3
            and test4
            and test5
            and test6
            and test7
            and test8
            and test9
            and test10
            and test11
            and test12
        )

        tester_n_test = (
            n_df0_test_pos + n_df0_test_neg + n_df1_test_pos + n_df1_test_neg
        )

        
        if n_df0_test_pos == 44:
            ret = {
                "n_df0_train_pos": n_df0_train_pos,
                "n_df0_test_pos": n_df0_test_pos,
                "n_df0_train_neg": n_df0_train_neg,
                "n_df0_test_neg": n_df0_test_neg,
                "n_df1_train_pos": n_df1_train_pos,
                "n_df1_test_pos": n_df1_test_pos,
                "n_df1_train_neg": n_df1_train_neg,
                "n_df1_test_neg": n_df1_test_neg,
                "mix_param_dict": mix_param_dict,
            }
            breakpoint()
            return ret
        if tester_positive_number:

            
            if n_test is None:
                return ret
            else:
                if (n_test - n_test_error) <= tester_n_test <= (n_test + n_test_error):
                    return ret
                else:
                    n_df0_test_pos -= 1
        else:
            n_df0_test_pos -= 1

        if n_df0_test_pos == 0:
            return None


In [140]:


ret = test(
            df0=df0,
            df1=df1,
            df0_label="y",
            df1_label="y",
            p_pos_train_z0=combination[0],
            p_pos_train_z1=combination[1],
            p_mix_z1=combination[2],
            alpha_test=combination[3],
            train_test_ratio=combination[4],
            n_test=combination[5],
            n_test_error=combination[6],
        )

> [0;32m<ipython-input-139-42caeba6592c>[0m(126)[0;36mtest[0;34m()[0m
[0;32m    124 [0;31m            }
[0m[0;32m    125 [0;31m            [0mbreakpoint[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 126 [0;31m            [0;32mreturn[0m [0mret[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    127 [0;31m        [0;32mif[0m [0mtester_positive_number[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    128 [0;31m[0;34m[0m[0m
[0m


ipdb>  tester_positive_number


True


ipdb>  tester_n_test


149


ipdb>  c


In [138]:
ret

{'n_df0_train_pos': 30,
 'n_df0_test_pos': 44,
 'n_df0_train_neg': 270,
 'n_df0_test_neg': 31,
 'n_df1_train_pos': 180,
 'n_df1_test_pos': 8,
 'n_df1_train_neg': 120,
 'n_df1_test_neg': 66,
 'mix_param_dict': {'p_pos_train_z0': 0.1,
  'p_pos_train_z1': 0.6,
  'p_pos_train': 0.35,
  'p_pos_test': 0.35,
  'p_mix_z0': 0.5,
  'p_mix_z1': 0.5,
  'alpha_train': 5.999999999999999,
  'alpha_test': 0.2,
  'p_pos_test_z0': 0.5833333333333334,
  'p_pos_test_z1': 0.11666666666666668,
  'C_y': 0.35,
  'C_z': 0.5}}

In [115]:
df2.query("alpha_test == 4.800000000000001")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
90,54,11,486,124,12,6,48,9,0.1,0.2,0.11,0.11,0.9,0.1,2.0,4.8,0.07971,0.382609,0.11,0.1
187,48,8,432,112,24,10,96,20,0.1,0.2,0.12,0.12,0.8,0.2,2.0,4.8,0.068182,0.327273,0.12,0.2
285,42,6,378,99,36,13,144,32,0.1,0.2,0.13,0.13,0.7,0.3,2.0,4.8,0.060748,0.291589,0.13,0.3
380,54,15,486,120,36,8,24,7,0.1,0.6,0.15,0.15,0.9,0.1,6.0,4.8,0.108696,0.521739,0.15,0.1
478,48,14,432,106,72,16,48,14,0.1,0.6,0.2,0.2,0.8,0.2,6.0,4.8,0.113636,0.545455,0.2,0.2
577,42,12,378,93,108,25,72,20,0.1,0.6,0.25,0.25,0.7,0.3,6.0,4.8,0.116822,0.560748,0.25,0.3
676,36,11,324,79,144,34,96,26,0.1,0.6,0.3,0.3,0.6,0.4,6.0,4.8,0.119048,0.571429,0.3,0.4
904,210,20,210,85,36,41,144,4,0.5,0.2,0.41,0.41,0.7,0.3,0.4,4.8,0.191589,0.919626,0.41,0.3
964,180,14,180,76,48,43,192,17,0.5,0.2,0.38,0.38,0.6,0.4,0.4,4.8,0.150794,0.72381,0.38,0.4


In [117]:
df2.query("n_df0_train_pos >= 10")

Unnamed: 0,n_df0_train_pos,n_df0_test_pos,n_df0_train_neg,n_df0_test_neg,n_df1_train_pos,n_df1_test_pos,n_df1_train_neg,n_df1_test_neg,p_pos_train_z0,p_pos_train_z1,p_pos_train,p_pos_test,p_mix_z0,p_mix_z1,alpha_train,alpha_test,p_pos_test_z0,p_pos_test_z1,C_y,C_z
0,54,16,486,119,12,1,48,14,0.1,0.2,0.11,0.11,0.9,0.1,2.0,0.30,0.118280,0.035484,0.11,0.1
1,54,16,486,119,12,1,48,14,0.1,0.2,0.11,0.11,0.9,0.1,2.0,0.35,0.117647,0.041176,0.11,0.1
2,54,16,486,119,12,1,48,14,0.1,0.2,0.11,0.11,0.9,0.1,2.0,0.40,0.117021,0.046809,0.11,0.1
3,54,16,486,119,12,1,48,14,0.1,0.2,0.11,0.11,0.9,0.1,2.0,0.45,0.116402,0.052381,0.11,0.1
4,54,16,486,119,12,1,48,14,0.1,0.2,0.11,0.11,0.9,0.1,2.0,0.50,0.115789,0.057895,0.11,0.1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1201,150,66,150,9,180,17,120,58,0.5,0.6,0.55,0.55,0.5,0.5,1.2,0.25,0.880000,0.220000,0.55,0.5
1202,150,63,150,12,180,19,120,56,0.5,0.6,0.55,0.55,0.5,0.5,1.2,0.30,0.846154,0.253846,0.55,0.5
1203,150,61,150,14,180,21,120,54,0.5,0.6,0.55,0.55,0.5,0.5,1.2,0.35,0.814815,0.285185,0.55,0.5
1204,150,59,150,16,180,24,120,51,0.5,0.6,0.55,0.55,0.5,0.5,1.2,0.40,0.785714,0.314286,0.55,0.5


# Multi-level variables - one confounder

## Inner Stat Function

In [16]:
import itertools

In [17]:

p_train_y_given_z = [[0.1, 0.1, 0.8], 
                     [0.2, 0,   0.8], 
                     [0.7, 0.3, 0  ],
                     [0.3, 0.3, 0.4]
                    ]
p_test_y_given_z = [[0.0, 0.7, 0.3], 
                     [0.1, 0.4, 0.5], 
                     [0.7, 0.3, 0  ],
                     [0.2, 0.4, 0.4]
                    ]
p_z = [0.2, 0.2, 0.3, 0.3]

In [19]:
if isinstance(p_train_y_given_z, list):
    p_train_y_given_z = np.array(p_train_y_given_z)
    
if isinstance(p_test_y_given_z, list):
    p_test_y_given_z = np.array(p_test_y_given_z)
    
if isinstance(p_z, list):
    p_z = np.array(p_z)
    

In [20]:
p_train_y_given_z

array([[0.1, 0.1, 0.8],
       [0.2, 0. , 0.8],
       [0.7, 0.3, 0. ],
       [0.3, 0.3, 0.4]])

In [21]:
p_test_y_given_z

array([[0. , 0.7, 0.3],
       [0.1, 0.4, 0.5],
       [0.7, 0.3, 0. ],
       [0.2, 0.4, 0.4]])

In [22]:
# quality check
assert np.all(p_train_y_given_z.sum(axis=1) == 1)
assert np.all(p_test_y_given_z.sum(axis=1) == 1)
assert np.all(p_z.sum() == 1)

assert p_train_y_given_z.shape == p_test_y_given_z.shape

## Generate Numbers

In [1]:
import random
import numpy as np
import pandas as pd

### Toy Dataset

In [2]:
random.seed(2)
z = random.choices(population=range(4), weights=[0.3,0.2,0.1,0.4], k=1000)

In [3]:
np.unique(z, return_counts=True)

(array([0, 1, 2, 3]), array([287, 209, 102, 402]))

In [4]:
random.seed(2)

choices = []
for i in z:
    if i == 0:
        weights = [0.2,0.5,0.3]
    elif i == 1:
        weights = [0.3,0.3,0.4]
    elif i == 2:
        weights = [0.2,0.8,0]
    elif i == 3:
        weights = [0.1,0.7,0.2]
    else:
        os.error()
    
    _choices = random.choices(population=range(3), weights=weights, k=10)
    choices.append(_choices)

In [5]:
df = pd.DataFrame({"z":np.repeat(z, 10),
                   "y":np.array(choices).flatten()
                  })

In [6]:
df.groupby("z").size()

z
0    2870
1    2090
2    1020
3    4020
dtype: int64

In [7]:
df.groupby(["z","y"]).size()

z  y
0  0     584
   1    1431
   2     855
1  0     609
   1     645
   2     836
2  0     235
   1     785
3  0     390
   1    2844
   2     786
dtype: int64

In [8]:
zC = ['Hosp_A','Hosp_B','Hosp_C', 'Hosp_D']
yC = ['none','benign', 'malignant']

In [9]:
z_map = {idx: it for idx,it in enumerate(zC)}
y_map = {idx: it for idx,it in enumerate(yC)}

In [10]:
df['zCategorical'] = df['z'].map(z_map)
df['yCategorical'] = df['y'].map(y_map)

In [11]:
df.groupby(["zCategorical","yCategorical"]).size()

zCategorical  yCategorical
Hosp_A        benign          1431
              malignant        855
              none             584
Hosp_B        benign           645
              malignant        836
              none             609
Hosp_C        benign           785
              none             235
Hosp_D        benign          2844
              malignant        786
              none             390
dtype: int64

In [34]:
n_zC = len(zC)
n_yC = len(yC)

In [35]:
assert p_train_y_given_z.shape == (n_zC, n_yC)

## Split

In [36]:
n_test = 150
n_error = 0

n_train = n_test * 4  # train:test ratio

In [37]:
p_test_y_given_z

array([[0. , 0.7, 0.3],
       [0.1, 0.4, 0.5],
       [0.7, 0.3, 0. ],
       [0.2, 0.4, 0.4]])

In [38]:
p_z

array([0.2, 0.2, 0.3, 0.3])

In [39]:
(p_z * n_test).round(0)

array([30., 30., 45., 45.])

In [40]:
(p_z * n_test).round(0).repeat(n_yC).reshape(n_zC, n_yC)

array([[30., 30., 30.],
       [30., 30., 30.],
       [45., 45., 45.],
       [45., 45., 45.]])

In [41]:
(p_z * n_test).round(0).repeat(3).reshape(4,3)

array([[30., 30., 30.],
       [30., 30., 30.],
       [45., 45., 45.],
       [45., 45., 45.]])

In [42]:
p_test_y_given_z

array([[0. , 0.7, 0.3],
       [0.1, 0.4, 0.5],
       [0.7, 0.3, 0. ],
       [0.2, 0.4, 0.4]])

In [43]:
n_test_y_given_z = (p_test_y_given_z * (p_z * n_test).round(0).repeat(n_yC).reshape(n_zC, n_yC)).round(0)

In [44]:
n_train_y_given_z = (p_train_y_given_z * (p_z * n_train).round(0).repeat(n_yC).reshape(n_zC, n_yC)).round(0)

In [45]:
n_train_y_given_z

array([[ 12.,  12.,  96.],
       [ 24.,   0.,  96.],
       [126.,  54.,   0.],
       [ 54.,  54.,  72.]])

In [46]:
n_test_y_given_z

array([[ 0., 21.,  9.],
       [ 3., 12., 15.],
       [31., 14.,  0.],
       [ 9., 18., 18.]])

In [47]:
assert (n_test - n_error) <= n_test_y_given_z.sum() <= (n_test + n_error)

In [51]:
full_df_list = []
full_shape_ls = []
for iz, iy in itertools.product(zC, yC):
    _df = df[(df['zCategorical'] == iz) & (df['yCategorical'] == iy)]
    
    full_shape_ls.append(len(_df))
    full_df_list.append(_df)

In [52]:
np.array(full_shape_ls).reshape(n_zC, n_yC)

array([[ 584, 1431,  855],
       [ 609,  645,  836],
       [ 235,  785,    0],
       [ 390, 2844,  786]])

In [53]:
assert np.all((n_test_y_given_z + n_train_y_given_z) <= np.array(full_shape_ls).reshape(n_zC, n_yC))

In [54]:
n_test_y_given_z

array([[ 0., 21.,  9.],
       [ 3., 12., 15.],
       [31., 14.,  0.],
       [ 9., 18., 18.]])

In [55]:
n_test_y_given_z.flatten()[7]

14.0

In [56]:
from sklearn.model_selection import train_test_split

In [57]:
full_df_list[0]['yCategorical'].unique().tolist()

['none']

In [58]:
n_test_y_given_z

array([[ 0., 21.,  9.],
       [ 3., 12., 15.],
       [31., 14.,  0.],
       [ 9., 18., 18.]])

In [59]:
p_test_y_given_z

array([[0. , 0.7, 0.3],
       [0.1, 0.4, 0.5],
       [0.7, 0.3, 0. ],
       [0.2, 0.4, 0.4]])

In [60]:
seed = 193
df_collect = []
for idx, _df in enumerate(full_df_list):
    n_needed_train = int(n_train_y_given_z.flatten()[idx])
    n_needed_test =  int(n_test_y_given_z.flatten()[idx])
    
    if (n_needed_test == 0) or (n_needed_train == 0):
        _df_train = _df_test = None
    else:
        
        _df_train, _df_test = train_test_split(_df,
                                               train_size=n_needed_train,
                                               test_size=n_needed_test,
                                               shuffle = True, random_state=seed
                                              )
    
    _ret = {"df_train":_df_train, 
            "df_test":_df_test, 
            "y":_df["yCategorical"].unique().tolist(), 
            "z":_df["zCategorical"].unique().tolist()
           }
    
    df_collect.append(_ret)
    

In [61]:
df_collect[1]['df_train']

Unnamed: 0,z,y,zCategorical,yCategorical
6441,0,1,Hosp_A,benign
6076,0,1,Hosp_A,benign
190,0,1,Hosp_A,benign
2807,0,1,Hosp_A,benign
5327,0,1,Hosp_A,benign
2109,0,1,Hosp_A,benign
7037,0,1,Hosp_A,benign
5396,0,1,Hosp_A,benign
3441,0,1,Hosp_A,benign
2640,0,1,Hosp_A,benign


In [62]:
df_collect[1]['df_test']

Unnamed: 0,z,y,zCategorical,yCategorical
2658,0,1,Hosp_A,benign
8179,0,1,Hosp_A,benign
1798,0,1,Hosp_A,benign
8345,0,1,Hosp_A,benign
2225,0,1,Hosp_A,benign
1794,0,1,Hosp_A,benign
9296,0,1,Hosp_A,benign
4913,0,1,Hosp_A,benign
8080,0,1,Hosp_A,benign
857,0,1,Hosp_A,benign


## Assemble Function

In [12]:

p_train_y_given_z = [[0.1, 0.1, 0.8], 
                     [0.2, 0,   0.8], 
                     [0.7, 0.3, 0  ],
                     [0.3, 0.3, 0.4]
                    ]
p_test_y_given_z = [[0.0, 0.7, 0.3], 
                     [0.1, 0.4, 0.5], 
                     [0.7, 0.3, 0  ],
                     [0.2, 0.4, 0.4]
                    ]
p_z = [0.2, 0.2, 0.3, 0.3]

In [13]:
zC = ['Hosp_A','Hosp_B','Hosp_C', 'Hosp_D']
yC = ['none','benign', 'malignant']

In [17]:
from sklearn.model_selection import train_test_split


def confoundSplitDFMultiLevel(df, zC, yC, p_train_y_given_z, p_test_y_given_z, p_z, 
                              n_test=100, n_error=0,
                              train_test_ratio=4, seed=2671):
    
    
    # convert to np.ndarray
    if isinstance(p_train_y_given_z, list):
        p_train_y_given_z = np.array(p_train_y_given_z)

    if isinstance(p_test_y_given_z, list):
        p_test_y_given_z = np.array(p_test_y_given_z)

    if isinstance(p_z, list):
        p_z = np.array(p_z)
    
    
    # quality check
    assert np.all(p_train_y_given_z.sum(axis=1) == 1)
    assert np.all(p_test_y_given_z.sum(axis=1) == 1)
    assert np.all(p_z.sum() == 1)

    assert p_train_y_given_z.shape == p_test_y_given_z.shape

    # get number of categories for Y and Z
    n_zC = len(zC)
    n_yC = len(yC)

    assert p_train_y_given_z.shape == (n_zC, n_yC)
    
    
    
    # calculate number for training and testing sets given probabilities
    n_train = n_test * train_test_ratio  # train:test ratio
    
    n_train_y_given_z = (p_train_y_given_z * (p_z * n_train).round(0).repeat(n_yC).reshape(n_zC, n_yC)).round(0)
    n_test_y_given_z  = (p_test_y_given_z  * (p_z * n_test ).round(0).repeat(n_yC).reshape(n_zC, n_yC)).round(0)
    
    TESTER_n_test = (n_test - n_error) <= n_test_y_given_z.sum() <= (n_test + n_error)
    if not TESTER_n_test:
        return None
    
    
    # calculate how many examples the original data have, for each z,y combination
    # and create sub-df's for each combination
    full_df_list = []
    full_shape_ls = []
    for iz, iy in itertools.product(zC, yC):
        _df = df[(df['zCategorical'] == iz) & (df['yCategorical'] == iy)]

        full_shape_ls.append(len(_df))
        full_df_list.append(_df)

    TESTER_full_shape = np.all((n_test_y_given_z + n_train_y_given_z) <= np.array(full_shape_ls).reshape(n_zC, n_yC))
    if not TESTER_full_shape:
        return None
    
    
    # iterate through every sub-df, get train-test split
    df_collect = []
    for idx, _df in enumerate(full_df_list):
        n_needed_train = int(n_train_y_given_z.flatten()[idx])
        n_needed_test =  int(n_test_y_given_z.flatten()[idx])

        if (n_needed_test == 0) or (n_needed_train == 0):
            _df_train = _df_test = None
        else:

            _df_train, _df_test = train_test_split(_df,
                                                   train_size=n_needed_train,
                                                   test_size=n_needed_test,
                                                   shuffle=True, random_state=seed
                                                  )

        _ret = {"df_train":_df_train, 
                "df_test":_df_test, 
                "y":_df["yCategorical"].unique().tolist(), 
                "z":_df["zCategorical"].unique().tolist()
               }

        df_collect.append(_ret)

    return df_collect
    

In [18]:
df_ret = confoundSplitDFMultiLevel(df=df, zC=zC, yC=yC, p_train_y_given_z=p_train_y_given_z, p_test_y_given_z=p_test_y_given_z, p_z=p_z, 
                              n_test=150, n_error=0,
                              train_test_ratio=4, seed=193)

In [71]:
len(df_collect)

12

In [19]:
len(df_ret)

12

In [None]:
df_collect[2]

{'df_train':       z  y zCategorical yCategorical
 4756  0  2       Hosp_A    malignant
 1849  0  2       Hosp_A    malignant
 8084  0  2       Hosp_A    malignant
 850   0  2       Hosp_A    malignant
 5845  0  2       Hosp_A    malignant
 ...  .. ..          ...          ...
 1046  0  2       Hosp_A    malignant
 9117  0  2       Hosp_A    malignant
 7739  0  2       Hosp_A    malignant
 6447  0  2       Hosp_A    malignant
 1504  0  2       Hosp_A    malignant
 
 [96 rows x 4 columns],
 'df_test':       z  y zCategorical yCategorical
 39    0  2       Hosp_A    malignant
 7569  0  2       Hosp_A    malignant
 7678  0  2       Hosp_A    malignant
 1004  0  2       Hosp_A    malignant
 3551  0  2       Hosp_A    malignant
 7560  0  2       Hosp_A    malignant
 9734  0  2       Hosp_A    malignant
 6378  0  2       Hosp_A    malignant
 6699  0  2       Hosp_A    malignant,
 'y': ['malignant'],
 'z': ['Hosp_A']}

In [77]:
df_ret[2]

{'df_train':       z  y zCategorical yCategorical
 4756  0  2       Hosp_A    malignant
 1849  0  2       Hosp_A    malignant
 8084  0  2       Hosp_A    malignant
 850   0  2       Hosp_A    malignant
 5845  0  2       Hosp_A    malignant
 ...  .. ..          ...          ...
 1046  0  2       Hosp_A    malignant
 9117  0  2       Hosp_A    malignant
 7739  0  2       Hosp_A    malignant
 6447  0  2       Hosp_A    malignant
 1504  0  2       Hosp_A    malignant
 
 [96 rows x 4 columns],
 'df_test':       z  y zCategorical yCategorical
 39    0  2       Hosp_A    malignant
 7569  0  2       Hosp_A    malignant
 7678  0  2       Hosp_A    malignant
 1004  0  2       Hosp_A    malignant
 3551  0  2       Hosp_A    malignant
 7560  0  2       Hosp_A    malignant
 9734  0  2       Hosp_A    malignant
 6378  0  2       Hosp_A    malignant
 6699  0  2       Hosp_A    malignant,
 'y': ['malignant'],
 'z': ['Hosp_A']}

In [82]:
i['df_train']

In [86]:
df_ret[0]['df_train']

In [88]:
for i in df_ret:
    if i['df_train'] is not None:
        print(len(i['df_train']), len(i['df_test']))

12 21
96 9
24 3
96 15
126 31
54 14
54 9
54 18
72 18


In [89]:
for i in df_collect:
    if i['df_train'] is not None:
        print(len(i['df_train']), len(i['df_test']))

12 21
96 9
24 3
96 15
126 31
54 14
54 9
54 18
72 18
