In [None]:
import os.path
from pathlib import Path

import plotly.express as px
import plotly.graph_objects as go
import plotly.subplots as sp

from utils.util import *
from scipy.stats import wasserstein_distance
%matplotlib inline

In [None]:
output_dir = "population_analysis"
Path(output_dir).mkdir(parents=True, exist_ok=True)
figwi = 20

In [None]:
color_palette = ["white", "blue", "red", "green", "yellow", "black", "pink"]
sns.set_palette(color_palette)
sns.palplot(sns.color_palette())

In [None]:
real_data_1000_genome = load_real_data(hapt_genotypes_path=f"../{REAL_10K_SNP_1000G_PATH}",
                                       extra_data_path=f"../{REAL_EXTRA_DATA_PATH}")
real_data_1000_genome = real_data_1000_genome[
    ~real_data_1000_genome['Superpopulation code'].str.contains(',', na=False)]

real_data_1000_genome

In [None]:
genotypes = get_relevant_columns(input_df=real_data_1000_genome, input_columns=[])
relevant_columns = genotypes + ['Superpopulation code', 'Population code']
print(f"Number of SNPs: {len(genotypes)}")

In [None]:
pop = real_data_1000_genome['Superpopulation code']
sub_pop = real_data_1000_genome['Population code']
real_data_1000_genome_genotypes = real_data_1000_genome[genotypes]
genotypes = [genotype for genotype in range(real_data_1000_genome_genotypes.shape[1])]
real_data_1000_genome_genotypes.columns = genotypes

real_data_1000_genome_genotypes['is_real'] = 1
real_data_1000_genome_genotypes_by_pop = real_data_1000_genome_genotypes.copy()
real_data_1000_genome_genotypes_by_pop['Superpopulation code'] = pop
real_data_1000_genome_genotypes_by_pop = real_data_1000_genome_genotypes_by_pop.reset_index(drop=True)

real_data_1000_genome_genotypes_by_sub_pop = real_data_1000_genome_genotypes.copy()
real_data_1000_genome_genotypes_by_sub_pop['Population code'] = sub_pop
real_data_1000_genome_genotypes_by_sub_pop = real_data_1000_genome_genotypes_by_sub_pop.reset_index(drop=True)

real_data_pop_ane_sub_pop = real_data_1000_genome_genotypes.copy()
real_data_pop_ane_sub_pop['Superpopulation code'] = pop
real_data_pop_ane_sub_pop['Population code'] = sub_pop
real_data_1000_genome.head()

In [None]:
# def prepare_synthetic_data(input_file_path, target_column, number_of_samples):
#     synthetic_pop_results = pd.read_csv(input_file_path, sep=' ', header=None)
#
#     category_counts = synthetic_pop_results[0].value_counts()
#     sample_counts = (category_counts / category_counts.sum() * (number_of_samples)).astype(int)
#     synthetic_pop_results = synthetic_pop_results.sample(frac=1).reset_index(drop=True)
#     # Sample rows from each category
#     synthetic_pop_results = synthetic_pop_results.groupby(0).apply(
#         lambda x: x.sample(sample_counts[x.name])).reset_index(drop=True)
#     pop = synthetic_pop_results[0]
#     synthetic_pop_results = synthetic_pop_results.drop(0, axis=1)
#     synthetic_pop_results.columns = [genotype for genotype in range(synthetic_pop_results.shape[1])]
#
#     synthetic_pop_results[target_column] = pop.str.replace('Fake_', "")
#     return synthetic_pop_results

def prepare_synthetic_data(input_file_path, target_column):
    synthetic_pop_results = pd.read_csv(input_file_path, sep=' ', header=None)
    pop = synthetic_pop_results[0]
    synthetic_pop_results = synthetic_pop_results.drop(0, axis=1)
    synthetic_pop_results.columns = [genotype for genotype in range(synthetic_pop_results.shape[1])]
    synthetic_pop_results[target_column] = pop.str.replace('Fake_', "")
    return synthetic_pop_results

def prepare_old_synthetic_data(input_file_path):
    synthetic_pop_results = pd.read_csv(input_file_path, sep=' ', header=None)
    synthetic_pop_results = synthetic_pop_results.drop([0, 1], axis=1)
    synthetic_pop_results.columns = [genotype for genotype in range(synthetic_pop_results.shape[1])]
    return synthetic_pop_results


In [None]:
synthetic_sub_pop_results = prepare_synthetic_data(
    '../resource/Genome-AC-GAN By National Population genotypes.hapt', 'Population code')
synthetic_sub_pop_results

