In [300]:
import itertools
import pickle
import warnings
from pathlib import Path

from scipy.stats import binned_statistic
from scipy.stats import sem
from scipy.stats import wasserstein_distance
from sklearn.metrics import r2_score, mean_squared_error

from utils.util import *
%matplotlib inline
try:
    import ot

    ot_loaded = True
except ModuleNotFoundError:
    ot_loaded = False
try:
    import statsmodels.api as sm

    sm_loaded = True
except ModuleNotFoundError:
    sm_loaded = False

UsageError: Line magic function `%` not found.


# Initialize Data

In [301]:
models_to_data = {
    "Real": {
        "path": "resource/test_0.2_super_pop.csv",
        "color": "black",
        "type": "dataset"
    },
    "Train Set": {
        "path": "resource/train_0.8_super_pop.csv",
        "color": "gray",
        "type": "dataset"
    },
    "RBM 2023": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_SNP_GAN_AG_10800Epochs.hapt",
        "color": "yellow"
    },
    "WGAN 2023": {
        "path": "fake_genotypes_sequences/preview_sequences/10K_WGAN.hapt",
        "color": "purple"
    },
    "GAN 2019 Retrain": {
        "path": "fake_genotypes_sequences/preview_sequences/old GAN retrain genotypes.hapt",
        "color": "brown",
        "type": "retrain_old_model"
    },
    "Genome-AC-GAN By National Population": {
        "path": "resource/Genome-AC-GAN By National Population genotypes.hapt",
        "color": "green",
        "type": "new_model"
    },
    "Genome-AC-GAN By Continental Population": {
        "path": "resource/Genome-AC-GAN By Continental Population genotypes.hapt",
        "color": "blue",
        "type": "new_model"
    },
}

In [302]:
output_dir = os.environ.get("output_dir", DEFAULT_EXPERIMENT_OUTPUT_DIR)
Path(output_dir).mkdir(parents=True, exist_ok=True)
compute_AATS = True

In [303]:
color_palette = {model_name: values["color"] for (model_name, values) in models_to_data.items()}
sns.set_palette(color_palette.values())

In [304]:
def load_analysis_data_agg_tests(models_to_data: dict, number_of_datasets: int):
    transformations = {'to_minor_encoding': False, 'min_af': 0, 'max_af': 1}

    model_keep_all_snps, sample_info = dict(), dict()
    # initialize real data
    real_data = models_to_data['Real']
    real_model_sequences, number_of_samples = create_single_dataset(
        real_data, f"../{real_data['path']}", 'Real', 0,
        sample_info)
    datasets = {'Real': [np.array(real_model_sequences.loc[:, 2:].astype(int))]}
    full_datasets = {'Real': np.array(real_model_sequences.loc[:, 2:].astype(int))}
    print('Real: ', datasets['Real'][0].shape)
    # init all other datasets
    for model_name, data in models_to_data.items():
        if model_name != 'Real':
            print(f"init data from {model_name} with type {data.get('type', 'none')}")
            model_datasets = []
            file_path = f"../{data['path']}"
            for dataset_number in range(1, number_of_datasets + 1):

                model_sequences, _ = create_single_dataset(data, file_path, model_name,
                                                           number_of_samples,
                                                           sample_info)
                model_datasets.append(np.array(model_sequences.loc[:, 2:].astype(int)))

                if dataset_number % 5 == 0:
                    print(f"Finished init model {model_name} number {dataset_number}")
            datasets[model_name] = model_datasets

            model_sequences, _ = create_single_dataset(data, file_path, model_name,
                                                       number_of_samples,
                                                       sample_info, filter_number_of_sequences=False)
            full_datasets[model_name] = np.array(model_sequences.loc[:, 2:].astype(int))

    extra_sample_info = pd.DataFrame(np.concatenate(list(sample_info.values())), columns=['label', 'id'])
    print("Dictionary of datasets:", len(datasets))
    return extra_sample_info, sample_info, datasets, transformations, model_keep_all_snps, number_of_samples, full_datasets


def create_single_dataset(data, file_path, model_name, number_of_samples, sample_info, filter_number_of_sequences=True):
    if data.get("type", "") == "dataset":
        model_sequences_df = pd.read_csv(file_path)
        columns = get_relevant_columns(model_sequences_df, model_sequences_df.columns[:2])
        model_sequences = model_sequences_df[columns]
        columns = [int(i) for i in columns]
        model_sequences.columns = columns
        model_sequences = model_sequences.sample(frac=1).reset_index(drop=True)
    else:
        model_sequences = pd.read_csv(file_path, sep=' ', header=None)
        if data.get("type", "") == "new_model":
            model_sequences.columns = [column if column == 0 else column + 1 for column in model_sequences.columns]
            # Calculate the category counts
            if filter_number_of_sequences:
                category_counts = model_sequences[0].value_counts()
                sample_counts = (category_counts / category_counts.sum() * (number_of_samples)).astype(int)
                model_sequences = model_sequences.sample(frac=1).reset_index(drop=True)
                # Sample rows from each category
                model_sequences = model_sequences.groupby(0).apply(
                    lambda x: x.sample(sample_counts[x.name])).reset_index(drop=True)

            model_sequences.insert(0, 1, [f"AG{sample_id}" for sample_id in range(model_sequences.shape[0])])
        if data.get("type", "") == "retrain_old_model":
            model_sequences = model_sequences.drop(columns=list(model_sequences.columns)[-1], axis=1)
            model_sequences.columns = [column + 2 for column in list(model_sequences.columns)]
            model_sequences.insert(loc=0, column=0, value="none")
            model_sequences.insert(loc=1, column=1, value='none')
    if model_name == 'Real':
        number_of_samples = len(model_sequences)

    if filter_number_of_sequences:
        if model_sequences.shape[0] > number_of_samples:
            model_sequences = model_sequences.drop(
                index=np.sort(
                    np.random.choice(np.arange(model_sequences.shape[0]),
                                     size=model_sequences.shape[0] - number_of_samples,
                                     replace=False)))
    # overwrite file first column to set the label name chosen in infiles (eg GAN, etc):
    model_sequences[0] = model_name
    sample_info[model_name] = pd.DataFrame({'label': model_sequences[0], 'ind': model_sequences[1]})
    return model_sequences, number_of_samples

