In [None]:
import itertools
import random
import numpy as np
import pickle
import plot_utils as plu

##  Compute 3-point correlation 
can be long because it was not optimized nor parallelized

In [None]:
print("* Computing 3-point correlation for different gap values...")

In [None]:
def get_counts(haplosubset, points): 
    counts = np.unique(
        np.apply_along_axis(
            lambda x: ''.join(map(str, x[points])),
            #lambda x: ''.join([str(x[p]) for p in points]),
            0, haplosubset),
        return_counts=True)
    return(counts)

def get_frequencies(counts):
    l = len(counts[0][0]) # haplotype length
    nind = np.sum(counts[1])
    f = np.zeros(shape=[2]*l)
    for i,allele in enumerate(counts[0]):
        f[tuple(map(int, allele))] = counts[1][i]/nind
    return f

def three_points_cor(haplosubset, out='all'):
    F = dict()
    for points in [[0],[1],[2],[0,1],[0,2],[1,2],[0,1,2]]:
        strpoints = ''.join(map(str, points))
        F[strpoints] = get_frequencies(
            get_counts(haplosubset, points)
        )
            
    cors = [F['012'][a,b,c] - F['01'][a,b]*F['2'][c] - F['12'][b,c]*F['0'][a] - F['02'][a,c]*F['1'][b] + 2*F['0'][a]*F['1'][b]*F['2'][c] for a,b,c in itertools.product(*[[0,1]]*3)]
    if out=='mean':
        return(np.mean(cors))
    if out=='max':
        return(np.max(np.abs(cors)))
    if out=='all':
        return(cors)
    return(ValueError(f"out={out} not recognized"))

#def mult_three_point_cor(haplo, sampleinfo, cat, picked_three_points):
#    return [three_points_cor(haplo[np.ix_(snps,sampleinfo.label==cat)], out='all') for snps in picked_three_points]

In [None]:
# set the seed so that the same real individual are subsampled (when needed) 
# to ensure consistency of the scores when adding a new model or a new sumstat
np.random.seed(3)
random.seed(3)

In [None]:
# Compute 3 point correlations results for different datasets and different distances between SNPs

# pick distance between SNPs at which 3point corr will be computed 
# (defined in nb of snps)
# a gap of -9 means that snp triplets are chosen completely at random (not predefined distance)
# for each category we randomly pick 'nsamplesets' triplets

# if datasets have different nb of snps, for convenience we will sample 
# slightly more at the beginning of the chunk 

gap_vec = [1,4,16,64,256,512,1024,-9]
nsamplesets=1000
min_nsnp = min([dat.shape[1] for dat in datasets.values()]) 
cors_meta=dict()
for gap in gap_vec:
    print(f'\n gap={gap} SNPs', end=' ')
    if gap<0:
        # pick 3 random snps
        picked_three_points = [random.sample(range(min_nsnp),3) for  _ in range(nsamplesets)]
    else:
        try:
            # pick 3 successive snps spearated by 'gap' SNPs
            step = gap+1
            picked_three_points = [np.asarray(random.sample(range(min_nsnp-2*step),1))+[0,step,2*step] for  _ in range(nsamplesets)]
        except:
            continue # if there were not enough SNPs for this gap
    cors=dict()
    
    for cat in infiles.keys():
        print(cat, end=' ')
        #cors[cat]=[three_points_cor(haplo[np.ix_(snps,sampleinfo.label==cat)], out='all') for snps in picked_three_points]
        cors[cat]=[three_points_cor(datasets[cat][:,snps].T, out='all') for snps in picked_three_points]

    cors_meta[gap] = cors.copy()

In [None]:
# print(cors_meta)

In [None]:
with open(outDir+"3pointcorr.pkl", "wb") as outfile:
    pickle.dump(cors_meta, outfile)

In [None]:
# Plot 3-point correlations results

plt.figure(figsize=(2*len(cors_meta),7))
#plt.figure(figsize=(figwi,figwi/2))
for i,gap in enumerate((cors_meta).keys()):
    ax = plt.subplot(2, np.ceil(len(cors_meta)/2),i+1)
    cors = cors_meta[gap]
    real = list(np.array(cors['Real']).flat)
    lims = [np.min(real), np.max(real)]
    for key, val in cors.items():
        if key=='Real': continue
        val = list(np.array(val).flat) 
        plu.plotreg(x=real, y=val, keys=['Real',key], 
                    statname='Correlation', col=colpal[key], ax=ax)      
    if gap<0:
        plt.title('3-point corr for random SNPs')
    else:
        plt.title(f'3-point corr for SNPs sep. by {gap} SNPs')

    plt.legend(fontsize='small')
plt.tight_layout()
plt.savefig(outDir+'3point_correlations.jpg',dpi=300) # can pick one of the format
plt.savefig(outDir+'3point_correlations.png',dpi=300)
plt.savefig(outDir+'3point_correlations.pdf',dpi=300)

In [None]:
# Same plot with axes limit fixed to (-0.1,0.1) for the sake of comparison

plt.figure(figsize=(4*len(cors_meta),14))
#plt.figure(figsize=(figwi,figwi/2))
for i,gap in enumerate((cors_meta).keys()):
    ax = plt.subplot(2, np.ceil(len(cors_meta)/2),i+1)
    cors = cors_meta[gap]
    real = list(np.array(cors['Real']).flat)
    lims = [np.min(real), np.max(real)]
    for key, val in cors.items():
        if key=='Real': continue
        val = list(np.array(val).flat)
        plu.plotreg(x=real, y=val, keys=['Real',key], 
                    statname='Correlation', col=colpal[key], ax=ax)
        ax.set_xlim((-.1,.1))
        ax.set_ylim((-.1,.1))

    if gap<0:
        plt.title('3-point corr for random SNPs')
    else:
        plt.title(f'3-point corr for SNPs sep. by {gap} SNPs')

    plt.legend(fontsize='small')
plt.tight_layout()

plt.savefig(outDir+'3point_correlations_fixlim.pdf',dpi=300)
plt.savefig(outDir+'3point_correlations_fixlim.png',dpi=300)
plt.savefig(outDir+'3point_correlations_fixlim.jpg',dpi=300)

In [None]:
print('************************************************************************\n*** Computation and plotting 3-point cor DONE. Figures saved in {} ***\n************************************************************************'.format(outDir))