In [None]:
synthetic_pop_results = prepare_synthetic_data(
    '../resource/Genome-AC-GAN By Continental Population genotypes.hapt', 'Superpopulation code')
synthetic_pop_results

In [None]:
synthetic_pop_results.shape

In [None]:
real_with_fake_by_pop = pd.concat([real_data_1000_genome_genotypes_by_pop, synthetic_pop_results])
real_with_fake_by_pop['is_real'] = real_with_fake_by_pop['is_real'].fillna(0)
real_with_fake_by_pop

In [None]:
real_with_fake_by_sub_pop = pd.concat([real_data_1000_genome_genotypes_by_sub_pop, synthetic_sub_pop_results])
real_with_fake_by_sub_pop['is_real'] = real_with_fake_by_sub_pop['is_real'].fillna(0)
real_with_fake_by_sub_pop["Population code"].unique()

In [None]:
def plot_pca_real_vs_fake(df, population_col, is_real_col, pop_type, color, number_of_columns=5):
    """
    Plot PCA of real compared to fake on each population, with the number of required components.

    Args:
    df (pd.DataFrame): The dataset containing the samples.
    population_col (str): The name of the column containing the population code.
    is_real_col (str): The name of the column containing the indicator for real/fake samples.
    """

    # Get unique populations
    populations = list(df[population_col].unique())
    populations = [pop for pop in populations if "," not in pop]
    num_populations = len(populations)
    num_rows = min(5, int(np.ceil(num_populations / number_of_columns)))
    num_cols = number_of_columns
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(30, 6 * num_rows))
    # fig.suptitle(f'PCA Population Comparison {pop_type} Population\n', fontsize=24, fontweight='bold')

    # Loop through populations
    for pop_index, pop_name in enumerate(populations):

        row = pop_index // num_cols
        col = pop_index % num_cols
        if row > 4:
            break
        print(f"start calculating population: {pop_name}")
        # Get real and fake samples for population
        pop_df = df[df[population_col] == pop_name].reset_index(drop=True)
        real_samples = pop_df[pop_df[is_real_col] == 1].iloc[:, :-2].values
        fake_samples = pop_df[pop_df[is_real_col] == 0].iloc[:, :-2].values

        print(f"calculating PCA population: {pop_name}")
        # Fit PCA on real samples
        pca_real = PCA(n_components=2)
        pca_real.fit(real_samples)
        pca_real_transformed = pca_real.transform(real_samples)

        pca_fake_transformed = pca_real.transform(fake_samples)

        # Plot PCA comparison
        if num_rows > 1:
            ax = axes[row, col]
        else:
            ax = axes[col]
        ax.scatter(pca_real_transformed[:, 0], pca_real_transformed[:, 1], color='black', label='Real',
                   alpha=0.6, s=100)
        ax.scatter(pca_fake_transformed[:, 0], pca_fake_transformed[:, 1], color=color, alpha=0.5, s=100)

        title = ax.set_title(pop_name, fontsize=40, fontweight='bold', color="white")
        title.set_bbox({'facecolor': 'black', 'edgecolor': color, 'pad': 10})  # Add a box around the title
        print(f"finished calculating population: {pop_name}\n")

    # plt.tight_layout()
    plt.savefig(os.path.join(output_dir, f"pca2_on_{pop_type}.jpg"), bbox_inches='tight', dpi=500)
    plt.show()


In [None]:
plot_pca_real_vs_fake(real_with_fake_by_pop.copy(), 'Superpopulation code', 'is_real', 'Continental', color='blue')


In [None]:
from sklearn.manifold import TSNE  # Import t-SNE

