In [None]:
# default_exp doublediff_analysis

In [None]:
import alphaquant.background_distributions as aqbg
import numpy as np


def get_z_and_firstterm_variance(ions1, ions2, normed_c1, normed_c2, ion2diffDist, p2z, ionpair2doublediffdist):
    firstterm_variance = 0
    all_ionpairs = []
    ion2pairs = {}
    ionpair2idx_ols = {}

    nrep_c1 = normed_c1.ion2allvals.get(ions1[0])
    nrep_c2 = normed_c2.ion2allvals.get(ions2[0])
    
    for ion1 in ions1:
        ion1_c1_ints = normed_c1.ion2allvals.get(ion1)
        ion1_c2_ints = normed_c2.ion2allvals.get(ion1)
        for ion2 in ions2:
            ion2_c1_ints = normed_c1.ion2allvals.get(ion2)
            ion2_c2_ints = normed_c2.ion2allvals.get(ion2)

            #account for missing values: ion1 and ion2 values are only compared within the same sample -> filter for intensities that occur in the same sample for both ions
            overlapping_c1_idx = [x if (~np.isnan(ion1_c1_ints[x])) & (~np.isnan(ion2_c1_ints[x])) for x in range(nrep_c1)]
            nrep_ol_c1 = len(overlapping_c1_idx)
            if nrep_ol_c1 ==0:
                continue

            overlapping_c2_idx = [x if (~np.isnan(ion1_c2_ints[x])) & (~np.isnan(ion2_c2_ints[x])) for x in range(nrep_c2)]
            nrep_ol_c2 = len(overlapping_c2_idx)
            if nrep_ol_c2 ==0:
                continue

            #collection information for later variance calculation
            ionpair = [ion1, ion2]
            all_ionpairs.append(ionpair)
            ion2pairs[ion1] = ion2pairs.get(ion1, []) + [ionpair]
            ion2pairs[ion2] = ion2pairs.get(ion2, []) + [ionpair]
            ionpair2idx_ols[ionpair] = [overlapping_c1_idx, overlapping_c2_idx]

            #define all empirical error distributions (eed) and differential empirical error distributions (deed) and obtain the variance 
            eed_ion1_c1 = normed_c1.ion2background.get(ion1).var
            eed_ion1_c2 = normed_c2.ion2background.get(ion1).var

            eed_ion2_c1 = normed_c1.ion2background.get(ion2).var
            eed_ion2_c2 = normed_c2.ion2background.get(ion2).var

            deed_ion1 = aqbg.get_subtracted_bg(ion2diffDist,normed_c1, normed_c2,ion1, p2z).var
            deed_ion2 = aqbg.get_subtracted_bg(ion2diffDist,normed_c1, normed_c2,ion2, p2z).var

            #calculate the ionpair total variance as shown in Berchtold et al. EmpiReS

            ionpair_variance = (nrep_ol_c1 * nrep_ol_c2 *(deed_ion1 + deed_ion2) + nrep_ol_c1 *nrep_ol_c2 *(nrep_ol_c2-1) * (eed_ion1_c1 + eed_ion2_c1) +
             nrep_ol_c1 *nrep_ol_c2 *(nrep_ol_c1-1) * (eed_ion1_c2 + eed_ion2_c2))/(deed_ion1 + deed_ion2)

            firstterm_variance += ionpair_variance
    

def calc_per_peppair_z(overlapping_c1_idx, overlapping_c2_idx, ion1_c1_ints, ion1_c2_ints, ion2_c1_ints, ion2_c2_ints, ddeed_ion1_ion2):
    
    for idx1 in overlapping_c1_idx:
        for for idx2 in overlapping_c2_idx:
            fc_ion1 = ion1_c1_ints[idx1] - ion1_c2_ints[idx2]
            fc_ion2 = ion2_c1_ints[idx1] - ion2_c2_ints[idx2]
            fcfc_ion12 = fc_ion1 - fc_ion2
            zval = ddeed_ion1_ion2.calc_zscore_from_fc(fcfc_ion12)




def calculate_pairpair_overlap_factor(all_ionpairs, ion2pairs, ionpair2idx_ols, normed_c1, normed_c2, ion2diffdist, p2z):
    
    secondterm_variance = 0
    
    for ionpair in all_ionpairs:
        for ion in ionpair:
            compare_pairs = ion2pairs.get(ion)
            compare_pairs.remove(ionpair)
            
            for comp_ionpair in compare_pairs:

                comp_ion = comp_ionpair[0]

                idxs_ionpair = ionpair2idx_ols.get(ionpair)
                idxs_comp_ionpair = ionpair2idx_ols.get(comp_ionpair)
                
                n_sameidx_first = len(set(idxs_ionpair[0]).intersection(set(idxs_comp_ionpair[0])))
                n_sameidx_second = len(set(idxs_ionpair[1]).intersection(set(idxs_comp_ionpair[1])))

                deed1 = aqbg.get_subtracted_bg(ion2diffdist, normed_c1, normed_c2, ion, p2z)
                deed2 = aqbg.get_subtracted_bg(ion2diffdist, normed_c1, normed_c2, comp_ion, p2z)

                eed_ion_c1 = normed_c1.ion2background.get(ion)
                eed_ion_c2 = normed_c2.ion2background.get(ion)

                correlation_normfact = deed1.SD * deed2.SD
                
                var_overlap = len(idxs_ionpair[1])*len(idxs_comp_ionpair[1]) * n_sameidx_first * eed_ion_c1.var + len(idxs_ionpair[0])*len(idxs_comp_ionpair[0]) * n_sameidx_second * eed_ion_c2.var
                secondterm_variance += var_overlap/correlation_normfact
                
    return secondterm_variance