# GWAS integration: TWAS and MR

## Introduction

This module provides software implementations for transcriptome-wide association analysis (TWAS), and Mendelian Randomization using fine-mapping instrumental variables (IV). The procedures implements the MR procedure described in Zhang et al 2020 for "causal" effects estimation and model validation, with the unit of analysis being a single gene-trait pair.

This procedure is based on the SuSiE-TWAS workflow --- it assumes that xQTL fine-mapping has been performed (to be used for both TWAS and MR) and moleuclar traits prediction weights pre-computed (to be used for TWAS). Cross validation for TWAS weights is optional but highly recommended.

GWAS data required are GWAS summary statistics and LD matrix for the region of interest.

### Step 1: TWAS 

1. Extract GWAS z-score for region of interest and corresponding LD matrix.
2. (Optional) perform allele matching QC for the LD matrix with summary stats.
3. Process weights: for LASSO, Elastic Net and mr.ash we have to take the weights as is for QTL variants overlapping with GWAS variants. For SuSiE weights it can be adjusted to exactly match GWAS variants.
4. Perofrm TWAS test for multiple sets of weights. 
5. For each gene, filter TWAS results by keeping the best model selected by CV. Drop the genes that don't show good evidence of TWAS prediction weights.

### Step 2: MR for candidate genes

1. Limit MR only to those showing some evidence of TWAS significance AND have strong instrumental variable (fine-mapping PIP or CS). 
2. Use fine-mapped xQTL with GWAS data to perform MR. 
3. For multiple IV, aggregate individual IV estimates using a fixed-effect meta-analysis procedure.
4. Identify and exclude results with severe violations of the exclusion restriction (ER) assumption.

## Input

### GWAS Data Input Interface (Similar to `susie_rss`)


I. **GWAS Summary Statistics Files**
- **Input**: Vector of files for one or more GWAS studies.
- **Format**: 
  - Tab-delimited files.
  - First 4 columns: `chr`, `pos`, `a0`, `a1`
  - Additional columns can be loaded using column mapping file see below  
- **Column Mapping files (optional)**:
  - Optional YAML file for custom column mapping.
  - Required columns: `chr`, `pos`, `a0`, `a1`, either `z` or (`betahat` and `sebetahat`).
  - Optional columns: `n`, `var_y` (relevant to fine-mapping).

II. **GWAS Summary Statistics Meta-File**: this is optional and helpful when there are lots of GWAS data to process via the same command
- **Columns**: `study_id`, chromosome number, path to summary statistics file, optional path to column mapping file.
- **Note**: Chromosome number `0` indicates a genome-wide file.

eg: `gwas_meta.tsv`

```
study_id    chrom    file_path                 column_mapping_file
study1      1        gwas1.tsv.gz         column_mapping.yml
study1      2        gwas2.tsv.gz         column_mapping.yml
study2      0        gwas3.tsv.gz         column_mapping.yml
```

If both summary stats file (I) and meta data file (II) are specified we will take the union of the two.


III. **LD Reference Metadata File**
- **Format**: Single TSV file.
- **Contents**:
  - Columns: `chr`, `start`, `end`, path to the LD matrix, genomic build.
  - LD matrix path format: comma-separated, first entry is the LD matrix, second is the bim file.
- **Documentation**: Refer to our LD reference preparation document for detailed information (Tosin pending update).

### Output of Fine-Mapping & TWAS Pipeline

I. **xQTL Weight Database files**
- path to various weight DB files, comma delimited.

II. **xQTL Weight Database Metadata File**: this is optional and helpful when TWAS is done genome-wide for many regions via the same command
- **Types**: Gene-based or TAD-based.
- **Structure**: 
  - RDS format.
  - Organized hierarchically: region → condition → weight matrix.
  - Each column represents a different method.
- **Format**: `chrom`, `start`, `end`, `region_id`, `condition` (e.g., tissue type, QTL), path to various weight DB files, comma delimited.

eg: `xqtl_meta.tsv`