def plot_tsne_real_vs_fake(df, population_col, is_real_col, pop_type, color, number_of_columns=5):
    """
    Plot t-SNE of real compared to fake on each population, with the number of required components.

    Args:
    df (pd.DataFrame): The dataset containing the samples.
    population_col (str): The name of the column containing the population code.
    is_real_col (str): The name of the column containing the indicator for real/fake samples.
    """

    # Get unique populations
    populations = list(df[population_col].unique())
    populations = [pop for pop in populations if "," not in pop]
    num_populations = len(populations)
    num_rows = min(5, int(np.ceil(num_populations / number_of_columns)))
    num_cols = number_of_columns
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(30, 6 * num_rows))

    # Loop through populations
    for pop_index, pop_name in enumerate(populations):

        row = pop_index // num_cols
        col = pop_index % num_cols
        if row > 4:
            break
        print(f"Start calculating population: {pop_name}")

        # Get real and fake samples for population
        pop_df = df[df[population_col] == pop_name].reset_index(drop=True)
        real_samples = pop_df[pop_df[is_real_col] == 1].iloc[:, :-2].values
        fake_samples = pop_df[pop_df[is_real_col] == 0].iloc[:, :-2].values

        print(f"Calculating t-SNE for population: {pop_name}")

        # Fit t-SNE on real samples
        tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42)  # Adjust parameters as needed
        tsne_real_transformed = tsne.fit_transform(real_samples)
        tsne_fake_transformed = tsne.fit_transform(fake_samples)
        wd = wasserstein_distance(tsne_real_transformed.flatten(), tsne_fake_transformed.flatten())

        # Plot t-SNE comparison
        if num_rows > 1:
            ax = axes[row, col]
        else:
            ax = axes[col]
        ax.scatter(tsne_real_transformed[:, 0], tsne_real_transformed[:, 1], color='black', label='Real',
                   alpha=0.6, s=100)
        ax.scatter(tsne_fake_transformed[:, 0], tsne_fake_transformed[:, 1], color=color,
                   label=pop_name, alpha=0.5, s=100)

        title = ax.set_title(pop_name, fontsize=40, fontweight='bold', color="white")
        title.set_bbox({'facecolor': 'black', 'edgecolor': color, 'pad': 10})  # Add a box around the title
        print(f"Finished calculating population: {pop_name} - {round(wd, 3)}\n")

    plt.savefig(os.path.join(output_dir, f"tsne_on_{pop_type}.jpg"), bbox_inches='tight', dpi=500)
    plt.show()


In [None]:
# plot_tsne_real_vs_fake(real_with_fake_by_pop.copy(), 'Superpopulation code', 'is_real', 'Continental', color='blue')


In [None]:
plot_pca_real_vs_fake(real_with_fake_by_sub_pop.copy(), 'Population code', 'is_real', 'National', color='green')

In [None]:
n_components = 2
pca = PCA(n_components=n_components)

pca.fit(real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 1].iloc[:, :-2].values)

components = pca.transform(real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 1].iloc[:, :-2].values)
fig1 = px.scatter(components, x=0, y=1,
                  color=real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 1]["Superpopulation code"],
                  title="PCA By Super Population")

components = pca.transform(real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 0].iloc[:, :-2].values)

fig2 = px.scatter(components, x=0, y=1,
                  color=real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 0]["Superpopulation code"],
                  title="PCA By Super Population")

In [None]:
n_components = 2
pca = PCA(n_components=n_components)

# Assign unique numeric labels to 'Superpopulation code'
superpop_labels = np.unique(real_with_fake_by_pop['Superpopulation code'])
label_map = {label: i for i, label in enumerate(superpop_labels)}
real_with_fake_by_pop['Superpop Label'] = real_with_fake_by_pop['Superpopulation code'].map(label_map)
print("label_map:", label_map)
# Fit and transform data for real samples
real_data = real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 1].iloc[:, :-3].values
pca.fit(real_data)
real_components = pca.transform(real_data)

# Fit and transform data for fake samples
fake_data = real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 0].iloc[:, :-3].values
pca.fit(fake_data)
fake_components = pca.transform(fake_data)

# Create subplots
fig = sp.make_subplots(rows=1, cols=2,
                       subplot_titles=("PCA By Continental Population (Real)", "PCA By Continental Population (Fake)"))
fig.update_layout(width=1000)
# Add scatter plot for real components
fig.add_trace(
    go.Scatter(x=real_components[:, 0], y=real_components[:, 1],
               mode='markers',
               marker=dict(color=real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 1]['Superpop Label'],
                           colorscale='Viridis', colorbar=dict(title='Continental')),
               showlegend=False),

    row=1, col=1
)

# Add scatter plot for fake components
fig.add_trace(
    go.Scatter(x=fake_components[:, 0], y=fake_components[:, 1],
               mode='markers',
               marker=dict(color=real_with_fake_by_pop[real_with_fake_by_pop['is_real'] == 0]['Superpop Label'],
                           colorscale='Viridis', colorbar=dict(title='Continental')),
               showlegend=False),
    row=1, col=2
)

fig.write_image(os.path.join(output_dir, "Continental_pca_total.jpg"), format="jpeg")
fig.show()


In [None]:
def dataframe_to_dict(df, key_column, value_column):
    result_dict = {}

    for index, row in df.iterrows():
        key = row[key_column]
        value = row[value_column]
        result_dict[key] = value

    return result_dict

In [None]:
pop_to_super_pop = dataframe_to_dict(real_data_pop_ane_sub_pop, 'Population code', 'Superpopulation code')

