In [1]:
import itertools
import pickle
import warnings
from pathlib import Path

from scipy.stats import binned_statistic
from scipy.stats import sem

from utils.util import *



KeyboardInterrupt



In [None]:
models_to_test = {
    "Real": {
        "path": "resource/10K_SNP_1000G_real.hapt",
        "color": "gray"
    },
    "GAN_prev": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_SNP_GAN_AG_10800Epochs.hapt",
        "color": "blue"
    },
    "RBM_prev": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_RBM.hapt",
        "color": "red"
    },
    "RBM_new": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_SNP_RBM_AG_1050epochs.hapt",
        "color": "orange"
    },
    "WGAN": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_SNP_RBM_AG_1050epochs.hapt",
        "color": "purple"
    },
    "DEFAULT GS-AC-GAN": {
        "path": "fake_genotypes_sequences/new_sequences/default_gs-ac-gan/genotypes.hapt",
        "color": "brown"
    },
    "F1 WITH PENALTY GS-AC-GAN": {
        "path": "fake_genotypes_sequences/new_sequences/f1_score_test/genotypes.hapt",
        "color": "green"
    },
    "SUB F1 WITH PENALTY GS-AC-GAN, ": {
        "path": "fake_genotypes_sequences/new_sequences/f1_score_sub_pop/genotypes.hapt",
        "color": "blue"
    },
}

In [None]:
NUMBER_OF_SAMPLES = 500
output_dir = os.environ.get("output_dir", DEFAULT_EXPERIMENT_OUTPUT_DIR)
Path(output_dir).mkdir(parents=True, exist_ok=True)
compute_AATS = True

In [None]:
model_name_to_input_file, model_name_to_color, color_palette = init_analysis_args(output_dir, models_to_test)

In [None]:
extra_sample_info, sample_info, datasets, transformations, model_keep_all_snps = load_analysis_data(number_of_samples=NUMBER_OF_SAMPLES, model_name_to_input_file=model_name_to_input_file)
sum_alleles_by_position, allele_frequency, is_fixed = build_allele_frequency(datasets)

In [None]:
figwi = 12
plt.figure(figsize=(15, 5 * len(model_name_to_input_file.keys())))
l, c = len(model_name_to_input_file.keys()) - 1, 2
plt.figure(figsize=(figwi * c / 4, figwi * l / 4))
win = 1
for i, model_name in enumerate(model_name_to_input_file.keys()):
    if model_name == 'Real': continue
    plt.subplot(l, c, win)
    win += 1
    plt.plot(allele_frequency['Real'][(sum_alleles_by_position[model_name] == 0)], alpha=1, marker='.', lw=0)
    plt.ylabel("Allele frequency in Real")
    plt.title("Real frequency of alleles \n absent from {}".format(model_name))
    plt.subplot(l, c, win)
    win += 1
    plt.hist(allele_frequency['Real'][(sum_alleles_by_position[model_name] == 0)], alpha=1)
    plt.title("Hist real freq of alleles \n absent from {}".format(model_name))

plt.suptitle("Plotting allele frequency characteristics \n\n")
plt.tight_layout()
plt.savefig(os.path.join(output_dir, "RealAC_for_0fixed_sites.pdf"))

In [None]:
# Plotting Allele frequencies in Generated vs Real
# below a certain real frequency (here set to 0.2 ie 20%)
l, c = np.ceil(len(allele_frequency) / 3), 3
plt.figure(figsize=(figwi, figwi * l / c))
maf = 0.2
keep = (allele_frequency['Real'] <= maf)
for i, (model_name, val) in enumerate(allele_frequency.items()):
    ax = plt.subplot(int(l), c, i + 1)
    plotreg(x=allele_frequency['Real'][keep], y=val[keep],
            keys=['Real', model_name], statname="Allele frequency",
            col=color_palette[model_name], ax=ax)
    plt.title(f'{model_name} vs Real')