In [305]:
extra_sample_info, sample_info, multiple_datasets, transformations, model_keep_all_snps, number_of_samples, full_datasets = load_analysis_data_agg_tests(
    models_to_data, 50)

Real:  (1002, 10000)
init data from Train Set with type dataset


KeyboardInterrupt: 

# PCA Tests

In [None]:
from sklearn.decomposition import PCA


def plot_pca_comparison(models):
    model_to_wasserstein_dists = {}
    all_best_sequences = {}
    # Extract the 'Real' model data
    real_model = models['Real'][0]
    all_best_sequences['Real'] = real_model
    # Perform PCA on the 'Real' model
    pca_real = PCA(n_components=2)
    pca_real.fit(real_model)
    pca_real_transformed = pca_real.transform(real_model)

    # Plotting parameters
    num_models = len(models) - 1
    num_rows = int(np.ceil(num_models / 3))
    num_cols = min(num_models, 3)
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(15, 5 * num_rows))

    for i, (model_name, model_sequences) in enumerate(models.items()):
        # Skip 'Real' model
        if model_name == 'Real':
            continue

        print(f"start model: {model_name} get best sequences")
        all_wasserstein_dist, best_model_sequence, best_pca_transformed = \
            get_best_pca_wasserstein(model_sequences, pca_real_transformed)
        mean_all_wasserstein_dist = np.mean(all_wasserstein_dist)
        std_all_wasserstein_dist = np.std(all_wasserstein_dist)
        min_all_wasserstein_dist = np.min(all_wasserstein_dist)
        model_to_wasserstein_dists[model_name] = all_wasserstein_dist
        all_best_sequences[model_name] = best_model_sequence
        print(
            f"finished model: {model_name} get best sequences with mean: {mean_all_wasserstein_dist}, std: {std_all_wasserstein_dist}, min: {min_all_wasserstein_dist}")
        # Set subplot position
        position = i - 1
        row = position // num_cols
        col = position % num_cols

        # Plot PCA comparison
        ax = axes[row, col]
        ax.scatter(pca_real_transformed[:, 0], pca_real_transformed[:, 1], color=color_palette['Real'], alpha=0.8)
        ax.scatter(best_pca_transformed[:, 0], best_pca_transformed[:, 1], color=color_palette[model_name], alpha=0.6)
        title = "\n".join(model_name.split("By"))
        ax.set_title(title, fontsize=25, fontweight='bold')

    # Adjust the spacing between the first row and the second column
    plt.subplots_adjust(hspace=0.5)

    plt.savefig(os.path.join(output_dir, "pca2_on_test_real.jpg"), bbox_inches='tight', dpi=300)
    plt.show()
    return model_to_wasserstein_dists, all_best_sequences


def get_best_pca_wasserstein(model_sequences, pca_real_transformed):
    all_wasserstein_dist = []
    best_wasserstein_dist = np.inf
    best_model_sequence = None
    best_pca_transformed = None
    for model_sequence in model_sequences:
        # Perform PCA on the current model
        pca_model = PCA(n_components=2)
        pca_model.fit(model_sequence)
        pca_model_transformed = pca_model.transform(model_sequence)

        # Calculate Wasserstein distance
        tmp_wasserstein_dist = wasserstein_distance(pca_real_transformed.flatten(),
                                                    pca_model_transformed.flatten())
        all_wasserstein_dist.append(tmp_wasserstein_dist)
        if tmp_wasserstein_dist < best_wasserstein_dist:
            best_wasserstein_dist = tmp_wasserstein_dist
            best_model_sequence = model_sequence
            best_pca_transformed = pca_model_transformed
    return all_wasserstein_dist, best_model_sequence, best_pca_transformed


In [None]:
model_to_wasserstein_dists, all_best_sequences = plot_pca_comparison(multiple_datasets)

In [None]:
import scipy.stats as stats
import numpy as np
import matplotlib.pyplot as plt

model_names = list(model_to_wasserstein_dists.keys())
model_names = ["\n".join(model_name.split("By")) for model_name in model_names]
model_names = ["\n".join(model_name.split("Model")) for model_name in model_names]
wasserstein_distances = list(model_to_wasserstein_dists.values())

p_values = []
for i in range(len(model_names)):
    for j in range(i + 1, len(model_names)):
        p_values.append(1 - stats.ttest_ind(wasserstein_distances[i], wasserstein_distances[j]).pvalue)

# Reshape the p_values into a 2D matrix
n = len(model_names)
p_values_matrix = np.zeros((n, n))
p_values_matrix[np.triu_indices(n, 1)] = p_values
p_values_matrix += p_values_matrix.T

# Create a plot matrix of the p-values
fig, ax = plt.subplots(figsize=(15, 15))  # Increase the size of the plot
im = ax.imshow(p_values_matrix, cmap='coolwarm', vmin=0, vmax=1)
ax.set_xticks(np.arange(len(model_names)))
ax.set_yticks(np.arange(len(model_names)))
ax.set_xticklabels(model_names, rotation=45)
ax.set_yticklabels(model_names)

# Add numerical values in the matrix
for i in range(n):
    for j in range(n):
        text = ax.text(j, i, f'{p_values_matrix[i, j] * 100:.5f}%', ha='center', va='center', color='w', fontsize=15)

plt.colorbar(im)

plt.savefig(os.path.join(output_dir, "P-values wasserstein_distances"))
plt.show()


## Wasserstein Distance Plot

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Assuming you have a dictionary called `wasserstein_dict` with model names as keys and Wasserstein Distance lists as values
# and a dictionary called `color_palette` with model names as keys and color names as values