```
chrom    start    end    region_id    condition    file_path
1        1000     5000   region1      cohor1:tissue1:eQTL      weight1.rds, weight2.rds
2        2000     6000   region2      cohor1:tissue1:eQTL      weight3.rds
3        3000     7000   region3      cohor1:tissue1:eQTL      weight4.rds, weight5.rds
```

## Output

TWAS FIXME this is incorrect for now.

- Each row corresponds to a single SNP inferred as a member of a signal cluster, with columns including:
   - `snp`: SNP name.
   - `beta_eQTL`: eQTL effect.
   - `se_eQTL`: Standard error of estimated eQTL effect.
   - `beta_GWAS`: GWAS effect.
   - `se_GWAS`: Standard error of GWAS effect.
   - `cluster`: Signal cluster ID (credible sets index).
   - `pip`: SNP posterior inclusion probability (PIP).
   - `gene_id`: Gene name.


MR

-  The output includes the following columns for each gene:
    - `gene_id`: Gene name.
    - `num_cluster`: Number of credible sets of the gene.
    - `num_instruments`: Number of instruments included in the gene.
    - `spip`: Sum of PIP for credible sets of each gene.
    - `grp_beta`: Signal-level estimates, combining SNP-level estimates from member SNPs weighted by their PIPs.
    - `grp_se`: Standard error of signal-level estimates.
    - `meta`: Gene-level estimate of the causal effect, aggregating signal-level estimates using a fixed-effect meta-analysis model.
    - `se_meta`: Standard error of the gene-level estimate of the causal effect.
    - `Q`: Cochran’s Q statistic.
    - `I2`: $I^2$ statistics

In [None]:
[global]
parameter: cwd = path("output/")
parameter: gwas_meta_data = path()
parameter: xqtl_meta_data = path()
parameter: ld_meta_data = path()
parameter: gwas_name = []
parameter: gwas_data = []
parameter: column_mapping = []
parameter: name = f"{xqtl_meta_data:bn}.{gwas_meta_data:bn}"
parameter: container = ''
import re
parameter: entrypoint= ('micromamba run -a "" -n' + ' ' + re.sub(r'(_apptainer:latest|_docker:latest|\.sif)$', '', container.split('/')[-1])) if container else ""
parameter: job_size = 100
parameter: walltime = "5m"
parameter: mem = "8G"
parameter: numThreads = 1

import os
import pandas as pd

def adapt_file_path(file_path, reference_file):
    """
    Adapt a single file path based on its existence and a reference file's path.

    Args:
    - file_path (str): The file path to adapt.
    - reference_file (str): File path to use as a reference for adaptation.

    Returns:
    - str: Adapted file path.

    Raises:
    - FileNotFoundError: If no valid file path is found.
    """
    reference_path = os.path.dirname(reference_file)

    # Check if the file exists
    if os.path.isfile(file_path):
        return file_path

    # Check file name without path
    file_name = os.path.basename(file_path)
    if os.path.isfile(file_name):
        return file_name

    # Check file name in reference file's directory
    file_in_ref_dir = os.path.join(reference_path, file_name)
    if os.path.isfile(file_in_ref_dir):
        return file_in_ref_dir

    # Check original file path prefixed with reference file's directory
    file_prefixed = os.path.join(reference_path, file_path)
    if os.path.isfile(file_prefixed):
        return file_prefixed

    # If all checks fail, raise an error
    raise FileNotFoundError(f"No valid path found for file: {file_path}")

def group_by_region(lst, partition):
    # from itertools import accumulate
    # partition = [len(x) for x in partition]
    # Compute the cumulative sums once
    # cumsum_vector = list(accumulate(partition))
    # Use slicing based on the cumulative sums
    # return [lst[(cumsum_vector[i-1] if i > 0 else 0):cumsum_vector[i]] for i in range(len(partition))]
    return partition

In [None]:
[get_analysis_regions: shared = "regional_data"]
from collections import OrderedDict

def check_required_columns(df, required_columns):
    """Check if the required columns are present in the dataframe."""
    missing_columns = [col for col in required_columns if col not in list(df.columns)]
    if missing_columns:
        raise ValueError(f"Missing required columns: {', '.join(missing_columns)}")