plt.suptitle(f'AF below {maf} in Real \n\n')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'AC_generated_vs_Real_zoom.pdf'))

In [None]:
# Plotting Allele frequencies in Generated vs Real
l, c = np.ceil(len(allele_frequency) / 3), 3
plt.figure(figsize=(figwi, figwi * l / c))
for i, (model_name, val) in enumerate(allele_frequency.items()):
    ax = plt.subplot(int(l), 3, i + 1)
    plotreg(x=allele_frequency['Real'], y=val,
            keys=['Real', model_name], statname="Allele frequency",
            col=color_palette[model_name], ax=ax)
    plt.title(f'Allele Frequencies {model_name} vs Real')
plt.suptitle(f'Allele Frequencies vs Real \n\n')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, 'AC_generated_vs_Real.pdf'))

In [None]:

score_list = []
number_of_components = 6  # change to compute more PCs

method_name = "Combined PCA"
print(f'Computing {method_name} ...')
pca = PCA(n_components=number_of_components)
pcs = pca.fit_transform(
    np.concatenate(list(datasets.values()))
)
pcdf = pd.DataFrame(pcs, columns=["PC{}".format(x + 1) for x in np.arange(pcs.shape[1])])
pcdf["label"] = extra_sample_info.label.astype('category')
plotPCAallfigs(pcdf, method_name, orderedCat=model_name_to_input_file.keys(), output_dir=output_dir,
               colpal=color_palette)
plt.suptitle("PCA Comparison")
plt.show()

In [None]:
print("* Computing and plotting LD...")
#### Compute correlation between all pairs of SNPs for each generated/real dataset

model_names = model_name_to_input_file.keys()
hcor_snp = dict()
for i, model_name in enumerate(model_names):
    print(model_name)
    with np.errstate(divide='ignore', invalid='ignore'):
        # Catch warnings due to fixed sites in dataset (the correlation value will be np.nan for pairs involving these sites)
        hcor_snp[model_name] = np.corrcoef(datasets[model_name], rowvar=False) ** 2  # r2

_, region_len, snps_on_same_chrom = get_dist(f"../{REAL_POSITION_FILE_NAME}", region_len_only=True,
                                             kept_preprocessing=model_keep_all_snps['Real'])

nbins = 50
logscale = True
bins = nbins
binsPerDist = nbins
if logscale: binsPerDist = np.logspace(np.log(1), np.log(region_len), nbins)

# Compute LD binned by distance
# Take only sites that are SNPs in all datasets (intersect)
# (eg intersection of SNPs in Real, SNPs in GAN, SNPs in RBM etc)
# -> Makes sense only if there is a correspondence between sites

binnedLD = dict()
binnedPerDistLD = dict()
kept_snp = ~is_fixed
n_kept_snp = np.sum(kept_snp)
realdist = get_dist(f"../{REAL_POSITION_FILE_NAME}", kept_preprocessing=model_keep_all_snps['Real'],
                    kept_snp=kept_snp)[0]
mat = hcor_snp['Real']
# filter and flatten
flatreal = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
isnanReal = np.isnan(flatreal)
i = 1
plt.figure(figsize=(10, len(hcor_snp) * 5))