# Calculate the mean and standard deviation for each model's Wasserstein Distance list
means = {}
stds = {}
boxprops = dict(linewidth=3, color='black')
medianprops = dict(linewidth=3, color='black')
meanprops = dict(linewidth=3, color='black')

for model, distances in model_to_wasserstein_dists.items():
    model_display_name = "\n".join(model.split("By"))
    model_display_name = "\n".join(model_display_name.split("Model"))
    means[model_display_name] = np.mean(distances)
    stds[model_display_name] = np.std(distances)

# Sort the model names alphabetically
sorted_models = list(model_to_wasserstein_dists.keys())

# Get the colors from the color palette based on the sorted model names
colors = [color_palette[model] for model in sorted_models]

# Plotting the mean and standard deviation for each model
data = [model_to_wasserstein_dists[model] for model in sorted_models]
sorted_models = ["\n".join(model_name.split("By")) for model_name in sorted_models]
sorted_models = ["\n".join(model_name.split("Model")) for model_name in sorted_models]
# Plotting the boxplot for each model
fig, ax = plt.subplots(figsize=(18, 8))
boxplot = ax.boxplot(data, labels=sorted_models, patch_artist=True, showfliers=False, boxprops=boxprops,
                     medianprops=medianprops, meanprops=meanprops)

# Set the colors for the boxes based on the color palette
for patch_artist, color in zip(boxplot['boxes'], colors):
    patch_artist.set_facecolor(color)

# Add text annotations for mean and standard deviation values in the label
for i, model in enumerate(sorted_models):
    mean = means[model]
    std = stds[model]
    label = f"Mean: {mean:.2f}\nStd: {std:.2f}"
    pos = 0.1 if i != 3 else -0.1
    ax.text(i + 1, mean + std + pos, label, ha='center', va='top', fontsize=24, color='white', fontweight='bold',
            bbox=dict(facecolor='black', edgecolor='black', boxstyle='round', pad=0.2))

for label in ax.get_xticklabels():
    label.set_weight('bold')
    label.set_size(16)
for label in ax.get_yticklabels():
    label.set_weight('bold')
    label.set_size(16)
# Set the y-axis label
ax.set_ylabel('Wasserstein Distance', fontweight='bold', fontsize=17)

ax.grid(True, color='black')

# Show the plot
plt.savefig(os.path.join(output_dir, "Wasserstein Distance Comparison.jpg"), bbox_inches='tight', dpi=300)
plt.show()


## MAF Tests

In [None]:
sum_alleles_by_position, allele_frequency, is_fixed = build_allele_frequency(full_datasets)

In [None]:
from sklearn.metrics import mean_absolute_error


def plotreg(x, y, keys, statname, col, model_name_display, ax=None):
    """
    Plot for x versus y with regression scores and returns correlation coefficient and MSE

    Parameters
    ----------
    x : array-like, scalar
    y : array-like, scalar
    keys : tuple
        Tuple containing the model names or keys
    statname : str
        'Allele frequency', 'LD', or '3 point correlation', etc.
    col : str
        Color code or name

    Returns
    -------
    r : float
        Pearson correlation coefficient between x and y
    mse : float
        Mean Squared Error between x and y
    """

    lims = [np.min(x), np.max(x)]
    r, _ = pearsonr(x, y)
    mae = mean_absolute_error(x, y)

    if sm_loaded:
        reg = sm.OLS(x, y).fit()

    if ax is None:
        ax = plt.subplot(1, 1, 1)

    if len(x) < 100:
        alpha = 1
    else:
        alpha = .6

    ax.plot(x, y, c=col, marker='o', lw=0, alpha=alpha)
    ax.plot(lims, lims, ls='--', alpha=1, c='black')
    title = ax.set_title(
        f'{model_name_display}\nCorrelation={round(round(r, 3) * 100, 3)}%\nMAE={round(round(mae, 3) * 100, 3)}',
        fontsize=29, fontweight="bold", y=1, color='black')

    title.set_bbox({'facecolor': 'white', 'edgecolor': "black", 'pad': 1.2})
    ax.set_xlabel("MAF In Real", fontsize=28, fontweight="bold")
    ax.set_ylabel("MAF In Synthetic", fontsize=28, fontweight="bold")

    # Adjust vertical spacing between subplots
    plt.subplots_adjust(hspace=0.5)

    return r, mae


In [None]:
def plot_allele_frequency(allele_frequency, file_name, maf, highest=False):
    # Plotting Allele frequencies in Generated vs Real
    # below a certain real frequency
    figwi = 14
    l, c = 1, 6
    plt.figure(figsize=(44, 6))
    if highest:
        maf = 1 - maf
        keep = (allele_frequency['Real'] >= maf)
    else:
        keep = (allele_frequency['Real'] <= maf)
    for i, (model_name, val) in enumerate(allele_frequency.items()):
        model_name_display = model_name.replace("Population", "").replace(" By", "")
        if model_name != 'Real':
            ax = plt.subplot(int(l), c, i)
            plotreg(x=allele_frequency['Real'][keep], y=val[keep],
                     keys=['Real', model_name_display], statname="Allele frequency",
                     col=color_palette[model_name], model_name_display=model_name_display, ax=ax)

    plt.savefig(os.path.join(output_dir, file_name), bbox_inches='tight', dpi=300)

In [None]:
plot_allele_frequency(allele_frequency, 'total_allele_frequency.jpg', 1)

In [None]:
plot_allele_frequency(allele_frequency, 'zoom_lowest_total_allele_frequency.jpg', 0.2)

In [None]:
plot_allele_frequency(allele_frequency, 'zoom_highest_total_allele_frequency.jpg', 0.2, highest=True)

# Initialized Data With Preview Loader