In [None]:
import random
def plot_pca_with_colors(df, color_column, marker_column, file_name):
    n_components = 2
    pca = PCA(n_components=n_components)
    components = pca.fit_transform(df)
    color_map = {'AFR': 'white', 'AMR': 'blue', 'EAS': 'green', 'EUR': 'yellow', 'SAS': 'red'}
    fig = px.scatter(components, x=0, y=1, color=color_column, size_max=100, color_continuous_scale='Viridis',
                     symbol=marker_column, symbol_sequence=['circle-open-dot'], color_discrete_map=color_map)

    fig.update_layout(width=800, height=600, showlegend=False)
    fig.update_xaxes(title="PCA1", showticklabels=False, showgrid=False, zeroline=True)
    fig.update_yaxes(title="PCA2", showticklabels=False, showgrid=False, zeroline=True)

    # Calculate group sizes
    group_sizes = marker_column.value_counts().to_dict()

    # Add symbol name with random arrow length for each trace
    ax_ay = [(50, 50), (-50, -50), (-100, 0), (100, 0), (-80, 60), (80, -60)]

    for index, trace in enumerate(fig.data):
        pos = index % len(ax_ay)

        symbol_name = trace.name
        symbol_name = symbol_name.split(",")[-1].strip()
        x_avg = sum(trace.x) / len(trace.x)
        y_avg = sum(trace.y) / len(trace.y)

        ax_len = random.uniform(-100, 100)
        ay_len = random.uniform(-100, 100)
        color = trace.marker.color

        group_name = symbol_name
        if group_name in group_sizes:
            borderpad = group_sizes[group_name] / max(group_sizes.values()) * 25
        else:
            borderpad = 15

    fig.write_image(os.path.join(output_dir, file_name), format="jpg", width=800, height=600)
    fig.show()


In [None]:
plot_pca_with_colors(synthetic_sub_pop_results.drop(["Population code"], axis=1),
                     synthetic_sub_pop_results["Population code"].replace(pop_to_super_pop),
                     synthetic_sub_pop_results["Population code"], "fake_pca_by_sub_pop_and_pop.jpg")

In [None]:
plot_pca_with_colors(real_data_pop_ane_sub_pop[genotypes].drop([0, 1], axis=1),
                     real_data_pop_ane_sub_pop["Superpopulation code"], real_data_pop_ane_sub_pop["Population code"],
                     "real_pca_by_sub_pop_and_pop.jpg")

In [None]:
plot_pca_with_colors(synthetic_pop_results.drop(["Superpopulation code"], axis=1),
                     synthetic_pop_results["Superpopulation code"],
                     synthetic_pop_results["Superpopulation code"], "fake_pca_by_cont.jpg")

In [None]:
print(real_data_pop_ane_sub_pop[genotypes])

In [None]:
def print_frequency_compression(current_df, target_column):
    rows = []
    for pop in current_df[target_column].unique():
        if "," not in pop:
            for is_real in [0, 1]:
                tmp_df = current_df[current_df[target_column] == pop]
                tmp_df = tmp_df[tmp_df["is_real"] == is_real]
                uniques, counts = np.unique(tmp_df[genotypes], return_counts=True)
                tmp_percentages = dict(zip(uniques, 100 * counts / (len(tmp_df[genotypes]) * len(genotypes))))
                tmp_percentages = {key: round(values, 3) for key, values in tmp_percentages.items()}
                tmp_percentages["Pop"] = pop
                tmp_percentages["Type"] = "Real" if is_real == 1 else "Fake"
                rows.append(tmp_percentages)

    return pd.DataFrame(rows).sort_values(0)

In [None]:
allel_freq_df = print_frequency_compression(real_with_fake_by_pop, target_column='Superpopulation code')

fig, ax = plt.subplots(figsize=(16, 10))

real_genotypes = real_with_fake_by_sub_pop[real_with_fake_by_sub_pop["is_real"] == 1][genotypes]
uniques, counts = np.unique(real_genotypes, return_counts=True)
tmp_percentages = dict(zip(uniques, 100 * counts / (len(real_genotypes) * len(genotypes))))
tmp_percentages = {key: round(values) for key, values in tmp_percentages.items()}
ax.axhline(y=tmp_percentages[1], color='black', linestyle='--', linewidth=2)

# group the dataframe by 'Pop' and 'Type' columns, and get the sum of '1' column
grouped = allel_freq_df.groupby(['Pop', 'Type'])[1].mean()

# plot the grouped data as a bar plot with color mapped to the 'Type' categories
grouped.unstack().plot(kind='bar', ax=ax, color=['red', 'blue'], width=0.6)