for model_name, mat in hcor_snp.items():
    flathcor = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
    isnan = np.isnan(flathcor)
    curr_dist = realdist

    # For each dataset LD pairs are stratified by SNP distance and cut into 'nbins' bins
    # bin per SNP distance
    ld = binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic='mean', bins=binsPerDist)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)  # so that empty bins do not raise a warning
        binnedPerDistLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                                    'LD': ld.statistic,
                                                    # 'sd': binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic = 'std', bins=binsPerDist).statistic,
                                                    'sem': binned_statistic(curr_dist[~isnan], flathcor[~isnan],
                                                                            statistic=sem,
                                                                            bins=binsPerDist).statistic,
                                                    'model_name': model_name, 'logscale': logscale})

    # For each dataset LD pairs are stratified by LD values in Real and cut into 'nbins' bins
    # binnedLD contains the average, std of LD values in each bin
    isnan = np.isnan(flathcor) | np.isnan(flatreal)
    ld = binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic='mean', bins=bins)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)  # so that empty bins do not raise a warning
        binnedLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                             'LD': ld.statistic,
                                             'sd': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                    statistic='std',
                                                                    bins=bins).statistic,
                                             'sem': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                     statistic=sem,
                                                                     bins=bins).statistic,
                                             'model_name': model_name, 'logscale': logscale})

    # Plotting quantiles ?
    plotregquant(x=flatreal, y=flathcor,
                 keys=['Real', model_name], statname='LD', col=color_palette[model_name],
                 step=0.05,
                 ax=plt.subplot(len(hcor_snp), 2, i))
    i += 1
    plt.title(f'Quantiles LD {model_name} vs Real')

    # removing nan values and subsampling before doing the regression to have a reasonnable number of points
    isnanInter = isnanReal | isnan
    keepforplotreg = random.sample(list(np.where(~isnanInter)[0]), NUMBER_OF_SAMPLES)
    plotreg(x=flatreal[keepforplotreg], y=flathcor[keepforplotreg],
            keys=['Real', model_name], statname='LD', col=color_palette[model_name],
            ax=plt.subplot(len(hcor_snp), 2, i))
    i += 1
    plt.title(f'LD {model_name} vs Real')
plt.savefig(os.path.join(output_dir, "LD_generated_vs_real_intersectSNP.pdf"))

In [None]:
if snps_on_same_chrom:  # (position_fname['Real']!="1kg_real/805snps.legend"):
    plt.figure(figsize=(5, 5))
    for model_name, bld in binnedPerDistLD.items():
        plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=model_name, alpha=.65,
                     linewidth=3)
    plt.title("Binned LD +/- 1 sem")
    if (logscale): plt.xscale('log')
    # plt.yscale('log')
    plt.xlabel("Distance between SNPs (bp) [Left bound of distance bin]")
    plt.ylabel("Average LD in bin")
    plt.legend()
    plt.savefig(os.path.join(output_dir, "correlation_vs_dist_intersectSNP.pdf"))

In [None]:
# For each dataset LD pairs were stratified by LD values in Real, cut into nbins bins
# binnedLD contains the average LD in each bin
# Plot generated average LD as a function of the real average LD in the bins
plt.figure(figsize=(10, 10))
for model_name, bld in binnedLD.items():
    plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=model_name, alpha=1, marker='o')
plt.title("Binned LD +/- 1 sem")
plt.xlabel("Bins (LD in Real)")
plt.ylabel("Average LD in bin")
plt.legend()
plt.savefig(
    os.path.join(output_dir, 'LD_{}bins_{}fixedremoved.pdf'.format(nbins, 'logdist_' if logscale else '')))



In [None]:
# # Set edges of the region for which to plot LD block matrix (l=0, f='end') for full region
# # not used as for now apart from the filename
# l_bound = None
# r_bound = None
# snpcode = "fullSNP"
# mirror, diff = False, False
# outfilename = f"LD_HEATMAP_{snpcode}_bounds={l_bound}-{r_bound}_mirror={mirror}_diff={diff}.pdf"
# fig = plt.figure(figsize=(10 * len(hcor_snp), 10))
# plotLDblock(hcor_snp,
#             left=l_bound, right=r_bound,  # None, None -> takes all SNPs
#             mirror=mirror, diff=diff,
#             is_fixed=is_fixed, is_fixed_dic=is_fixed_dic,
#             suptitle_kws={'t': outfilename}
#             )
# plt.title(outfilename)
# plt.savefig(os.path.join(output_dir, outfilename))
# plt.show()
#
# print(
#     '****************************************************************\n*** Computation and plotting LD DONE. Figures saved in {} ***\n****************************************************************'.format(
#         output_dir))

