In [40]:
import glob
import pandas as pd
import itertools
from scipy import stats

### Get a list of TFs, controls, and case sample names

In [41]:
product_files = sorted(glob.glob('**/*product.parquet', recursive=True))
combo_files = sorted(glob.glob('**/*combinations.parquet', recursive=True))

df = pd.read_parquet(product_files[0])
ctrls = df.query('sample_1.str.contains("HV")', engine='python')
ctrls = ctrls['sample_1'].unique().tolist()
tfs = df['TF'].unique().tolist()


all_case_samples = []
for i in range(len(product_files)):
    df = pd.read_parquet(product_files[i])
    samples = df['sample_1'].unique().tolist()
    case_samples = [x for x in samples if x not in ctrls]
    all_case_samples.extend(case_samples)

### KS functions

In [42]:
def get_distance_lists(grouped_cases, ctrls):
    all_group_distances = []
    for group in grouped_cases:
        group_distances = []
        for sample in group:
            if sample == "OMD1_PreRT":
                case = 'OMD001'
            elif sample in HN:
                case = sample[:5]
            elif sample in OMD:
                case = sample[:6]
            else:
                case = sample

            sample_distances = []
            product_df = pd.read_parquet(f'{case}/{case}_product.parquet')
            for tf in tfs:
                distances = product_df.query(f"TF  == '{tf}' & sample_1 == '{sample}' & sample_2 in {ctrls} ")['distance'].to_numpy()
                sample_distances.extend(distances)
            group_distances.extend(sample_distances)
        all_group_distances.append(group_distances)

        
    all_control_distnaces = []
    for ctrl in ctrls:
        ctrl_distances = []
        for tf in tfs:
            other_ctrls = [x for x in ctrls if x != ctrl]
            combinations_df = pd.read_parquet(f'{case}/{case}_combinations.parquet') # which case not important - all have same ctrl-ctrl distances
            distances = product_df.query(f"TF  == '{tf}' & sample_1 == '{ctrl}' & sample_2 in {other_ctrls} ")['distance'].to_numpy()
            ctrl_distances.extend(distances)
        all_control_distnaces.extend(ctrl_distances)

    all_group_distances.append(all_control_distnaces)
        
    return all_group_distances
        
        

In [43]:
def get_ks_df(group_names, grouped_cases, ctrls):
    
    all_group_distances = get_distance_lists(grouped_cases, ctrls)
    
    name_combos = list(itertools.combinations(group_names, 2))
    distance_combos = list(itertools.combinations(all_group_distances, 2))
    
    compared_groups = []
    group_pvalues = []
    group_Dstats = []
    for i in range(len(name_combos)):
        paired_groups = list(name_combos[i])
        paired_arrays = distance_combos[i]
        g1_array = paired_arrays[0]
        g2_array = paired_arrays[1]
        less = [['HN', 'CTRL'], ['OMD', 'CTRL'], ['PV', 'CTRL']]
        if paired_groups in less:
            D, p =stats.ks_2samp(g1_array, g2_array, alternative='less')
        else:
            D, p =stats.ks_2samp(g1_array, g2_array, alternative='greater')

        correction =  len(name_combos)
        adj_p = p * correction
        group_pvalues.append(p)
        group_Dstats.append(D)
        compared_groups.append(paired_groups)

    group_ks_lists = []
    for i in range(len(compared_groups)):
        group_ks_list = [compared_groups[i][0], compared_groups[i][1], group_Dstats[i], group_pvalues[i]]
        group_ks_lists.append(group_ks_list)

    group_ks_df = pd.DataFrame(group_ks_lists, columns = ['Group 1', 'Group 2', 'D Statistic', 'P-value'])
    
    return group_ks_df

### Group results

#### Baseline

In [44]:
HN = sorted([sample for sample in all_case_samples if sample.endswith('BL') and sample.startswith('HN')] \
         +  [sample for sample in all_case_samples if sample.endswith('Pre') and sample.startswith('HN')] )
OMD = sorted([sample for sample in all_case_samples if sample.endswith('BL') and sample.startswith('OMD')] \
         +  ['OMD1_PreRT'] )
PV = sorted([sample for sample in all_case_samples if sample.startswith('PV')])
grouped_cases = [HN, OMD, PV]
group_names = ["HN", "OMD", "PV", "CTRL"]

group_ks_df = get_ks_df(group_names, grouped_cases, ctrls)
group_ks_df

Unnamed: 0,Group 1,Group 2,D Statistic,P-value
0,HN,OMD,0.154519,0.0002718289
1,HN,PV,0.29191,1.522483e-14
2,HN,CTRL,0.293489,8.014207e-13
3,OMD,PV,0.204446,1.790156e-07
4,OMD,CTRL,0.426628,1.127487e-26
5,PV,CTRL,0.556973,8.838014e-49


#### Treatment 1

In [45]:
HN = sorted([sample for sample in all_case_samples if sample.endswith('d8')]) 
OMD = sorted([sample for sample in all_case_samples if sample.endswith('d2')])
grouped_cases = [HN, OMD]
group_names = ["HN", "OMD", "CTRL"]

group_ks_df = get_ks_df(group_names, grouped_cases, ctrls)
group_ks_df

Unnamed: 0,Group 1,Group 2,D Statistic,P-value
0,HN,OMD,0.28863,2.69696e-13
1,HN,CTRL,0.490768,1.973806e-35
2,OMD,CTRL,0.439261,2.7523540000000002e-28


#### Treatment 2

In [46]:
HN = sorted([sample for sample in all_case_samples if sample.endswith('d49')] \
            + [sample for sample in all_case_samples if sample.endswith('d43')] \
            + [sample for sample in all_case_samples if sample.endswith('D43')])
OMD = sorted([sample for sample in all_case_samples if sample.endswith('d5')] \
            + [sample for sample in all_case_samples if sample.endswith('d10')])
grouped_cases = [HN, OMD]
group_names = ["HN", "OMD", "CTRL"]

group_ks_df = get_ks_df(group_names, grouped_cases, ctrls)
group_ks_df

Unnamed: 0,Group 1,Group 2,D Statistic,P-value
0,HN,OMD,0.212828,1.624726e-07
1,HN,CTRL,0.590865,3.945306e-52
2,OMD,CTRL,0.690962,4.291952e-73


#### Follow-up 1

In [47]:
HN = sorted([sample for sample in all_case_samples if sample.endswith('3m') and sample.startswith('HN')])
OMD = sorted([sample for sample in all_case_samples if sample.endswith('3m') and sample.startswith('OMD')])
grouped_cases = [HN, OMD]
group_names = ["HN", "OMD", "CTRL"]

group_ks_df = get_ks_df(group_names, grouped_cases, ctrls)
group_ks_df

Unnamed: 0,Group 1,Group 2,D Statistic,P-value
0,HN,OMD,0.005831,0.9884226
1,HN,CTRL,0.422741,3.446825e-26
2,OMD,CTRL,0.344995,1.6101610000000002e-17
