In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
df1 = pd.read_csv('results/PheWAS/gwas_catalog/phewas_gwas_catalog_refined.csv')
df2 = pd.read_csv('results/PheWAS/gwas_altas/gwas_altas_sig.csv')
df2.rename(columns={'P-value':'P'}, inplace=True)
df1 = df1[['snp', 'PMID', 'Trait_category', 'Trait', 'P', 'Domain']]
df2 = df2[['snp', 'PMID', 'Trait_category', 'Trait', 'P', 'Domain']]
# concat
df = pd.concat([df1, df2], ignore_index=True)

In [None]:
df = df.drop_duplicates(subset=['snp', 'PMID', 'Trait_category'], keep='first')

In [None]:
# count the number of unique traits
trait_counts = df['Trait_category'].value_counts().reset_index(drop=False)
print(trait_counts)
trait_counts.to_csv('results/PheWAS/trait_counts.csv', index=False)

In [None]:
# Plotting the number of unique traits per category more than 3 associations
traits = trait_counts[trait_counts['count'] >= 3]['Trait_category'].tolist()
df_filtered = df[df['Trait_category'].isin(traits)]

In [None]:
fig, ax = plt.subplots(figsize=(7, 10))
sns.countplot(data=df_filtered, y='Trait_category', order=df_filtered['Trait_category'].value_counts().index, ax=ax, color='#e76d2a')
ax.set_xlabel('Number of associations', fontsize=18, fontweight='bold')
ax.set_ylabel('', fontsize=18)
# Set the font size for the tick labels
ax.tick_params(axis='both', labelsize=15)
# REMOVE GRIDLINES
ax.grid(False)

ax.set_title('', fontsize=18, fontweight='bold')
# background color none
ax.set_facecolor('white')
# spines color black, linewidth 3
for spine in ax.spines.values():
    spine.set_color('black')
    spine.set_linewidth(3)
# remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig('results/PheWAS/trait_counts_plot.pdf', dpi=300, bbox_inches='tight', )