def extract_regional_data(gwas_meta_data, xqtl_meta_data, gwas_name, gwas_data, column_mapping):
    """
    Extracts data from GWAS and xQTL metadata files and additional GWAS data provided. 

    Args:
    - gwas_meta_data (str): File path to the GWAS metadata file.
    - xqtl_meta_data (str): File path to the xQTL weight metadata file.
    - gwas_name (list): vector of GWAS study names.
    - gwas_data (list): vector of GWAS data.
    - column_mapping (list, optional): vector of column mapping files.

    Returns:
    - Tuple of two dictionaries:
        - GWAS Dictionary: Maps study IDs to a list containing chromosome number, 
          GWAS file path, and optional column mapping file path.
        - xQTL Dictionary: Nested dictionary with region IDs as keys.

    Raises:
    - FileNotFoundError: If any specified file path does not exist.
    - ValueError: If required columns are missing in the input files or vector lengths mismatch.
    """
    # Check vector lengths
    if len(gwas_name) != len(gwas_data):
        raise ValueError("gwas_name and gwas_data must be of equal length")
    
    if len(column_mapping)>0 and len(column_mapping) != len(gwas_name):
        raise ValueError("If column_mapping is provided, it must be of the same length as gwas_name and gwas_data")

    # Required columns for each file type
    required_gwas_columns = ['study_id', 'chrom', 'file_path']
    required_xqtl_columns = ['region_id', 'chrom', 'start', 'end', 'condition', 'file_path']
    
    # Reading the GWAS metadata file
    gwas_df = pd.read_csv(gwas_meta_data, sep="\t")
    check_required_columns(gwas_df, required_gwas_columns)
    gwas_dict = OrderedDict()

    # Process additional GWAS data from R vectors
    for name, data, mapping in zip(gwas_name, gwas_data, column_mapping or [None]*len(gwas_name)):
        gwas_dict[name] = {0: [data, mapping]}

    for _, row in gwas_df.iterrows():
        file_path = row['file_path']
        mapping_file = row.get('column_mapping_file')
        
        # Adjust paths if necessary
        file_path = adapt_file_path(file_path, gwas_meta_data)
        if mapping_file:
            mapping_file = adapt_file_path(mapping_file,  gwas_meta_data)

       # Create or update the entry for the study_id
        if row['study_id'] not in gwas_dict:
            gwas_dict[row['study_id']] = {}

        # Expand chrom 0 to chrom 1-22 or use the specified chrom
        chrom_range = range(1, 23) if row['chrom'] == 0 else [row['chrom']]
        for chrom in chrom_range:
            if chrom in gwas_dict[row['study_id']]:
         e       existing_entry = gwas_dict[row['study_id']][chrom]
                raise ValueError(f"Duplicate chromosome specification for study_id {row['study_id']}, chrom {chrom}. "
                                 f"Conflicting entries: {existing_entry} and {[file_path, mapping_file]}")
            gwas_dict[row['study_id']][chrom] = [file_path, mapping_file]

    # Reading the xQTL weight metadata file
    xqtl_df = pd.read_csv(xqtl_meta_data, sep="\t")
    check_required_columns(xqtl_df, required_xqtl_columns)
    xqtl_dict = OrderedDict()
    for _, row in xqtl_df.iterrows():
        file_paths = [adapt_file_path(fp.strip(), xqtl_meta_data) for fp in row['file_path'].split(',')]  # Splitting and stripping file paths
        xqtl_dict[row['region_id']] = {"meta_info": [row['chrom'], row['start'], row['end'], row['region_id'], row['condition']],
                                       "files": file_paths}
    return gwas_dict, xqtl_dict

gwas_dict, xqtl_dict = extract_regional_data(gwas_meta_data, xqtl_meta_data, gwas_name, gwas_data, column_mapping)
regional_data = dict([("GWAS", gwas_dict), ("xQTL", xqtl_dict)])