# set the title and axis labels
ax.set_title('Bar plot of Pop vs. 1, colored by Type')
ax.set_xlabel('Pop')
ax.set_ylabel('1')

# show the plot

In [None]:
def split_by_target_column(input_df: pd.DataFrame, target_column, label=None):
    real_split = {}
    fake_split = {}
    populations = input_df[target_column].unique()
    input_df.set_index(target_column, inplace=True)
    category_counts = pd.Series(list(input_df.index)).value_counts()
    mean_counts = category_counts.mean()
    std_counts = category_counts.std()
    print(mean_counts, std_counts)
    for index, pop_name in enumerate(list(populations)):
        pop_df = input_df[input_df.index == pop_name]
        if label is None:
            real_split[pop_name] = pop_df[pop_df["is_real"] == 1].drop('is_real', axis=1)
            fake_split[pop_name] = pop_df[pop_df["is_real"] == 0].drop('is_real', axis=1)
        else:
            real_split[pop_name] = pop_df[pop_df["is_real"] == 1].drop('is_real', axis=1).drop(label, axis=1)
            fake_split[pop_name] = pop_df[pop_df["is_real"] == 0].drop('is_real', axis=1).drop(label, axis=1)
    return real_split, fake_split

In [None]:
real_split_by_sub_pop, fake_split_by_sub_pop = split_by_target_column(input_df=real_with_fake_by_pop.copy(),
                                                                      target_column="Superpopulation code", label='Superpop Label')

In [None]:
real_split_by_super_pop, fake_split_by_super_pop = split_by_target_column(input_df=real_with_fake_by_sub_pop.copy(),
                                                                          target_column="Population code")

In [None]:
# real_split[pop_name] = pop_df[pop_df["is_real"] == 1].drop('is_real', axis=1)
# fake_split[pop_name] = pop_df[pop_df["is_real"] == 0].drop('is_real', axis=1)

fake_sum_alleles_by_sub_pop, fake_allele_frequency_by_sub_pop, _ = build_allele_frequency(fake_split_by_sub_pop)
fake_sum_alleles_by_super_pop, fake_allele_frequency_by_super_pop, _ = build_allele_frequency(fake_split_by_super_pop)
real_sum_alleles_by_sub_pop, real_allele_frequency_by_sub_pop, _ = build_allele_frequency(real_split_by_sub_pop)
real_sum_alleles_by_super_pop, real_allele_frequency_by_super_pop, _ = build_allele_frequency(real_split_by_super_pop)
# len(real_split_by_super_pop.keys())

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error


def plotreg(x, y, keys, pop_name, col, ax=None):
    """
    Plot for x versus y with regression scores and returns correlation coefficient

    Parameters
    ----------
    x : array, scalar
    y : array, scalar
    statname : str
        'Allele frequency' LD' or '3 point correlation' etc.
    col : str, color code
        color

    """

    lims = [np.min(x), np.max(x)]
    r, _ = pearsonr(x, y)
    mae = mean_absolute_error(x, y)
    # euclidean_distance = np.linalg.norm(x - y)
    if ax is None:
        ax = plt.subplot(1, 1, 1)


    ax.plot(x, y, c=col, marker='o', lw=0, markersize=10)
    ax.plot(lims, lims, ls='--', alpha=1, c='black')
    title = ax.set_title(
        f'{pop_name}\nCorrelation={round(r * 100, 2)}%\nMAE={round(mae, 4)}',
        fontsize=65, fontweight='bold', color='black', y=1.05)
    title.set_bbox({'facecolor': (0.9, 0.9, 0.9), 'edgecolor': "black", 'pad': 1.5})
    ax.set_xlabel("AF In Real", fontsize=50, fontweight="bold")
    ax.set_ylabel("AF In Synthetic", fontsize=50, fontweight="bold")

    return r

In [None]:
def plot_allele_frequency_fake_vs_real(real_input, fake_input, color, output_file_name):
    num_cols = 5  # Number of columns for subplots
    num_rows = len(real_input) // 5 # Number of rows for subplots
    plt.figure(figsize=(10 * num_cols, 10 * num_rows))

    number_of_plots = 0
    for i, (pop_name, fake_allele_frequency) in enumerate(fake_input.items()):
        number_of_plots += 1
        if number_of_plots > (num_cols * num_rows):
            break
        ax = plt.subplot(num_rows, num_cols, i + 1)
        plotreg(x=real_input[pop_name], y=fake_allele_frequency,
                keys=['Real', "Synthetic"], pop_name=pop_name,
                col=color, ax=ax)
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, output_file_name), bbox_inches='tight', dpi=200)