In [None]:
def load_analysis_data_for_preview_tests(model_name_to_input_file: dict):
    transformations = {'to_minor_encoding': False, 'min_af': 0, 'max_af': 1}

    datasets, model_keep_all_snps, sample_info = dict(), dict(), dict()
    number_of_samples = 0
    for model_name, data in model_name_to_input_file.items():
        file_path = f"../{data['path']}"
        print(model_name, "loaded from", file_path)
        if file_path.endswith('.csv'):
            model_sequences = pd.read_csv(file_path)
            columns = get_relevant_columns(model_sequences, model_sequences.columns[:2])
            model_sequences = model_sequences[columns]
            columns = [int(i) for i in columns]
            model_sequences.columns = columns
            number_of_samples = len(model_sequences)

        else:
            model_sequences = pd.read_csv(file_path, sep=' ', header=None)
            if 'Genome-AC-GAN' in model_name:
                model_sequences.columns = [column if column == 0 else column + 1 for column in model_sequences.columns]
                model_sequences.insert(0, 1, [f"AG{sample_id}" for sample_id in range(model_sequences.shape[0])])
            if model_sequences.shape[1] == 808:  # special case for a specific file that had an extra empty column
                model_sequences = model_sequences.drop(columns=model_sequences.columns[-1])
            if model_sequences.shape[0] > number_of_samples:
                model_sequences = model_sequences.drop(
                    index=np.sort(
                        np.random.choice(np.arange(model_sequences.shape[0]),
                                         size=model_sequences.shape[0] - number_of_samples,
                                         replace=False))
                )
            if 'GAN 2019 Retrain' in model_name:
                model_sequences = model_sequences.drop(columns=list(model_sequences.columns)[-1], axis=1)
                model_sequences.columns = [column + 2 for column in list(model_sequences.columns)]
                model_sequences.insert(loc=0, column=0, value="none")
                model_sequences.insert(loc=1, column=1, value='none')
        # overwrite file first column to set the label name chosen in infiles (eg GAN, etc):
        model_sequences[0] = model_name
        sample_info[model_name] = pd.DataFrame({'label': model_sequences[0], 'ind': model_sequences[1]})
        datasets[model_name] = np.array(model_sequences.loc[:, 2:].astype(int))

        # transformations can be maf filtering, recoding into major=0/minor=1 format
        if transformations is not None:
            datasets[model_name], model_keep_all_snps[model_name] = datatransform(datasets[model_name],
                                                                                  **transformations)
        print(model_name, datasets[model_name].shape)
    extra_sample_info = pd.DataFrame(np.concatenate(list(sample_info.values())), columns=['label', 'id'])
    print("Dictionary of datasets:", len(datasets))
    return extra_sample_info, sample_info, datasets, transformations, model_keep_all_snps, number_of_samples


In [None]:
extra_sample_info, sample_info, datasets, transformations, model_keep_all_snps, number_of_samples = load_analysis_data_for_preview_tests(
    models_to_data)

In [None]:
sum_alleles_by_position, allele_frequency, is_fixed = build_allele_frequency(datasets)


# LD Tests

In [None]:
print("* Computing and plotting LD...")
#### Compute correlation between all pairs of SNPs for each generated/real dataset

model_names = models_to_data.keys()
hcor_snp = dict()
for i, model_name in enumerate(model_names):
    print(model_name)
    with np.errstate(divide='ignore', invalid='ignore'):
        # Catch warnings due to fixed sites in dataset (the correlation value will be np.nan for pairs involving these sites)
        hcor_snp[model_name] = np.corrcoef(datasets[model_name], rowvar=False) ** 2  # r2

_, region_len, snps_on_same_chrom = get_dist(f"../{REAL_POSITION_FILE_NAME}", region_len_only=True,
                                             kept_preprocessing=model_keep_all_snps['Real'])

nbins = 100
logscale = True
bins = nbins
binsPerDist = nbins
if logscale: binsPerDist = np.logspace(np.log(1), np.log(region_len), nbins)

# Compute LD binned by distance
# Take only sites that are SNPs in all datasets (intersect)
# (eg intersection of SNPs in Real, SNPs in GAN, SNPs in RBM etc)
# -> Makes sense only if there is a correspondence between sites

binnedLD = dict()
binnedPerDistLD = dict()
kept_snp = ~is_fixed
n_kept_snp = np.sum(kept_snp)
realdist = get_dist(f"../{REAL_POSITION_FILE_NAME}", kept_preprocessing=model_keep_all_snps['Real'],
                    kept_snp=kept_snp)[0]
mat = hcor_snp['Real']
# filter and flatten
flatreal = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
isnanReal = np.isnan(flatreal)
i = 1
plt.figure(figsize=(10, len(hcor_snp) * 5))

for model_name, mat in hcor_snp.items():
    flathcor = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
    isnan = np.isnan(flathcor)
    curr_dist = realdist

    # For each dataset LD pairs are stratified by SNP distance and cut into 'nbins' bins
    # bin per SNP distance
    ld = binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic='mean', bins=binsPerDist)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)  # so that empty bins do not raise a warning
        binnedPerDistLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                                    'LD': ld.statistic,
                                                    # 'sd': binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic = 'std', bins=binsPerDist).statistic,
                                                    'sem': binned_statistic(curr_dist[~isnan], flathcor[~isnan],
                                                                            statistic=sem,
                                                                            bins=binsPerDist).statistic,
                                                    'model_name': model_name, 'logscale': logscale})

    # For each dataset LD pairs are stratified by LD values in Real and cut into 'nbins' bins
    # binnedLD contains the average, std of LD values in each bin
    isnan = np.isnan(flathcor) | np.isnan(flatreal)
    ld = binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic='mean', bins=bins)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore", category=RuntimeWarning)  # so that empty bins do not raise a warning
        binnedLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                             'LD': ld.statistic,
                                             'sd': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                    statistic='std',
                                                                    bins=bins).statistic,
                                             'sem': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                     statistic=sem,
                                                                     bins=bins).statistic,
                                             'model_name': model_name, 'logscale': logscale})

    # Plotting quantiles ?
    plotregquant(x=flatreal, y=flathcor,
                 keys=['Real', model_name], statname='LD', col=color_palette[model_name],
                 step=0.05,
                 ax=plt.subplot(len(hcor_snp), 2, i))
    i += 1
    plt.title(f'Quantiles LD {model_name} vs Real')

    # removing nan values and subsampling before doing the regression to have a reasonnable number of points
    isnanInter = isnanReal | isnan
    keepforplotreg = random.sample(list(np.where(~isnanInter)[0]), number_of_samples)
    plotreg(x=flatreal[keepforplotreg], y=flathcor[keepforplotreg],
            keys=['Real', model_name], statname='LD', col=color_palette[model_name],
            ax=plt.subplot(len(hcor_snp), 2, i))
    i += 1
    plt.title(f'LD {model_name} vs Real')