In [None]:
[twas_mr]
depends: sos_variable("regional_data")
meta_info = [x["meta_info"] for x in regional_data['xQTL'].values()]
xqtl_files = [x["files"] for x in regional_data['xQTL'].values()]
input: xqtl_files, group_by = lambda x: group_by_region(x, xqtl_files), group_with = "meta_info"
output: f'{cwd:a}/{step_name[:-2]}/{name}.{_meta_info[3]}.twas_mr.txt'
task: trunk_workers = 1, trunk_size = job_size, walltime = walltime, mem = mem, cores = numThreads, tags = f'{step_name}_{_output:bn}'
R: expand = '${ }', stdout = f"{_output:n}.stdout", stderr = f"{_output:n}.stderr", container = container, entrypoint = entrypoint
    # we have potentially multiple weight db RDS files for each region of interest
    weight_db = c(${_input:r,})
    chrom = ${_meta_info[0]}
    start = ${_meta_info[1]} 
    end = ${_meta_info[2]}
    region = "${_meta_info[3]}"
    xqtl_conditions = c(${paths(_meta_info[4:]):r,})
    gwas_studies = c(${paths(regional_data["GWAS"].keys()):r,})
    # load gwas data file for this particular chrom
    gwas_files = c(${paths([v[_meta_info[0]] for k, v in regional_data["GWAS"].items()]):r,})
    library(pecotmr)
    # Step 0: Load GWAS data for the region of interest, for each study
    # Generate the region of interest
    region_of_interest = data.frame(chrom = chrom, start = start, end = end)
    # Generate the LD_meta_file data frame as the input of load_LD_matrix function
    LD_metadata = genomic_file_paths(paste0(ld_meta_data,"/",chrom,sep=""))
    gwas_data = list()
    for (s in length(gwas_studies)) {
      gwas_data[[gwas_studies[s]]] = list()
      gwas_sumstats = fread(gwas_files[s])%>% 
                 rename("pos" = "position", "chrom" = "chromosome", "A1" = "ref","A2" = "alt")%>%
                 mutate(z=beta/se)
      # Load LD list containing LD matrix and corresponding variants
      gwas_LD_list = load_LD_matrix(LD_metadata, region_of_interest, gwas_sumstats)
      # Allele flip
      gwas_allele_flip= allele_qc(gwas_sumstats, gwas_LD_list$combined_LD_variants, match.min.prop=0.2, remove_dups=FALSE, flip=TRUE, remove=TRUE)
      # Load LD matrix and sumstats
      gwas_data[[gwas_studies[s]]]["LD"] = gwas_LD_list$combined_LD_matrix
      gwas_data[[gwas_studies[s]]]["sumstats"] = gwas_allele_flip %>% mutate(variant_allele_flip = paste(chrom,pos,A1.sumstats,A2.sumstats,sep=":"))      
    }
    for (condition in xqtl_conditions) {
        # For each region of interest for a particular xQTL (condtion), we can perform TWAS for all the input studies, with all of the weights available
        # Step 1: load the weight matrix: a matrix of weights for the specified condition and region; each column is weight for a method. The row names should be variant names.
        weights[[xqtl_conditions[condition]]] = list()
        # For each condition, we extract the weights of multiple weight df RDS file
        for (weight_db_file in length(weight_db)){
        weights_matrix = load_twas_weights(weight_db[weight_db_file], condition, variable_name_obj="variant_names", twas_weights_table = "twas_weights")
        adjusted_susie_weights = adjust_susie_weights(weight_db[weight_db_file], weights_matrix, keep_variants = gwas_data[[gwas_studies[s]]]["sumstats"]$variant_allele_flip)
        # Step 2: for each study we perform TWAS: 
        # take overlap between weights and the gwas_data summstats for each study. If there are some variants in sumstats but not in weights then simply make weight zero
        # Step 2-1: NOT YET IMPLEMEMENTED (FIXME) adjust susie_weights if possible --- if susie_fit exist in weight_db. This should replace the origional SuSiE_weight computed above
        # Step 3: perform TWAS on the weights and sumstats
    }