In [None]:
dSS_dic = dict()
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
for cat, mat in datasets.items():
    dAB = distance.cdist(mat, mat, 'cityblock')
    np.fill_diagonal(dAB, np.Inf)
    dSS_dic[cat] = dAB.min(axis=1)
    sns.kdeplot(dAB[np.triu_indices(dAB.shape[0], k=1)], linewidth=3, label=cat)  # dSS
plt.title("Pairwise distance within each dataset")
plt.legend()

plt.subplot(1, 2, 2)
for cat, d in dSS_dic.items():
    sns.kdeplot(dSS_dic[cat], linewidth=3, label=cat)
plt.title("Minimal pairwise distance within each dataset")
plt.legend()

plt.savefig(os.path.join(output_dir, "haplo_pairw_distrib_within.pdf"))

In [None]:
haplo = np.concatenate(list(datasets.values())).T  # orientation of scikit allele

outFilePrefix = ''
# if not ref in model_name_to_input_file.keys(): continue
ref = 'Real'
print("Computing AATS with ref " + ref)
AA, MINDIST = computeAAandDist(
    pd.DataFrame(haplo.T),
    extra_sample_info.label,
    model_name_to_input_file.keys(),
    refCateg=ref,
    saveAllDist=True,
    output_dir=output_dir,
    outFilePrefix=outFilePrefix)

# save AA and MINDIST pd.DataFrame to csv
# np.array of all pariwise distances are saved as npz automatically when calling computeAAandDist with saveAllDist=True
AA.to_csv(os.path.join(output_dir, f'AA_{ref}.csv.bz2'), index=None)
MINDIST.to_csv(os.path.join(output_dir, f'MINDIST_{ref}.csv.bz2'), index=None)

In [None]:
#### Distribution WITHIN model_namesories
W = pd.DataFrame(columns=['stat', 'statistic', 'label', 'comparaison'])

plt.figure(figsize=(12, 12))
plt.subplot(1, 2, 1)
model_names = model_name_to_input_file.keys()
for i, model_name in enumerate(model_names):
    subset = (np.load('{}/dist_{}_{}.npz'.format(output_dir, model_name, model_name)))['dist']
    if model_name == 'Real':
        subsetreal = subset
    sns.kdeplot(subset, linewidth=3, label='{} ({} identical pairs)'.format(model_name, (subset == 0).sum()))

    sc = scs.wasserstein_distance(subsetreal, subset)
    new_row = pd.DataFrame(
        {'stat': ['wasserstein'], 'statistic': [sc], 'label': [model_name], 'comparaison': ['within']})
    W = pd.concat([W, new_row], ignore_index=True)

plt.title("Distribution of haplotypic pairwise difference within each dataset")
plt.legend()

In [None]:
print(MINDIST)

In [None]:
# DISTmelt = MINDIST.melt(id_vars='cat').rename(columns=str.title)
# g = sns.FacetGrid(DISTmelt, hue="Cat", height=7, col='Variable', hue_order=model_name_to_input_file.keys())
# # cut=0 : negative values have no meaning for distances, however be aware that this might accidently hide real picks at zero (due to copying for example)
# # check whether the full distribution is  similar or not (next cell)
# g.map(sns.kdeplot, "Value")
# g.add_legend()
# plt.savefig(os.path.join(output_dir, "distrib_minimal_distances_cut.pdf"))
#
# DISTmelt = MINDIST.melt(id_vars='cat').rename(columns=str.title)
# g = sns.FacetGrid(DISTmelt, hue="Cat", height=7, col='Variable', hue_order=model_name_to_input_file.keys())
# g.map(sns.kdeplot, "Value")
# g.add_legend()
# plt.savefig(os.path.join(output_dir, "distrib_minimal_distances_full.pdf"))

