In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import glob

import seaborn as sns

s_het_paths = {
    'Weghorn-drift': ".../450k/selection_weghorn/weghorn_drift_gencode-v34.txt",
    'Cassa': ".../450k/selection_cassa/cassa_supp_table_1_gencode-v34.txt",
    'PLI': ".../450k/selection_pli/gnomad.v2.1.1.PLI_gencode-v34.txt", 
    "Roulette": ".../450k/selection_roulette/s_het_roulette_gencode-v34.csv",
}

#gene panel
gene_panel = pd.read_csv(".../450k/regions/gene-panel-gencode-v34.txt", header=None)
gene_panel.columns = ['Gene name', 'Gene panel']

gene_panel['Gene panel original'] = gene_panel['Gene panel'].copy()

gene_panel.loc[gene_panel['Gene panel'] != 'ID-total', 'Gene panel'] = 'AR_without_ID'
gene_panel.tail(3)

In [None]:
roulette_s_hets = pd.read_csv(s_het_paths['Roulette'], sep='\t').merge(gene_panel[gene_panel['Gene panel'].isin(['ID-total', 'AR_without_ID'])], 
                                                                       left_on='gene_symbol', right_on='Gene name', how='inner')
roulette_s_hets['Gene panel'].value_counts()

In [None]:
sns.stripplot(data=roulette_s_hets, x='Gene panel', y='s_het', jitter=True, alpha=0.5)

In [None]:
roulette_s_hets.groupby('Gene panel')['s_het'].mean()

In [None]:
n_samples = 20
random_state = 42
# random_state = 1


# Define bin edges
accepted_id_list = ['EIF4A3', 'SOBP']
accepted_non_id_list = ['PRKDC']
accepted_panels = [roulette_s_hets[roulette_s_hets['Gene name'].isin(accepted_id_list)].copy()]

for idx in range(n_samples):
    sample = roulette_s_hets[roulette_s_hets['Gene name'].isin(accepted_non_id_list)].copy()
    sample['Gene panel'] = f'AR_without_ID_{idx}'

    accepted_panels.append(sample)
    
accepted_panels = pd.concat(accepted_panels)

bin_edges = np.arange(0, 0.9, 0.05)

# # Assign each row to a bin
roulette_s_hets['s_het_bin'] = pd.cut(roulette_s_hets['s_het'], bins=bin_edges, include_lowest=True, right=False)

sampled_data = []

for bin in roulette_s_hets['s_het_bin'].unique():
    bin_data = roulette_s_hets[roulette_s_hets['s_het_bin'] == bin].copy()

    min_size = bin_data.groupby('Gene panel').size().min()

    if bin_data['Gene panel'].unique().size < 2:
        # print(bin_data[['Gene name', 'Gene panel']])
        continue

    sampled_min_size = [bin_data[bin_data['Gene panel'] == 'ID-total'].sample(min_size, random_state=random_state)]
    
    for idx in range(n_samples):
        sample = bin_data[bin_data['Gene panel'] == 'AR_without_ID'].sample(min_size).copy()
        sample['Gene panel'] = f'AR_without_ID_{idx}'

        sampled_min_size.append(sample)

    sampled_min_size = pd.concat(sampled_min_size)

    sampled_data.append(sampled_min_size) 

    print (bin, bin_data.groupby('Gene panel').size()['ID-total'], bin_data.groupby('Gene panel').size()['AR_without_ID'])
    # break

sampled_data = pd.concat(sampled_data)
sampled_data = pd.concat([sampled_data, accepted_panels])


rename_dict = {k: f'{k.split('-')[0]}_sampled' for k in sampled_data['Gene panel'].unique()}
sampled_data['Gene panel'] = sampled_data['Gene panel'].apply(lambda x: rename_dict[x])

In [None]:
sns.stripplot(data=sampled_data, y='Gene panel', x='s_het', jitter=True, alpha=0.5)
# plt.xticks(rotation=45)

In [None]:
sns.barplot(data=sampled_data, y='Gene panel', x='s_het', estimator=np.mean)
sns.barplot(data=sampled_data, y='Gene panel', x='s_het', estimator=np.median)

In [None]:
sampled_data.groupby('Gene panel')['s_het'].mean()

In [None]:
sampled_data['Gene panel'].value_counts()

In [None]:
sampled_data['Gene panel original'].value_counts()

In [None]:
sampled_data['Gene panel original'].value_counts()

In [None]:
sampled_data[['Gene name', 'Gene panel']].to_csv('.../450k/regions/gene-panel-gencode-v34.sampled.txt', sep=',', index=False)

# Look at the distibutions

In [None]:
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

import glob

import seaborn as sns

from ukbb_recessive.data_collection.variants import VariantFeatures

variant_features = VariantFeatures()

variants_paths_cfg = {
    'recessive' : {
        'cohort_files': glob.glob(".../450k/RAP_output_per_chr/filtered_plps/basic/new_gene_names/new_freq/new_relatedness/chr*"), 
        'all_variants_file': ".../450k/plp_selection/basic/new_gene_names/new_freq/new_relatedness/all_chr_total_presumable_plps_HFE_final_sorted.txt"
    }
}

with open(".../450k/samples/european_non_related_no_withdrawal_to_include_450k.no_hom_comp_het.txt", 'r') as f:
    samples = [l.strip() for l in f.readlines()]

print ("Number of samples:", len(samples))

In [None]:
sampeld_panel = pd.read_csv('.../450k/regions/gene-panel-gencode-v34.sampled.txt')

sampeld_panel['Gene panel'].value_counts()

In [None]:
# select rare PLPs
rare_plps = variant_features.collect_rare_plps(het_occurence_threshold=20,
                                               hom_occurence_threshold=0,
                                                all_plps_file=variants_paths_cfg['recessive']['all_variants_file'],
                                                s_het_file=s_het_paths['Roulette'],
                                                genes_list=None)

rare_plps = rare_plps.merge(sampeld_panel, left_on='gene', right_on='Gene name', how='inner')

In [None]:
gene_rare_plps = rare_plps.groupby('gene').agg({'s_het': 'first', 'Gene panel': 'first', 'hets': 'max'}).reset_index()

In [None]:
sns.barplot(data=gene_rare_plps, y='s_het', x='hets', hue='Gene panel')