# SCRIPT TO VISUALIZE QUANTITATIVE AND BINARY REGENIE RESULTS THROUGH FOREST PLOTS

#### Initialization
##### Load packages

In [None]:
import os
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from matplotlib.lines import Line2D

from src.results import pheno_search, pval_stars, plot_BT, plot_QT, plot_BT_grouped_by_mask, plot_QT_grouped_by_mask

Path("../tmp").resolve().mkdir(parents=True, exist_ok=True)

##### Load field_id - title dictionary

If field_id is not in showcase dictionary from UK Biobank (https://biobank.ndph.ox.ac.uk/showcase/schema.cgi?id=16), should be added manually to 'Data_Dictionary_Custom.txt'

In [None]:
ukb_coding = pd.read_csv(
    "/opt/notebooks/gogoGPCR2/data/misc/Data_Dictionary_ShowCase.txt",
    sep="\t", 
    on_bad_lines="skip",
    quotechar='"',
    usecols=["field_id", "title"],
)

custom_coding = pd.read_csv(
    "/opt/notebooks/gogoGPCR2/data/misc/Data_Dictionary_Custom.txt",
    sep="\t", 
    on_bad_lines="skip",
    quotechar='"',
    usecols=["field_id", "title"],
)

### Data
#### Binary traits

In [None]:
# Binary trait files
files_binary = [
    f"file:/mnt/project/gogoGPCR2/Results/WGS/BT/{file}"
    for file in os.listdir(f"/mnt/project/gogoGPCR2/Results/WGS/BT/")
    if file.endswith(".regenie")
]

In [None]:
# Load raw DF
df_raw_binary = pd.read_csv(files_binary[0], delimiter=" ", header="infer", comment="#").assign(
    SOURCE=os.path.basename(files_binary[0])
)

df_raw_binary = pd.concat(
    [df_raw_binary]
    + [
        pd.read_csv(fp, delimiter=" ", comment="#").assign(SOURCE=os.path.basename(fp))
        for fp in files_binary[1:]
    ],
    axis=0,
)

In [None]:
# Fix common fields
TEST="ADD"
AAF = "all"

df_binary = df_raw_binary
df_binary.loc[:, "GENE"] = df_binary.ID.apply(lambda x: x.split(".")[0])
df_binary.loc[:, "MASK"] = df_binary.ALLELE1.apply(lambda x: x.split(".", maxsplit=2)[0])
df_binary.loc[:, "AAF"] = df_binary.ALLELE1.apply(lambda x: x.split(".", maxsplit=1)[-1])
df_binary.loc[:, "TRAIT"] = "BT"  # df.FILE.apply(lambda x: x[0])
df_binary.loc[:, "PHENO"] = df_binary.SOURCE.apply(lambda x: "_".join(x.split("_")[1:]).split(".")[0])
df_binary = df_binary.loc[df_binary.TEST.eq(TEST),:]
df_binary = df_binary.drop(["ID", "ALLELE0", "ALLELE1", "EXTRA", "SOURCE", "TEST"], axis=1)

In [None]:
# Filters
bt = df_binary.TRAIT == "BT"

# Fix Binary Traits
df_binary.loc[bt, "OR"] = np.exp(df_binary.loc[bt, "BETA"])
df_binary.loc[bt, "OR_up"] = np.exp(df_binary.loc[bt, "BETA"] + df_binary.loc[bt, "SE"])
df_binary.loc[bt, "OR_low"] = np.exp(df_binary.loc[bt, "BETA"] - df_binary.loc[bt, "SE"])
df_binary.loc[bt, "OR_up_lim"] = df_binary.loc[bt, "OR_up"] - df_binary.loc[bt, "OR"]
df_binary.loc[bt, "OR_low_lim"] = df_binary.loc[bt, "OR"] - df_binary.loc[bt, "OR_low"]

# Final fixes
df_binary.loc[:, "Phenotype"] = df_binary.PHENO.apply(
    lambda x: pheno_search(x, ukb_coding, custom_coding).replace('"', "").strip()
)
df_binary.loc[:, "pval"] = np.power(10, -df_binary["LOG10P"])
df_binary.loc[:, "pval_stars"] = df_binary["pval"].apply(lambda x: pval_stars(x))
df_binary.loc[:, "N_pos"] = (2 * df_binary["N"] * df_binary["A1FREQ"]).astype(int)

# Singletons
df_binary = df_binary.loc[df_binary.AAF != "singleton", :]

In [None]:
phenos_to_remove = []

plt_df_binary = (
    df_binary.sort_values(by=["Phenotype", "AAF"], ascending=[True, False])  # , "AAF"
    .groupby(["Phenotype", "MASK"])
    .first()
    .reset_index()
)

plt_df_binary = plt_df_binary.loc[~plt_df_binary.Phenotype.astype(str).isin(phenos_to_remove), :]

effect = "OR"

group_by_mean = (
    pd.DataFrame({"mean": plt_df_binary.groupby(["Phenotype"]).agg("mean", numeric_only=True)[effect]})
    .sort_values(by="mean", ascending=False)
    .reset_index()
)

sorter = group_by_mean.Phenotype.tolist()

plt_df_binary["Phenotype"] = plt_df_binary["Phenotype"].astype("category")
plt_df_binary["Phenotypes"] = plt_df_binary["Phenotype"].cat.set_categories(sorter)

plt_df_binary = plt_df_binary.sort_values(
    by=["Phenotype", "MASK"], ascending=[True, False]
).reset_index(drop=True)

phenotypes = plt_df_binary.Phenotype.unique()

#### Quantitative traits

In [None]:
# Quantitative traits files
files_quantiative = [
    f"file:/mnt/project/gogoGPCR2/Results/WES/QT/{file}"
    for file in os.listdir(f"/mnt/project/gogoGPCR2/Results/WES/QT/")
    if file.endswith(".regenie")
]

In [None]:
# Load raw DF
df_raw_quantiative = pd.read_csv(files_quantiative[0], delimiter=" ", header="infer", comment="#").assign(
    SOURCE=os.path.basename(files_quantiative[0])
)

df_raw_quantiative = pd.concat(
    [df_raw_quantiative]
    + [
        pd.read_csv(fp, delimiter=" ", comment="#").assign(SOURCE=os.path.basename(fp))
        for fp in files_quantiative[1:]
    ],
    axis=0,
)

In [None]:
# Fix common fields
TEST="ADD"
AAF = "all"

df_quantiative = df_raw_quantiative
df_quantiative.loc[:, "GENE"] = df_quantiative.ID.apply(lambda x: x.split(".")[0])
df_quantiative.loc[:, "MASK"] = df_quantiative.ALLELE1.apply(lambda x: x.split(".", maxsplit=2)[0])
df_quantiative.loc[:, "AAF"] = df_quantiative.ALLELE1.apply(lambda x: x.split(".", maxsplit=1)[-1])
df_quantiative.loc[:, "TRAIT"] = "QT"  # df.FILE.apply(lambda x: x[0])
df_quantiative.loc[:, "PHENO"] = df_quantiative.SOURCE.apply(lambda x: "_".join(x.split("_")[1:]).split(".")[0])
df_quantiative = df_quantiative.loc[df_quantiative.TEST.eq(TEST),:]
df_quantiative = df_quantiative.drop(["ID", "ALLELE0", "ALLELE1", "EXTRA", "SOURCE", "TEST"], axis=1)

In [None]:
# Filters
qt = df_quantiative.TRAIT == "QT"

# Fix Quantitative Traits
df_quantiative.loc[qt, "BETA_up_lim"] = df_quantiative.loc[qt, "BETA"] + df_quantiative.loc[qt, "SE"]
df_quantiative.loc[qt, "BETA_low_lim"] = df_quantiative.loc[qt, "BETA"] - df_quantiative.loc[qt, "SE"]

# Final fixes
df_quantiative.loc[:, "Phenotype"] = df_quantiative.PHENO.apply(
    lambda x: pheno_search(x, ukb_coding, custom_coding).replace('"', "").strip()
)
df_quantiative.loc[:, "pval"] = np.power(10, -df_quantiative["LOG10P"])
df_quantiative.loc[:, "pval_stars"] = df_quantiative["pval"].apply(lambda x: pval_stars(x))
df_quantiative.loc[:, "N_pos"] = (2 * df_quantiative["N"] * df_quantiative["A1FREQ"]).astype(int)

# Singletons
df_quantiative = df_quantiative.loc[df_quantiative.AAF != "singleton", :]

In [None]:
phenos_to_remove_quantiative = []

plt_df_quantiative = (
    df_quantiative.sort_values(by=["Phenotype", "AAF"], ascending=[True, False])  # , "AAF"
    .groupby(["Phenotype", "MASK"])
    .first()
    .reset_index()
)

plt_df_quantiative = plt_df_quantiative.loc[~plt_df_quantiative.Phenotype.astype(str).isin(phenos_to_remove_quantiative), :]

effect = "BETA"

group_by_mean = (
    pd.DataFrame({"mean": plt_df_quantiative.groupby(["Phenotype"]).agg("mean", numeric_only=True)[effect]})
    .sort_values(by="mean", ascending=False)
    .reset_index()
)

sorter = group_by_mean.Phenotype.tolist()

plt_df_quantiative["Phenotype"] = plt_df_quantiative["Phenotype"].astype("category")
plt_df_quantiative["Phenotypes"] = plt_df_quantiative["Phenotype"].cat.set_categories(sorter)

plt_df_quantiative = plt_df_quantiative.sort_values(
    by=["Phenotype", "MASK"], ascending=[True, False]
).reset_index(drop=True)

phenotypes = plt_df_quantiative.Phenotype.unique()

### Plots

#### Binary traits

In [None]:
plot = plot_BT(
    plt_df_binary, title=f"GIPR Binary Traits. WGS", xlim=[0, 10]
)

plt.savefig(
    f"gipr_WGS_BT1.svg",
    dpi=600,
    bbox_inches="tight",
    format="svg",
)

In [None]:
plot = plot_BT_grouped_by_mask(
    plt_df_binary, title=f"GIPR Binary Traits. WGS", xlim=[0, 10]
)

plt.savefig(
    f"gipr_WGS_BT2.svg",
    dpi=600,
    bbox_inches="tight",
    format="svg",
)

#### Quantitative traits

In [None]:
plot = plot_QT(
    plt_df_quantiative,
    title=f"GIPR Quantitative Traits. WGS",
    xlim=[-4, 4],
    height=30,
)

plt.savefig(
    f"gipr_WGS_QT1.svg",
    dpi=600,
    bbox_inches="tight",
    format="svg",
)

In [None]:
plot = plot_QT_grouped_by_mask(
    plt_df_quantiative,
    title=f"GIPR Quantitative Traits. WGS",
    xlim=[-4, 4],
    height=30,
)

plt.savefig(
    f"gipr_WGS_QT2.svg",
    dpi=600,
    bbox_inches="tight",
    format="svg",
)