plt.savefig(os.path.join(output_dir, "LD_generated_vs_real_intersectSNP.pdf"))

In [None]:
import pandas as pd
from scipy.stats import binned_statistic, sem
import warnings


def compute_and_plot_ld(real_data, synthetic_data, output_dir):
    model_names = synthetic_data.keys()
    hcor_snp = dict()

    for model_name in model_names:
        with np.errstate(divide='ignore', invalid='ignore'):
            hcor_snp[model_name] = np.corrcoef(real_data[model_name], rowvar=False) ** 2  # r2

    _, region_len, snps_on_same_chrom = get_dist(f"../{REAL_POSITION_FILE_NAME}", region_len_only=True,
                                                 kept_preprocessing=real_data)

    nbins = 100
    logscale = True
    bins = nbins
    binsPerDist = nbins
    if logscale:
        binsPerDist = np.logspace(np.log(1), np.log(region_len), nbins)

    binnedLD = dict()
    binnedPerDistLD = dict()
    realdist = get_dist(f"../{REAL_POSITION_FILE_NAME}", kept_preprocessing=real_data,
                        kept_snp='all')[0]
    mat = hcor_snp['Real']
    flatreal = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
    isnanReal = np.isnan(flatreal)
    i = 1

    plt.figure(figsize=(10, len(hcor_snp) * 5))

    for model_name, mat in hcor_snp.items():
        flathcor = (mat[np.ix_(kept_snp, kept_snp)])[np.triu_indices(n_kept_snp)]
        isnan = np.isnan(flathcor)
        curr_dist = realdist

        ld = binned_statistic(curr_dist[~isnan], flathcor[~isnan], statistic='mean', bins=binsPerDist)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            binnedPerDistLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                                        'LD': ld.statistic,
                                                        'sem': binned_statistic(curr_dist[~isnan], flathcor[~isnan],
                                                                                statistic=sem,
                                                                                bins=binsPerDist).statistic,
                                                        'model_name': model_name, 'logscale': logscale})

        isnan = np.isnan(flathcor) | np.isnan(flatreal)
        ld = binned_statistic(flatreal[~isnan], flathcor[~isnan], statistic='mean', bins=bins)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            binnedLD[model_name] = pd.DataFrame({'bin_edges': ld.bin_edges[:-1],
                                                 'LD': ld.statistic,
                                                 'sd': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                        statistic='std',
                                                                        bins=bins).statistic,
                                                 'sem': binned_statistic(flatreal[~isnan], flathcor[~isnan],
                                                                         statistic=sem,
                                                                         bins=bins).statistic,
                                                 'model_name': model_name, 'logscale': logscale})

        # Plotting quantiles ?
        plotregquant(x=flatreal, y=flathcor,
                     keys=['Real', model_name], statname='LD', col=color_palette[model_name],
                     step=0.05,
                     ax=plt.subplot(len(hcor_snp), 2, i))
        i += 1
        plt.title(f'Quantiles LD {model_name} vs Real')

        # removing nan values and subsampling before doing the regression to have a reasonnable number of points
        isnanInter = isnanReal | isnan
        keepforplotreg = random.sample(list(np.where(~isnanInter)[0]), number_of_samples)
        plotreg(x=flatreal[keepforplotreg], y=flathcor[keepforplotreg],
                keys=['Real', model_name], statname='LD', col=color_palette[model_name],
                ax=plt.subplot(len(hcor_snp), 2, i))
        i += 1
        plt.title(f'LD {model_name} vs Real')
    plt.savefig(os.path.join(output_dir, "LD_generated_vs_real_intersectSNP.pdf"))


# AATS Privacy Tests

In [None]:

import numpy as np
import matplotlib.pyplot as plt

plt.figure(figsize=(15, 10))
line_styles = ['solid', 'dashdot', 'dotted']
scores = []
real_bld = binnedPerDistLD['Real'].LD.values[~np.isnan(binnedPerDistLD['Real'].LD.values)]
# Calculate the absolute difference from the "Real" line
for index, (model_name, bld) in enumerate(binnedPerDistLD.items()):
    style_index = index % len(line_styles)
    line_style = line_styles[style_index]
    r2 = round(r2_score(real_bld, bld.LD.values[~np.isnan(bld.LD.values)]), 3)
    rmse = round(np.sqrt(mean_squared_error(real_bld, bld.LD.values[~np.isnan(bld.LD.values)])), 3)
    plt.errorbar(
        bld.bin_edges.values, bld.LD.values, bld['sem'].values,
        label=r"$\mathbf{" + model_name + "}$  RMSE = " + str(rmse) + ", R-squared = " + str(r2),
        alpha=0.8, linewidth=3, linestyle=line_style
    )

# plt.title("Binned LD +/- 1 sem")
if logscale:
    plt.xscale('log')
# plt.yscale('log')
plt.xlabel("Distance between SNPs (bp) [Left bound of distance bin]", fontsize=15)
plt.ylabel("Average LD in bin", fontsize=15)
plt.legend(fontsize='x-large', loc="upper right")

plt.savefig(os.path.join(output_dir, "correlation_vs_dist_intersectSNP.jpg"), bbox_inches='tight', dpi=500)