In [None]:
W = pd.DataFrame(columns=['stat', 'statistic', 'label', 'comparaison'])
for model_name in model_name_to_input_file.keys():
    for method in ['dTS', 'dST']:
        real = MINDIST[method][MINDIST.cat == 'Real'][0]
        sc = scs.wasserstein_distance(real, MINDIST[method][MINDIST.cat == model_name].values[0])
        new_row = pd.DataFrame({'stat': ['wasserstein'], 'statistic': [sc],
                                'label': [model_name], 'comparaison': [method]})
        W = pd.concat([W, new_row], ignore_index=True)
scores = pd.concat([W])

In [None]:
scores = pd.concat([scores, W])
scores.to_csv(os.path.join(output_dir, "scores_pairwise_distances.csv"), index=False)

plt.figure(figsize=(1.5 * len(model_names), 6))

sns.barplot(x='Cat', y='Value', hue='Variable', palette=sns.color_palette('colorblind'),
            data=(AA.drop(columns=['PrivacyLoss', 'ref'], errors='ignore')).melt(id_vars='cat').rename(
                columns=str.title))
plt.axhline(0.5, color='black')
if 'Real_test' in AA.cat.values:
    plt.axhline(np.float(AA[AA.cat == 'Real_test'].AATS), color=sns.color_palette()[0], ls='--')
plt.ylim(0, 1.1)
plt.title("Nearest Neighbor Adversarial Accuracy on training (AATS) and its components")
plt.savefig(os.path.join(output_dir, "AATS_scores.pdf"))

Test = '_Test2'
Train = ''  # means Training set is Real
dfPL = plotPrivacyLoss(Train, Test, output_dir, color_palette, model_name_to_color)

# Compute PL for the real dataset Test1
# Useful if an RBM with alternative training scheme (cf paper) is in the list of models
# Because Test1 served for initializing the RBM sampling in this case
Test = '_Test2'
Train = '_Test1'
dfPL = plotPrivacyLoss(Train, Test, output_dir, color_palette, model_name_to_color)

In [None]:
def get_counts(haplosubset, points):
    counts = np.unique(
        np.apply_along_axis(
            lambda x: ''.join(map(str, x[points])),
            # lambda x: ''.join([str(x[p]) for p in points]),
            0, haplosubset),
        return_counts=True)
    return (counts)


def get_frequencies(counts):
    l = len(counts[0][0])  # haplotype length
    nind = np.sum(counts[1])
    f = np.zeros(shape=[2] * l)
    for i, allele in enumerate(counts[0]):
        f[tuple(map(int, allele))] = counts[1][i] / nind
    return f


def three_points_cor(haplosubset, out='all'):
    F = dict()
    for points in [[0], [1], [2], [0, 1], [0, 2], [1, 2], [0, 1, 2]]:
        strpoints = ''.join(map(str, points))
        F[strpoints] = get_frequencies(
            get_counts(haplosubset, points)
        )

    cors = [
        F['012'][a, b, c] - F['01'][a, b] * F['2'][c] - F['12'][b, c] * F['0'][a] - F['02'][a, c] * F['1'][b] + 2 *
        F['0'][a] * F['1'][b] * F['2'][c] for a, b, c in itertools.product(*[[0, 1]] * 3)]
    if out == 'mean':
        return (np.mean(cors))
    if out == 'max':
        return (np.max(np.abs(cors)))
    if out == 'all':
        return (cors)
    return (ValueError(f"out={out} not recognized"))


# def mult_three_point_cor(haplo, extra_sample_info, model_name, picked_three_points):
#    return [three_points_cor(haplo[np.ix_(snps,extra_sample_info.label==model_name)], out='all') for snps in picked_three_points]

# set the seed so that the same real individual are subsampled (when needed)
# to ensure consistency of the scores when adding a new model or a new sumstat
np.random.seed(3)
random.seed(3)

# Compute 3 point correlations results for different datasets and different distances between SNPs