In [None]:
plot_allele_frequency_fake_vs_real(real_input=real_allele_frequency_by_sub_pop,
                                   fake_input=fake_allele_frequency_by_sub_pop,
                                   color="blue", output_file_name="maf_continental.jpg")

In [None]:
plot_allele_frequency_fake_vs_real(real_input=real_allele_frequency_by_super_pop,
                                   fake_input=fake_allele_frequency_by_super_pop,
                                   color="green", output_file_name="maf_national.jpg")

In [None]:
print(real_allele_frequency_by_sub_pop["SAS"])

In [None]:

import random
RANDOM_POSITIONS = 1000
def get_allele_frequency_plot(real_frequency, fake_frequency, output_file_name):
    # Get the list of population names
    population_names = real_frequency.columns.tolist()
    if len(population_names) > 5:
        population_names = [pop for pop in population_names if pop in ['KHV', 'ACB', 'ESN', 'CDX', 'GBR']]
    population_names = random.sample(population_names, min(len(population_names), 5))
    random_snp_positions = random.sample(real_frequency.index.tolist(), RANDOM_POSITIONS)

    # Define colors for real and synthetic data
    colors = ['blue', 'red', 'green', 'purple', 'orange']

    # Create a figure and subplots
    fig, axes = plt.subplots(figsize=(60, 60), nrows=len(population_names), sharey=True)

    # Plot the data
    for i, population_name in enumerate(population_names):
        real_pop_df = real_frequency[population_name]
        fake_pop_df = fake_frequency[population_name]
        real_frequency_filtered = real_pop_df[real_pop_df.index.isin(random_snp_positions)]
        fake_frequency_filtered = fake_pop_df[fake_pop_df.index.isin(random_snp_positions)]


        # Use different colors and line styles for real and synthetic data
        axes[i].plot(fake_frequency_filtered.index, fake_frequency_filtered, label='Synthetic', color='red', linestyle='-',  linewidth=15)

        axes[i].plot(real_frequency_filtered.index, real_frequency_filtered, label='Real', color="black",  linestyle='-', alpha=0.8,linewidth=15)

        wasserstein_distance_value = wasserstein_distance(real_frequency_filtered, fake_frequency_filtered)
        title = axes[i].set_title(
            f"{population_name}: Wasserstein Distance = {round(wasserstein_distance_value, 3)},  Number Of SNPs = {round(len(real_frequency_filtered))}",
            fontsize=100, fontweight='bold', color='white', y=1.05)  # Set the title color to black
        title.set_bbox({'facecolor': 'black', 'edgecolor': "black", 'pad': 2})  # Set the title box color to white

        axes[i].set_xlabel('SNP', fontsize=70, fontweight='bold')
        axes[i].set_ylabel('MAF', fontsize=70, fontweight='bold')
        # axes[i].legend(loc='upper left', fontsize=40)
        axes[i].grid(True, color='gray')  # Change grid color to gray for better visibility
        # Add black horizontal lines at each 1000 on the x-axis
        for j in range(0, int(max(real_frequency_filtered.index)), 1000):
            axes[i].axvline(x=j, color='black', linestyle='--', alpha=0.8, linewidth=10)
        axes[i].tick_params(axis='x', labelsize=60)
        axes[i].tick_params(axis='y', labelsize=60)
    plt.tight_layout()

    # Save the figure
    plt.savefig(os.path.join(output_dir, output_file_name), bbox_inches='tight', dpi=300)


In [None]:
get_allele_frequency_plot(real_frequency=pd.DataFrame(real_allele_frequency_by_sub_pop).drop('Superpop Label', axis=0),
                          fake_frequency=pd.DataFrame(fake_allele_frequency_by_sub_pop).drop('Superpop Label', axis=0),
                          output_file_name="allele_frequency_by_continental.jpg")

In [None]:
get_allele_frequency_plot(real_frequency=pd.DataFrame(real_allele_frequency_by_super_pop),
                          fake_frequency=pd.DataFrame(fake_allele_frequency_by_super_pop),
                          output_file_name="allele_frequency_by_National.jpg")

In [None]:
real_data_pop_ane_sub_pop["Population code"]

In [None]:
sample_count_per_population = real_data_pop_ane_sub_pop["Population code"].value_counts()
print(sample_count_per_population)

In [None]:
# Assuming you have a pandas DataFrame named "df" with the desired columns
grouped_df = real_data_1000_genome.groupby(
    ["Population code", "Population name", "Superpopulation code", "Superpopulation name"]).agg({
    "Superpopulation code": "count"
})

# Rename the count columns
grouped_df = grouped_df.rename(columns={"Superpopulation code": "Superpopulation code count"})