In [None]:
import matplotlib.pyplot as plt

# Create a figure and axes object
fig, axes = plt.subplots(figsize=(10, 10))

# Plot the data for each model
for model_name, bld in binnedPerDistLD.items():
    plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=model_name, alpha=.65,
                 linewidth=3, color=color_palette[model_name])

# Add a title to the plot
plt.title("Binned LD +/- 1 sem")

# Set the x-axis label
plt.xlabel("Distance between SNPs (bp) [Left bound of distance bin]")

# Set the y-axis label
plt.ylabel("Average LD in bin")

# Add a legend to the plot
plt.legend()

# Save the plot
plt.savefig(os.path.join(output_dir, "correlation_vs_dist_intersectSNP.pdf"))

In [None]:
import matplotlib.pyplot as plt

# For each dataset LD pairs were stratified by LD values in Real, cut into nbins bins
# binnedLD contains the average LD in each bin
# Plot generated average LD as a function of the real average LD in the bins
plt.figure(figsize=(10, 10))
for model_name, bld in binnedLD.items():
    plt.errorbar(bld.bin_edges.values, bld.LD.values, bld['sem'].values, label=model_name, alpha=0.8, marker='o')
plt.title("Binned LD +/- 1 sem")
plt.xlabel("Bins (LD in Real)")
plt.ylabel("Average LD in bin")
plt.legend()
plt.savefig(os.path.join(output_dir, 'LD decay.jpg'))



In [None]:
dSS_dic = dict()
plt.figure(figsize=(10, 4))
plt.subplot(1, 2, 1)
for cat, mat in datasets.items():
    dAB = distance.cdist(mat, mat, 'cityblock')
    np.fill_diagonal(dAB, np.Inf)
    dSS_dic[cat] = dAB.min(axis=1)
    sns.kdeplot(dAB[np.triu_indices(dAB.shape[0], k=1)], linewidth=3, label=cat)  # dSS
plt.title("Pairwise distance within each dataset")
plt.legend(fontsize='x-large')

plt.subplot(1, 2, 2)
for cat, d in dSS_dic.items():
    sns.kdeplot(dSS_dic[cat], linewidth=3, label=cat)
plt.title("Minimal pairwise distance within each dataset")
plt.legend(fontsize='x-large')

plt.savefig(os.path.join(output_dir, "haplo_pairw_distrib_within.pdf"), bbox_inches='tight', dpi=300)

In [None]:
haplo = np.concatenate(list(datasets.values())).T  # orientation of scikit allele

outFilePrefix = ''
# if not ref in model_name_to_input_file.keys(): continue
ref = 'Real'
print("Computing AATS with ref " + ref)
AA, MINDIST = computeAAandDist(
    pd.DataFrame(haplo.T),
    extra_sample_info.label,
    models_to_data.keys(),
    refCateg=ref,
    saveAllDist=True,
    output_dir=output_dir,
    outFilePrefix=outFilePrefix)

# save AA and MINDIST pd.DataFrame to csv
# np.array of all pariwise distances are saved as npz automatically when calling computeAAandDist with saveAllDist=True
AA.to_csv(os.path.join(output_dir, f'AA_{ref}.csv.bz2'), index=None)
MINDIST.to_csv(os.path.join(output_dir, f'MINDIST_{ref}.csv.bz2'), index=None)

In [None]:
#### Distribution WITHIN model_namesories
W = pd.DataFrame(columns=['stat', 'statistic', 'label', 'comparaison'])

plt.figure(figsize=(14, 5))
# plt.subplot(1, 2, 1)
model_names = models_to_data.keys()
for i, model_name in enumerate(model_names):
    subset = (np.load('{}/dist_{}_{}.npz'.format(output_dir, model_name, model_name)))['dist']
    if model_name == 'Real':
        subsetreal = subset
    sns.kdeplot(subset, linewidth=3, label=model_name)

    sc = scs.wasserstein_distance(subsetreal, subset)
    new_row = pd.DataFrame(
        {'stat': ['wasserstein'], 'statistic': [sc], 'label': [model_name], 'comparaison': ['within']})
    W = pd.concat([W, new_row], ignore_index=True)

plt.legend(loc='upper left', fontsize='x-large')
plt.savefig(os.path.join(output_dir, "distribution_haplotypic_pairwise_diff.jpg"), bbox_inches='tight', dpi=300)

In [None]:
#### Distribution WITHIN model_namesories
W = pd.DataFrame(columns=['stat', 'statistic', 'label', 'comparaison'])

plt.figure(figsize=(24, 12))
plt.subplot(1, 2, 1)
model_names = models_to_data.keys()
for i, model_name in enumerate(model_names):
    subset = (np.load('{}/dist_{}_{}.npz'.format(output_dir, model_name, model_name)))['dist']
    if model_names == 'Real':
        subsetreal = subset
    sns.distplot(subset, hist=False, kde=True,
                 kde_kws={'linewidth': 3},  #'bw':.02
                 label='{} ({} identical pairs)'.format(model_names, (subset == 0).sum()))

    sc = scs.wasserstein_distance(subsetreal, subset)
    W = pd.concat([W, pd.DataFrame(
        [{'stat': 'wasserstein', 'statistic': sc, 'pvalue': None, 'label': model_name, 'comparaison': 'between'}])],
                  ignore_index=True)

plt.title("Distribution of haplotypic pairwise difference within each dataset")
plt.legend()
#plt.savefig(outDir+"haplo_pairw_distrib_within_{}_simplify.pdf".format("-".join(categ)))
subsetreal = None