# pick distance between SNPs at which 3point corr will be computed
# (defined in nb of snps)
# a gap of -9 means that snp triplets are chosen completely at random (not predefined distance)
# for each category we randomly pick 'nsamplesets' triplets

# if datasets have different nb of snps, for convenience we will sample
# slightly more at the beginning of the chunk

gap_vec = [1, 4, 16, 64, 256, 512, 1024, -9]
nsamplesets = 1000
min_nsnp = min([dat.shape[1] for dat in datasets.values()])
cors_meta = dict()
for gap in gap_vec:
    print(f'\n gap={gap} SNPs', end=' ')
    if gap < 0:
        # pick 3 random snps
        picked_three_points = [random.sample(range(min_nsnp), 3) for _ in range(nsamplesets)]
    else:
        try:
            # pick 3 successive snps spearated by 'gap' SNPs
            step = gap + 1
            picked_three_points = [np.asarray(random.sample(range(min_nsnp - 2 * step), 1)) + [0, step, 2 * step]
                                   for _
                                   in range(nsamplesets)]
        except:
            continue  # if there were not enough SNPs for this gap
    cors = dict()

    for model_name in model_name_to_input_file.keys():
        print(model_name, end=' ')
        # cors[model_name]=[three_points_cor(haplo[np.ix_(snps,extra_sample_info.label==model_name)], out='all') for snps in picked_three_points]
        cors[model_name] = [three_points_cor(datasets[model_name][:, snps].T, out='all') for snps in
                            picked_three_points]

    cors_meta[gap] = cors.copy()

# print(cors_meta)

with open(os.path.join(output_dir, "3pointcorr.pkl"), "wb") as outfile:
    pickle.dump(cors_meta, outfile)

# Plot 3-point correlations results

plt.figure(figsize=(2 * len(cors_meta), 7))
# plt.figure(figsize=(figwi,figwi/2))
for i, gap in enumerate((cors_meta).keys()):
    ax = plt.subplot(2, int(np.ceil(len(cors_meta) / 2)), int(i) + 1)
    cors = cors_meta[gap]
    real = list(np.array(cors['Real']).flat)
    lims = [np.min(real), np.max(real)]
    for key, val in cors.items():
        if key == 'Real': continue
        val = list(np.array(val).flat)
        plotreg(x=real, y=val, keys=['Real', key],
                statname='Correlation', col=color_palette[key], ax=ax)
    if gap < 0:
        plt.title('3-point corr for random SNPs')
    else:
        plt.title(f'3-point corr for SNPs sep. by {gap} SNPs')

    plt.legend(fontsize='small')
plt.tight_layout()
plt.savefig(os.path.join(output_dir, '3point_correlations.jpg'), dpi=300)  # can pick one of the format

# Same plot with axes limit fixed to (-0.1,0.1) for the sake of comparison

plt.figure(figsize=(4 * len(cors_meta), 14))
# plt.figure(figsize=(figwi,figwi/2))
for i, gap in enumerate((cors_meta).keys()):
    ax = plt.subplot(2, int(np.ceil(len(cors_meta) / 2)), int(i) + 1)
    cors = cors_meta[gap]
    real = list(np.array(cors['Real']).flat)
    lims = [np.min(real), np.max(real)]
    for key, val in cors.items():
        if key == 'Real': continue
        val = list(np.array(val).flat)
        plotreg(x=real, y=val, keys=['Real', key],
                statname='Correlation', col=color_palette[key], ax=ax)
        ax.set_xlim((-.1, .1))
        ax.set_ylim((-.1, .1))

    if gap < 0:
        plt.title('3-point corr for random SNPs')
    else:
        plt.title(f'3-point corr for SNPs sep. by {gap} SNPs')

    plt.legend(fontsize='small')
plt.tight_layout()

plt.savefig(os.path.join(output_dir, '3point_correlations_fixlim.pdf'), dpi=300)