# Reset the index to turn the groupby keys into columns
grouped_df = grouped_df.reset_index()

# Print the resulting DataFrame
print(grouped_df[grouped_df["Population code"].isin(["MSL", "FIN", "PJL", "BEB", "CLM"])])

In [None]:
# Assuming you have a pandas DataFrame named "df" with the desired columns
population_counts = real_data_1000_genome["Population code"].value_counts().reset_index()
population_counts.columns = ["Population code", "Population code count"]

superpopulation_counts = real_data_1000_genome["Superpopulation code"].value_counts().reset_index()
superpopulation_counts.columns = ["Superpopulation code", "Superpopulation code count"]

# Merge the counts into a single DataFrame
combined_counts = population_counts.merge(superpopulation_counts, how="outer", left_on="Population code",
                                          right_on="Superpopulation code")
combined_counts = combined_counts.drop("Superpopulation code", axis=1)

# Print the resulting DataFrame
print(combined_counts)


In [None]:
from sklearn.manifold import TSNE  # Import t-SNE

def calculate_2d_wasserstein_distance(d1, d2):
    distance_matrix = cdist(d1, d2)

    # Solve the linear assignment problem to find the optimal transport plan
    row_indices, col_indices = linear_sum_assignment(distance_matrix)

    # Calculate the Wasserstein distance
    return np.sum(distance_matrix[row_indices, col_indices])

def tsne_real_vs_fake(dfs, population_col, is_real_col, pop_type, color, number_of_columns=5):
    """
    Plot t-SNE of real compared to fake on each population, with the number of required components.

    Args:
    df (pd.DataFrame): The dataset containing the samples.
    population_col (str): The name of the column containing the population code.
    is_real_col (str): The name of the column containing the indicator for real/fake samples.
    """

    # Get unique populations

    num_dfs = len(dfs)
    num_rows = min(5, int(np.ceil(num_dfs / number_of_columns)))
    num_cols = number_of_columns
    fig, axes = plt.subplots(num_rows, num_cols, figsize=(30, 6 * num_rows))
    # tsne = TSNE(n_components=2, init='pca', learning_rate='auto', random_state=42)  # Adjust parameters as needed
    real_samples = list(dfs[0].values())[0][list(dfs[0].values())[0][is_real_col] == 1].iloc[:, :-2].values
    pca_real = PCA(n_components=2)
    pca_real.fit(real_samples)
    # tsne_real_transformed = tsne.fit_transform(real_samples)
    tsne_real_transformed = pca_real.transform(real_samples)
    # tsne_real_transformed = tsne.fit_transform(real_samples)
    for df_index, name_to_df in enumerate(dfs):
        for df_name, df in name_to_df.items():
            row = df_index // num_cols
            col = df_index % num_cols
            if row > 4:
                break

            fake_samples = df[df[is_real_col] == 0].iloc[:, :-2].values


            # Fit t-SNE on real samples
            pca_model = PCA(n_components=2)
            pca_model.fit(fake_samples)
            tsne_fake_transformed = pca_model.transform(fake_samples)
            # tsne_fake_transformed = tsne.fit_transform(fake_samples)
            wd = calculate_2d_wasserstein_distance(tsne_real_transformed, tsne_fake_transformed)

            # Plot t-SNE comparison
            if num_rows > 1:
                ax = axes[row, col]
            else:
                ax = axes[col]
            ax.scatter(tsne_real_transformed[:, 0], tsne_real_transformed[:, 1], color='black', label='Real',
                       alpha=0.6, s=100)
            ax.scatter(tsne_fake_transformed[:, 0], tsne_fake_transformed[:, 1], color=color, alpha=0.5, s=100)

            title = ax.set_title(f"{df_name} - wd: {round(wd, 3)}", fontsize=40, fontweight='bold', color="white")
            title.set_bbox({'facecolor': 'black', 'edgecolor': color, 'pad': 10})  # Add a box around the title

    plt.savefig(os.path.join(output_dir, f"tsne_on_{pop_type}.jpg"), bbox_inches='tight', dpi=500)
    plt.show()


In [None]:
full_datasets = {}
dfs = []
for idx in range(800, 1100, 50):
    tmp_df = prepare_synthetic_data(
        f'../fake_genotypes_sequences/new_sequences/super_pop_wd/{idx}_genotypes.hapt', 'Superpopulation code')
    tmp_df_with_real = pd.concat([real_data_1000_genome_genotypes_by_pop, tmp_df])
    tmp_df_with_real['is_real'] = tmp_df_with_real['is_real'].fillna(0)
    dfs.append({str(idx): tmp_df_with_real})