#### Distribution BETWEEN categories
plt.subplot(1, 2, 2)
model_names = models_to_data.keys()
for i, model_name in enumerate(model_names):
    subset = (np.load('{}/dist_{}_{}.npz'.format(output_dir, model_name, model_name)))['dist']
    if model_name == 'Real':
        subsetreal = subset
    sns.distplot(subset, hist=False, kde=True,
                 kde_kws={'linewidth': 3},  #'bw':.02
                 label='{} vs {} ({} identical pairs)'.format(model_name, 'Real', (subset == 0).sum()))

    sc = scs.wasserstein_distance(subsetreal, subset)
    W = pd.concat([W, pd.DataFrame(
        [{'stat': 'wasserstein', 'statistic': sc, 'pvalue': None, 'label': model_name, 'comparaison': 'between'}])])

plt.title("Distribution of haplotypic pairwise difference between datasets")
plt.legend()
plt.savefig(os.path.join(output_dir, "haplo_pairw_distrib.pdf"))

scores = pd.concat([W])

print(W)

In [None]:
MINDIST.to_csv(os.path.join(output_dir, "MINDIST.csv"))

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns


def plot_score_distributions(df):
    # Create a figure with three subplots, one for each score type
    fig, axes = plt.subplots(1, 3, figsize=(20, 15))

    # Set the column names for the score types
    score_types = ['dTS', 'dST', 'dSS']

    # Set the colors for each model
    model_colors = sns.color_palette('Set1', n_colors=len(df['cat'].unique()))

    # Create a dictionary to store the model names and their corresponding colors
    model_color_dict = dict(zip(df['cat'].unique(), model_colors))

    # Iterate over the score types
    for i, score_type in enumerate(score_types):
        # Select the data for the current score type
        data = df[['cat', score_type]]

        # Melt the data to transform it into long format
        data_melted = data.explode(score_type).reset_index(drop=True)

        # Plot the distribution for each model
        for model in df['cat'].unique():
            model_data = data_melted[data_melted['cat'] == model]
            color = model_color_dict[model]

            sns.histplot(data=model_data, x=score_type, element='step', stat='density',
                         common_norm=False, fill=False, kde=True,
                         ax=axes[i], color=color, label=model)

        # Set plot title and labels
        axes[i].set_title(f'Distribution of {score_type}')
        axes[i].set_xlabel('Score')
        axes[i].set_ylabel('Density')

        # Set legend
        axes[i].legend(title='Model', loc='upper right')

    # Adjust the spacing between subplots

    # Show the plot
    plt.show()


In [None]:
plot_score_distributions(MINDIST)


In [None]:
def plot_score_distributions(df, score_name):
    flattened_df = df[['cat', score_name]].explode(score_name).reset_index(drop=True)

    # Reorder the unique model names with 'Real' at the front
    unique_models = list(flattened_df["cat"].unique())
    unique_models.remove('Real')
    unique_models.append('Real')
    color_palette['Real'] = 'black'
    sns.set(style="whitegrid")
    plt.figure(figsize=(12, 8))

    # Iterate over the reordered unique model names and plot the distribution for each
    for model_name in unique_models:
        if model_name == 'Real':
            sns.kdeplot(data=flattened_df[flattened_df["cat"] == model_name], x=score_name, label=model_name,
                        fill=True, common_norm=False, alpha=0.6, color='black')
        else:
            sns.kdeplot(data=flattened_df[flattened_df["cat"] == model_name], x=score_name, label=model_name,
                        common_norm=False, alpha=1, linewidth=5)

    plt.xlabel(score_name)
    plt.ylabel("Density")
    plt.legend(title="Model Name", loc="upper right")
    plt.title(f"Distribution of {score_name} by Model")
    plt.savefig(os.path.join(output_dir, score_name + "_DISTRIBUTIONS.jpg"))
    plt.show()




In [None]:
plot_score_distributions(MINDIST, 'dST')

In [None]:
plot_score_distributions(MINDIST, 'dTS')

In [None]:
plot_score_distributions(MINDIST, 'dSS')

In [None]:
W = pd.DataFrame(columns=['stat', 'statistic', 'label', 'comparaison'])
for model_name in models_to_data.keys():
    for method in ['dTS', 'dST', 'dSS']:
        real = MINDIST[method][MINDIST.cat == 'Real'][0]
        sc = scs.wasserstein_distance(real, MINDIST[method][MINDIST.cat == model_name].values[0])
        new_row = pd.DataFrame({'stat': ['wasserstein'], 'statistic': [sc],
                                'label': [model_name], 'comparaison': [method]})
        W = pd.concat([W, new_row], ignore_index=True)
scores = pd.concat([W])

In [None]:
scores = pd.concat([scores, W])
scores.to_csv(os.path.join(output_dir, "scores_pairwise_distances.csv"), index=False)

plt.figure(figsize=(1.5 * len(model_names), 6))

sns.barplot(x='Cat', y='Value', hue='Variable', palette=sns.color_palette('colorblind'),
            data=(AA.drop(columns=['PrivacyLoss', 'ref'], errors='ignore')).melt(id_vars='cat').rename(
                columns=str.title))
plt.axhline(0.5, color='black')
if 'Real_test' in AA.cat.values:
    plt.axhline(np.float(AA[AA.cat == 'Real_test'].AATS), color=sns.color_palette()[0], ls='--')
plt.ylim(0, 1.1)
plt.title("Nearest Neighbor Adversarial Accuracy on training (AATS) and its components")
plt.savefig(os.path.join(output_dir, "AATS_scores.pdf"))

Test = '_Test2'
Train = ''  # means Training set is Real
dfPL = plotPrivacyLoss(Train, Test, output_dir, color_palette, color_palette)

Test = '_Test2'
Train = '_Test1'
dfPL = plotPrivacyLoss(Train, Test, output_dir, color_palette, color_palette)

In [None]:

def plot_3corr(x, y, keys, statname, col, ax=None):
    """
    Plot for x versus y with regression scores and returns correlation coefficient

    Parameters
    ----------
    x : array, scalar
    y : array, scalar
    statname : str
        'Allele frequency' LD' or '3 point correlation' etc.
    col : str, color code
        color

    """

    lims = [np.min(x), np.max(x)]
    r, _ = pearsonr(x, y)
    if sm_loaded:
        reg = sm.OLS(x, y).fit()
    if ax is None:
        ax = plt.subplot(1, 1, 1)
    if len(x) < 100:
        alpha = 1
    else:
        alpha = .6
    ax.plot(x, y, label=f"{keys[1]}: cor={round(r, 2)}", c=col, marker='o', lw=0, alpha=alpha)
    ax.plot(lims, lims, ls='--', alpha=1, c='black')
    ax.set_xlabel(f'{statname} in {keys[0]}')
    ax.set_ylabel(f'{statname} in {keys[1]}')

    return r

# 3 Points Correlation Test

In [None]:
reduced_dataset = {'Real': datasets['Real'], 'GAN 2019 Retrain': datasets['GAN 2019 Retrain'],
                   'Genome-AC-GAN By Continental Population': datasets['Genome-AC-GAN By Continental Population']}

In [None]:
def get_counts(haplosubset, points):
    counts = np.unique(
        np.apply_along_axis(
            lambda x: ''.join(map(str, x[points])),
            # lambda x: ''.join([str(x[p]) for p in points]),
            0, haplosubset),
        return_counts=True)
    return (counts)


def get_frequencies(counts):
    l = len(counts[0][0])  # haplotype length
    nind = np.sum(counts[1])
    f = np.zeros(shape=[2] * l)
    for i, allele in enumerate(counts[0]):
        f[tuple(map(int, allele))] = counts[1][i] / nind
    return f


def three_points_cor(haplosubset, out='all'):
    F = dict()
    for points in [[0], [1], [2], [0, 1], [0, 2], [1, 2], [0, 1, 2]]:
        strpoints = ''.join(map(str, points))
        F[strpoints] = get_frequencies(
            get_counts(haplosubset, points)
        )

    cors = [
        F['012'][a, b, c] - F['01'][a, b] * F['2'][c] - F['12'][b, c] * F['0'][a] - F['02'][a, c] * F['1'][b] + 2 *
        F['0'][a] * F['1'][b] * F['2'][c] for a, b, c in itertools.product(*[[0, 1]] * 3)]
    if out == 'mean':
        return (np.mean(cors))
    if out == 'max':
        return (np.max(np.abs(cors)))
    if out == 'all':
        return (cors)
    return (ValueError(f"out={out} not recognized"))


# def mult_three_point_cor(haplo, extra_sample_info, model_name, picked_three_points):
#    return [three_points_cor(haplo[np.ix_(snps,extra_sample_info.label==model_name)], out='all') for snps in picked_three_points]

# set the seed so that the same real individual are subsampled (when needed)
# to ensure consistency of the scores when adding a new model or a new sumstat
np.random.seed(3)
random.seed(3)

# Compute 3 point correlations results for different datasets and different distances between SNPs

# pick distance between SNPs at which 3point corr will be computed
# (defined in nb of snps)
# a gap of -9 means that snp triplets are chosen completely at random (not predefined distance)
# for each category we randomly pick 'nsamplesets' triplets

# if datasets have different nb of snps, for convenience we will sample
# slightly more at the beginning of the chunk

gap_vec = [1, 4, 16, 64, 256, 512, 1024, -9]
nsamplesets = 1000
min_nsnp = min([dat.shape[1] for dat in reduced_dataset.values()])
cors_meta = dict()
for gap in gap_vec:
    print(f'\n gap={gap} SNPs', end=' ')
    if gap < 0:
        # pick 3 random snps
        picked_three_points = [random.sample(range(min_nsnp), 3) for _ in range(nsamplesets)]
    else:
        try:
            # pick 3 successive snps spearated by 'gap' SNPs
            step = gap + 1
            picked_three_points = [np.asarray(random.sample(range(min_nsnp - 2 * step), 1)) + [0, step, 2 * step]
                                   for _
                                   in range(nsamplesets)]
        except:
            continue  # if there were not enough SNPs for this gap
    cors = dict()

    for model_name in reduced_dataset.keys():
        print(model_name, end=' ')
        # cors[model_name]=[three_points_cor(haplo[np.ix_(snps,extra_sample_info.label==model_name)], out='all') for snps in picked_three_points]
        cors[model_name] = [three_points_cor(reduced_dataset[model_name][:, snps].T, out='all') for snps in
                            picked_three_points]

    cors_meta[gap] = cors.copy()

# print(cors_meta)

with open(os.path.join(output_dir, "3pointcorr.pkl"), "wb") as outfile:
    pickle.dump(cors_meta, outfile)

plt.figure(figsize=(7 * len(cors_meta), 20))
plt.subplots_adjust(hspace=0.5)
for i, gap in enumerate((cors_meta).keys()):
    ax = plt.subplot(2, int(np.ceil(len(cors_meta) / 2)), int(i) + 1)
    cors = cors_meta[gap]
    real = list(np.array(cors['Real']).flat)
    lims = [np.min(real), np.max(real)]
    model_to_corr = {}
    for key, val in cors.items():
        if key == 'Real': continue
        val = list(np.array(val).flat)
        corr = plot_3corr(x=real, y=val, keys=['Real', key],
                       statname='Correlation', col=color_palette[key], ax=ax)
        ax.set_ylabel(f'Correlation In Synthetic', fontsize=30)
        ax.set_xlabel(f'Correlation In Real', fontsize=30)
        ax.set_xlim((-.1, .1))
        ax.set_ylim((-.1, .1))
        model_to_corr[key] = corr

    corr_size = str(gap) if gap > 0 else "Random"
    title = [f"3point Correlation By {corr_size} SNPs"]
    for model_name, corr_values in model_to_corr.items():
        model_name_display = model_name.replace("Population", "").replace("By ", "")
        title.append(f"{model_name}:{corr_values*100: .1f}%")
    title = "\n".join(title)
    plt.title(title, fontsize=29, y=1.05, fontweight='bold')


plt.savefig(os.path.join(output_dir, '3point_correlations_fixlim.jpg'), bbox_inches='tight', dpi=300)