# Summary statistics merger

## Aim

- 1.To merge multiple summary statistic files to new summary statistic files with common SNPs
- 2.To deal with allele flip and reserve issues in the process of merging

## Notes
 - 1. If there are duplicated `indels` in the summary statistics, they will be removed. For example, two SNPs at 10000 on chr1. one's `A0` is `T`, and `A1` is `TC`. Whereas the other one's `A0` is `TC`, and `A1` is `T`. Both of them will be removed. More about `indels` issues(https://github.com/statgenetics/UKBB_GWAS_dev/issues/81#issuecomment-1015556800).
 - 2. If duplicated `chr:pos` (GWAS) or `gene:chr:pos` (TWAS) exist, run a recursive match for each pair of them between two summary statistic files (`query`(each of inputs) and `subject` (target file)). 
 - 3. under the same `chr:pos` or `gene:chr:pos`, The variants' `A0` and `A1` are matched by exact, flip, reverse, or flip+reverse models. Only one of them is `True`, the variant in two files are matched. If they are matched by flip or flip+reverse, the sign of `query`'s `STAT` will be inversed. And the `query`'s `A0` and `A1` will be the same as the `subject`'s `A0` and `A1`.       

## Pre-requisites

Make sure you install the pre-requisited before running this notebook:

```
pip install cugg
```

## Input

- `--cwd`, the path of working directory
- `--yml_path`, the path of yaml file
- `--keep-ambiguous`, boolean. default False. if add --keep-ambiguous parameter, keep ambiguous alleles which can not be decided from flip or reverse, such as A/T or C/G. Otherwise, remove them. 
- `--intersect`, boolean. default False. if add --intersect parameter, output intersect SNPs in all input files.

### The format of the input yaml file 

For GWAS summary statistics: `ID` is `CHR,POS,A0,A1`, which can be used as a unique label for each variant.

```
INPUT:
  - ./data/testflip/*.gz:
        ID: CHR,POS,A0,A1
        CHR: CHR
        POS: POS
        A0: REF
        A1: ALT
        SNP: SNP
        STAT: BETA
        SE: SE
        P: P
  - ./data/testflip/flip/snps500_flip.regenie.snp_stats.gz:
  
TARGET: 
  - ./data/testflip/snps500.regenie.snp_stats.gz:
        ID: CHR,POS,A0,A1
        CHR: CHR
        POS: POS
        A0: REF
        A1: ALT
        SNP: SNP
        STAT: BETA
        SE: SE
        P: P
OUTPUT: data/testflip/output/
```

For TWAS summary statistics: `ID` is `GENE,CHR,POS,A0,A1`, which add the `GENE` name because a variant can be made association with multiple genes. 

```
INPUT:
  - data/twas/*.txt:
        ID: GENE,CHR,POS,A0,A1
        CHR: chrom
        POS: pos
        A0: ref
        A1: alt
        SNP: variant_id
        GENE: gene
        STAT: beta
        SE: se
        P: pval
 
  
TARGET: 
  - data/twas/DLPFC.chr6.mol_phe.cis_long_table.reformated.txt:
        ID: GENE,CHR,POS,A0,A1
        CHR: chrom
        POS: pos
        A0: ref
        A1: alt
        SNP: variant_id
        GENE: gene
        STAT: beta
        SE: se
        P: pval
OUTPUT: ../data/twas/output/
```

There are three parts in the input yaml file.
- INPUT
   - A list of yml file, as the output from yml_generator, each yml file documents a set of input
       - the input summary statistic files with the column names in below. 
       - the input files can be from multiple directory and from different format. The input paths must follow the rules related to Unix shell. the format is to pair the column names with keys (CHR, POS, A0, A1, SNP, STAT, SE, P). if not provided, the column names of the input file will be considered as the default keys.
       - The input summary statistic file cannot have duplicated chr:pos
       - The input summary statstic file cannot have # in its header
       -`ID` in yml is the rule to generate a unique identifier for each SNP, the content of ID shall be a combination of CHR, POS, A0, A1,SNP .etc but not the actual column names. ID can not take existing id columns in the original file.
- TARGET
   - the target file is a reference summary statistic file or a file with chr, pos, a0, a1 columns at least, which the other files compare with.
- OUTPUT
   - the path of an output directory for new summary statistic files

## Output
New summary statistic files with common SNPs in all input files. the sign of statistics has been corrected to make it consistent in different data.
   - for each input sumstat file, a qced version will be generated.
   - The generated sumstat files will have header as \"CHR,  POS,  A0, A1, SNP , STAT ,   SE, P\" regardless of input header
   - The generated sumstat files will be in gz format.

The header of actual output sumstat depends on the yml specification provided in the yml. However, all of them will have the column of `ID, CHR, POS, A0, A1, SNP, STAT, SE, P`. 

For example, when the input sumstat is from TensorQTL, the column specification is:

- GENE: Molecular trait identifier.(gene)
- CHR: Variant chromosome.
- POS: Variant chromosomal position (basepairs).
- A0: Variant reference allele (A, C, T, or G).
- A1: Variant alternate allele.
- TSS_D: Distance of the SNP to the gene transcription start site (TSS)
- AF: The allele frequency of this SNPs
- MA_SAMPLES: Number of samples carrying the minor allele
- MA_COUNT: Total number of minor alleles across individuals
- P: Nominal P-value from linear regression
- STAT: Slope of the linear regression
- SE: Standard error of beta

when the input sumstat is from APEX, the column specification is:

- GENE: Molecular trait identifier.(gene)
- CHR: Variant chromosome.
- POS: Variant chromosomal position (basepairs).
- A0: Variant reference allele (A, C, T, or G).
- A1: Variant alternate allele.
- P: Nominal P-value from linear regression
- STAT: Slope of the linear regression
- SE: Standard error of beta

## Memory usage
For merging two sumstat with ~85000 rows and of size of ~5MB, 1 GB of memory is needed 

For merging two sumstat with ~2000000 rows and of size of ~1 GB, at least 50 GB of memory is needed.

## Example command

```
sos run ./summary_stats_merger.ipynb --cwd data --yml_list data/yml_list.txt --keep-ambiguous --intersect
```

In [None]:
[global]
# Work directory where output will be saved to
parameter: cwd = path
## path to a list of yml file , with columns #chr and dir
parameter: yml_list = path
import pandas as pd
yml_path = pd.read_csv(yml_list,sep = "\t").values.tolist()
#if add --keep-ambiguous parameter, keep ambiguous alleles which can not be decided from flip or reverse, such as A/T or C/G. Otherwise, remove them.
parameter: keep_ambiguous = False
# if add --intersect parameter, output intersect SNPs in all input files.
parameter: intersect = False
# Containers that contains the necessary packages
parameter: container = ""
parameter: numThreads = 1
# For cluster jobs, number commands to run per job
parameter: job_size = 1
# Walltime 
parameter: walltime = '5h'
parameter: mem = '3G'
# The directory of the output sumstat
parameter: sumstat_list = path
sumstat_path = pd.read_csv(sumstat_list,sep = "\t").drop(columns="#chr").values.tolist()
name = pd.read_csv(sumstat_list,sep = "\t").drop(columns="#chr").columns.values.tolist()
## Whether to rename the Chr name.
parameter: remame = False
import time

## Workflow codes

In [113]:
[default_1 (export utils script)]
depends: Py_Module('cugg')
output: f'{cwd:a}/utils.py'
report: expand = '${ }', output=f'{cwd:a}/utils.py'
    import os
    import pandas as pd
    from cugg.sumstat import read_sumstat
    from cugg.utils import *
    ## To be added in cugg packages
    ## Running functions
    def read_sumstat(file, config,rename=True):
        try:
            sumstats = pd.read_csv(file, compression='gzip', header=0, sep='\t', quotechar='"')
        except:
            sumstats = pd.read_csv(file, header=0, sep='\t', quotechar='"')
        if config is not None:
            try:
                ID = config.pop('ID').split(',')
                sumstats = sumstats.loc[:,list(config.values())]
                sumstats.columns = list(config.keys())
                sumstats.index = namebyordA0_A1(sumstats[ID],cols=ID)
            except:
                raise ValueError(f'According to config_file, input summary statistics should have the following columns: %s' % list(config.values()))
            sumstats.columns = list(config.keys())
        if rename:
            sumstats.SNP = 'chr'+sumstats.CHR.astype(str).str.strip("chr") + ':' + sumstats.POS.astype(str) + '_' + sumstats.A0.astype(str) + '_' + sumstats.A1.astype(str)
        sumstats.CHR = sumstats.CHR.astype(str).str.strip("chr").astype(int)
        sumstats.POS = sumstats.POS.astype(int)
        sumstats.index = namebyordA0_A1(sumstats[["GENE","CHR","POS","A0","A1"]],cols=["GENE","CHR","POS","A0","A1"])
        return sumstats


    def snps_match(query,subject,keep_ambiguous=True):
        print("Total rows of query: ",query.shape[0],"Total rows of subject: ",subject.shape[0])
        if len(query.index[0].split('_')[0].split(':'))>2:
            #gene:chr:pos case
            genes_query = query.index.to_series().apply(lambda x: x.split(':')[0])
            genes_subject = subject.index.to_series().apply(lambda x: x.split(':')[0])
            query = dict(tuple(query.groupby(genes_query)))
            subject = dict(tuple(subject.groupby(genes_subject)))
            new_query, new_subject = [],[]
            for g in genes_query.unique():
                if g in query.keys() and g in subject.keys():
                    new_q,new_s = snps_match_dup(query[g],subject[g],keep_ambiguous)
                    new_query.append(new_q)
                    new_subject.append(new_s)
            new_query, new_subject=pd.concat(new_query),pd.concat(new_query)
        else:
            #chr:pos case
            new_query, new_subject=snps_match_dup(query,subject,keep_ambiguous)
        return new_query, new_subject


    def unify_sumstat(yml,keep_ambiguous,intersect):
        #parse yaml
        yml = load_yaml(yml)
        input_dict = parse_input(yml['INPUT'])
        target_dict = yml['TARGET']
        output_path = yml['OUTPUT'][0]
        lst_sumstats_file = [ os.path.basename(i) for i in input_dict.keys()]
        print('Total number of sumstats: ',len(lst_sumstats_file))
        if len(set(lst_sumstats_file))<len(lst_sumstats_file):
            raise Exception("There are duplicated names in {}".format(lst_sumstats_file))
        #read all sumstats
        print(input_dict)
        lst_sumstats = {os.path.basename(i):read_sumstat(i,j,) for i,j in input_dict.items()}
        nqs = []
        #Readin the reference target file: Using one of the input
        if os.path.basename(target_dict[0]) in lst_sumstats:
            subject = check_indels(lst_sumstats[os.path.basename(target_dict[0])])
        #Or using a prepared input
        else:
            subject = check_indels(read_sumstat(target_dict[0],None,False)[["CHR","POS","A0","A1","GENE"]])
        for query in lst_sumstats.values():
            #check duplicated indels and remove them.
            query = check_indels(query)
            #under the same chr:pos or gene:chr:pos. match A0 and A1 by exact, flip, reverse, or flip+reverse.
            #if duplicated chr_pos or gene_chr_pos exist, run a recursive match for each pair of them between query and subject.
            nq,_ = snps_match(query,subject,keep_ambiguous)
            nq = nq.loc[:,~nq.columns.duplicated()] # Remove duplicated columns due to order of columns difference in subject and query
            nqs.append(nq)
        if intersect:
            #get common snps
            common_snps = set.intersection(*[set(nq.SNP) for nq in nqs])
            print('Total number of common SNPs: ',len(common_snps))
            #write out new sumstats
            for output_sumstats,nq in zip(lst_sumstats_file,nqs):
                sumstats = nq[nq.SNP.isin(common_snps)]
                sumstats.to_csv(os.path.join(output_path, output_sumstats), sep = "\t", header = True, index = False)
        else:
            for output_sumstats,nq in zip(lst_sumstats_file,nqs):
                #output match SNPs with target SNPs.
                nq.to_csv(os.path.join(output_path, output_sumstats), sep = "\t", header = True, index = False)
        print('All are done')

In [2]:
[default_2 (unify sumstats)]
depends: f'{cwd:a}/utils.py'
input: for_each = "yml_path"
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
python: expand = '${ }', input = f'{cwd:a}/utils.py', stderr = f'{cwd:a}/{path(_yml_path[1]):bn}.stderr', stdout = f'{cwd:a}/output.stdout'
    yml = "${_yml_path[1]}"
    keep_ambiguous = ${keep_ambiguous}
    intersect = ${intersect}
    print(yml, keep_ambiguous,intersect)
    unify_sumstat(yml, keep_ambiguous,intersect)

In [None]:
[default_3 ,sumstat_to_vcf_1 ]
input: for_each = "sumstat_path"
output: [f'{path(x):an}.vcf' for x in _sumstat_path]
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output[0]:bn}'
python: expand = '${ }', stderr = f'{cwd:a}/{path(_sumstat_path[0]):bn}.stderr', stdout = f'{cwd:a}/output.stdout'
    from cugg.sumstat import ss_2_vcf
    import pandas as pd
    from sos.targets import path
    def ss_2_vcf(ss_df,name = "name"):
        ## Geno field
        df = pd.DataFrame()
        if "SNP" not in ss_df.columns:
            ss_df['SNP'] = 'chr'+ss_df.CHR.astype(str).str.strip("chr") + ':' + ss_df.POS.astype(str) + '_' + ss_df.A0.astype(str) + '_' + ss_df.A1.astype(str)
        df[['#CHROM', 'POS', 'ID', 'REF', 'ALT']] = ss_df[['CHR', 'POS', 'SNP', 'A0', 'A1']].sort_values(['CHR', 'POS'])
        ## Info field(Empty)
        df['QUAL'] = "."
        df['FILTER'] = "PASS"
        df['INFO'] = "."
        fix_header = ["SNP","A1","A0","POS","CHR","STAT","SE","P"]
        header_list = []
        if "GENE" in ss_df.columns:
            df['ID'] = ss_df['GENE'] + ":" + ss_df['SNP']
            df['INFO'] = "GENE=" + ss_df["GENE"]
            fix_header = ["GENE","SNP","A1","A0","POS","CHR","STAT","SE","P"]
            header_list = ['##INFO=<ID=GENE,Number=1,Type=String,Description="The name of genes">']
        ### Fix headers
        import time
        header = '##fileformat=VCFv4.2\n' + \
        '##FILTER=<ID=PASS,Description="All filters passed">\n' + \
        f'##fileDate={time.strftime("%Y%m%d",time.localtime())}\n'+ \
        '##FORMAT=<ID=STAT,Number=1,Type=Float,Description="Effect size estimate relative to the alternative allele">\n' + \
        '##FORMAT=<ID=SE,Number=1,Type=Float,Description="Standard error of effect size estimate">\n' + \
        '##FORMAT=<ID=P,Number=1,Type=Float,Description="The Pvalue corresponding to ES">\n' 
        ### Customized Field headers
        for x in ss_df.columns:
            if x not in fix_header:
                Prefix = f'##FORMAT=<ID={x},Number=1,Type='
                Type = str(type(ss_df[x][0])).replace("<class \'","").replace("'>","").replace("numpy.","").replace("64","").capitalize().replace("Int","Integer")
                Surfix = f',Description="Customized Field {x}">'
                header_list.append(Prefix+Type+Surfix)
        ## format and sample field
        df['FORMAT'] = ":".join(["STAT","SE","P"] + ss_df.drop(fix_header,axis = 1).columns.values.tolist())
        df[f'{name}'] = ss_df['STAT'].astype(str) + ":" + ss_df['SE'].astype(str) + ":" + ss_df['P'].astype(str) + ":" + ss_df.drop(fix_header,axis = 1).astype(str).apply(":".join,axis = 1)
        ## Rearrangment
        df = df[['#CHROM', 'POS', 'ID', 'REF', 'ALT', 'QUAL', 'FILTER', 'INFO','FORMAT',f'{name}']]
        # Add headers
        header = header + "\n".join(header_list) + "\n"
        return df,header

    sumstat_path_list = ${_sumstat_path}
    name = ${name}
    for x,y in zip(sumstat_path_list,name):
        sumstats = pd.read_csv(x,"\t")
        sumstats,header = ss_2_vcf(sumstats,y)
        with open(f'{path(x):an}.vcf', 'w') as f:
            f.write(header)
        sumstats.to_csv(f'{path(x):an}.vcf', sep = "\t", header = True, index = False,mode = "a")

In [None]:
[default_4,sumstat_to_vcf_2]
output: f'{cwd}/{_input[0]:bn}.merged.vcf.gz'.replace(name[0],"_".join(name))
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
bash: expand = '${ }', stderr = f'{cwd:a}/{_output:bn}.stderr', stdout = f'{cwd:a}/{_output:bn}.stdout'
    for i in ${_input:r}; do
    bgzip -k -f $i 
    tabix -p vcf -f  $i.gz; done
    bcftools merge ${" ".join([f'{str(x)}.gz' for x in _input])} --force-samples -m id  -Oz -o ${_output:a}