In [None]:
from scipy.optimize import linear_sum_assignment
from scipy.spatial.distance import cdist
tsne_real_vs_fake(dfs, 'Superpopulation code', 'is_real', 'Continental', color='blue')


In [None]:
tmp_df = prepare_synthetic_data(
    f'../fake_genotypes_sequences/new_sequences/super_pop_new2/{idx}_genotypes.hapt', 'Superpopulation code')
tmp_df_with_real = pd.concat([real_data_1000_genome_genotypes_by_pop, tmp_df])
full_datasets["Real"] = tmp_df_with_real[tmp_df_with_real['is_real'] == 1].iloc[:, :-2].values

In [None]:
sum_alleles_by_position, allele_frequency, is_fixed = build_allele_frequency(full_datasets)

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error


def plotreg(x, y, keys, statname, col, model_name_display, ax=None):
    """
    Plot for x versus y with regression scores and returns correlation coefficient and MSE

    Parameters
    ----------
    x : array-like, scalar
    y : array-like, scalar
    keys : tuple
        Tuple containing the model names or keys
    statname : str
        'Allele frequency', 'LD', or '3 point correlation', etc.
    col : str
        Color code or name

    Returns
    -------
    r : float
        Pearson correlation coefficient between x and y
    mse : float
        Mean Squared Error between x and y
    """

    lims = [np.min(x), np.max(x)]
    r, _ = pearsonr(x, y)
    mae = mean_absolute_error(x, y)

    if sm_loaded:
        reg = sm.OLS(x, y).fit()

    if ax is None:
        ax = plt.subplot(1, 1, 1)

    if len(x) < 100:
        alpha = 1
    else:
        alpha = .6

    ax.plot(x, y, c=col, marker='o', lw=0, alpha=alpha)
    ax.plot(lims, lims, ls='--', alpha=1, c='black')
    title = ax.set_title(
        f'{model_name_display}\nCorrelation={round(r * 100, 2)}%\nMAE={round(mae, 4)}',
        fontsize=29, fontweight="bold", y=1, color='black')

    title.set_bbox({'facecolor': 'white', 'edgecolor': "black", 'pad': 1.2})
    ax.set_xlabel("MAF In Real", fontsize=28, fontweight="bold")
    ax.set_ylabel("MAF In Synthetic", fontsize=28, fontweight="bold")

    ax.plot(x, y, c=col, marker='o', lw=0)
    ax.plot(lims, lims, ls='--', alpha=1, c='black')
    title = ax.set_title(
        f'{model_name_display}\nCorrelation={round(r * 100, 2)}%\nMAE={round(mae, 4)}',
        fontsize=35, fontweight="bold", y=1, color='black')
    title.set_bbox({'facecolor': (0.9, 0.9, 0.9), 'edgecolor': "black", 'pad': 1.5})
    ax.set_xlabel("MAF In Real", fontsize=25, fontweight="bold")
    ax.set_ylabel("MAF In Synthetic", fontsize=25, fontweight="bold")
    # Adjust vertical spacing between subplots
    # plt.subplots_adjust(hspace=0.2, wspace=0.5)

    return r, mae


def plot_allele_frequency(allele_frequency, file_name, maf, highest=False):
    # Plotting Allele frequencies in Generated vs Real
    # below a certain real frequency
    figwi = 14
    l, c = len(allele_frequency) // 20, 20
    plt.figure(figsize=(150, 6))
    if highest:
        maf = 1 - maf
        keep = (allele_frequency['Real'] >= maf)
    else:
        keep = (allele_frequency['Real'] <= maf)
    for i, (model_name, val) in enumerate(allele_frequency.items()):
        model_name_display = model_name.replace("Population", "").replace(" By", "").replace("Genome-AC-GAN",
                                                                                             "Genome-AC-GAN\n")
        if model_name != 'Real':
            print(i // 6, c, (i % c) + 1)
            ax = plt.subplot(i // 20 + 1, c, (i % c) + 1)
            plotreg(x=allele_frequency['Real'][keep], y=val[keep],
                    keys=['Real', model_name_display], statname="Allele frequency",
                    col="blue", model_name_display=model_name_display, ax=ax)

    plt.savefig(os.path.join(output_dir, file_name), bbox_inches='tight', dpi=300)

In [None]:
plot_allele_frequency(allele_frequency, 'total_allele_frequency.jpg', 1)


In [None]:
plot_allele_frequency(allele_frequency, 'zoom_lowest_total_allele_frequency.jpg', 0.2)

In [None]:
plot_allele_frequency(allele_frequency, 'zoom_highest_total_allele_frequency.jpg', 0.2, highest=True)
