In [None]:

import pandas as pd
import numpy as np
import itertools
import submitit
import matplotlib.pyplot as plt
import admix_genet_cor
import admix
from os.path import join
import os
import glob
import scipy

SUPP_TABLE_URL = "https://www.dropbox.com/s/jck2mhjby2ur55j/supp_tables.xlsx?dl=1"
ROOT_DIR = (
    "/u/project/pasaniuc/pasaniucdata/admixture/projects/PAGE-QC/01-dataset/out/aframr"
)
PFILE_DIR = join(ROOT_DIR, "imputed")
trait_list = [
    f.split("/")[-1].split(".")[0] for f in glob.glob(join(ROOT_DIR, "pheno", "*.tsv"))
]
trait_info = pd.read_excel(SUPP_TABLE_URL, sheet_name="trait-info")
trait_list = trait_info["trait"].values
dict_trait_display_name = {
    row["trait"]: row["display-name"] for _, row in trait_info.iterrows()
}
GRM_DIR = "/u/scratch/k/kangchen/admix-grm/rho-model"

CLUMP_DIR = (
    "/u/project/pasaniuc/pasaniucdata/admixture/projects/PAGE-QC/02-gwas/out/aframr"
)

df_params = pd.DataFrame(
    [
        params
        for params in itertools.product(
            ["hm3", "imputed"],
            ["mafukb", "gcta"],
            [0.005, 0.05],
        )
    ],
    columns=[
        "snpset",
        "hermodel",
        "maf",
    ],
)

df_params["grm_prefix"] = df_params.apply(
    lambda p: f"{p.snpset}.{p.hermodel}.{str(p.maf)[2:]}",
    axis=1,
)


def submit_gcta_estimate(grm_prefix, trait, duffy_covar=True):
    # compile phenotype and covariates
    dset = admix.io.read_dataset(
        join(PFILE_DIR, "chr1"),
        n_anc=2,
    )
    df_trait = pd.read_csv(join(ROOT_DIR, f"pheno/{trait}.tsv"), sep="\t", index_col=0)
    df_trait.index = df_trait.index.astype(str)

    # subset for individuals with non-nan value in df_trait
    dset = dset[:, dset.indiv.index.isin(df_trait.index)]
    dset.append_indiv_info(df_trait)

    covar_cols = df_trait.columns[1:]

    df_pheno = dset.indiv[[trait]].copy()
    df_covar = dset.indiv[covar_cols].copy()
    df_covar = admix.data.convert_dummy(df_covar)

    # special case for duffy SNPs, include the duffy SNPs in the covariate
    if duffy_covar:
        # find closest SNPs
        duffy_snp_loc = np.argmin(np.abs(dset.snp.POS - 159204893))
        assert dset.snp.CHROM.iloc[duffy_snp_loc] == 1
        duffy_lanc = dset[duffy_snp_loc].lanc.sum(axis=[0, 2]).compute()
        df_covar["duffy_lanc"] = duffy_lanc

    ### include GWAS variants as covariates
    clump_path = join(CLUMP_DIR, trait, "PLINK.imputed.clumped")
    if (sum(1 for line in open(clump_path)) > 1) and trait != "total_wbc_cnt":
        # when there is clumped variants
        df_clump = pd.read_csv(
            join(CLUMP_DIR, trait, "PLINK.imputed.clumped"), delim_whitespace=True
        )
        # include variants with h2 > 0.004
        df_clump = df_clump[df_clump.P < scipy.stats.chi2.sf(dset.n_indiv * 0.004, 1)]
        if len(df_clump) > 0:
            print(f"{len(df_clump)} SNPs with expected h2 > 0.2%")
            df_clump_geno = []
            for chrom, df_chrom_clump in df_clump.groupby("CHR"):
                dset_chrom = admix.io.read_dataset(
                    join(PFILE_DIR, f"chr{chrom}"),
                    n_anc=2,
                )
                dset_chrom = dset_chrom[
                    df_chrom_clump.SNP.values, dset.indiv.index.values
                ]
                df_clump_geno.append(
                    pd.DataFrame(
                        dset_chrom.geno.sum(axis=2).T.compute(),
                        columns=dset_chrom.snp.index.values,
                        index=dset_chrom.indiv.index.values,
                    )
                )
            df_clump_geno = pd.concat(df_clump_geno, axis=1)
            print(df_clump_geno)
            df_covar = pd.merge(
                df_covar, df_clump_geno, left_index=True, right_index=True
            )

    for col in df_pheno.columns:
        df_pheno[col] = admix.data.quantile_normalize(df_pheno[col])

    for col in df_covar.columns:
        df_covar[col] = admix.data.quantile_normalize(df_covar[col])

    # fill na with column mean
    df_covar.fillna(df_covar.mean(), inplace=True)

    df_id = pd.DataFrame(
        {"FID": df_pheno.index.values, "IID": df_pheno.index.values},
        index=df_pheno.index.values,
    )
    df_pheno = pd.merge(df_id, df_pheno, left_index=True, right_index=True)
    df_covar = pd.merge(df_id, df_covar, left_index=True, right_index=True)

    out_dir = f"out/gcta-estimate/{trait}-{grm_prefix}"
    os.makedirs(out_dir, exist_ok=True)
    rho_list = np.linspace(0, 1, 21)

    ### fit different rho
    for rho in rho_list:
        grm = join(GRM_DIR, f"{grm_prefix}/rho{int(rho * 100)}")
        out_prefix = os.path.join(out_dir, f"rho{int(rho * 100)}")
        if not os.path.exists(out_prefix + ".hsq"):
            admix.tools.gcta.reml(
                grm_path=grm,
                df_pheno=df_pheno,
                df_covar=df_covar,
                out_prefix=out_prefix,
                n_thread=4,
            )

df_params = pd.DataFrame(
    [params for params in itertools.product(df_params.grm_prefix.unique(), trait_list)],
    columns=["grm_prefix", "trait"],
)
df_todo_params = df_params[
    df_params.apply(
        lambda x: not os.path.exists(
            f"out/gcta-estimate/{x.trait}-{x.grm_prefix}/rho100.hsq"
        ),
        axis=1,
    )
]
df_todo_params

executor = submitit.SgeExecutor(folder="./submitit-logs")

executor.update_parameters(
    time_min=700,
    memory_g=20,
    #     queue="highp",
    setup=[
        "export PATH=~/project-pasaniuc/software/miniconda3/bin:$PATH",
        "export PYTHONNOUSERSITE=True",
    ],
)

jobs = executor.map_array(
    submit_gcta_estimate,
    df_todo_params.grm_prefix,
    df_todo_params.trait,
)
