# MASH analysis pipeline with posterior computation


## Compute MASH posteriors

In the GTEx V6 paper we assumed one eQTL per gene and applied the model learned above to those SNPs. Under that assumption, the input data for posterior calculation will be the `dat$strong.*` matrices.
It is a fairly straightforward procedure as shown in [this vignette](https://stephenslab.github.io/mashr/articles/eQTL_outline.html).

But it is often more interesting to apply MASH to given list of eQTLs, eg, from those from fine-mapping results. In GTEx V8 analysis we obtain such gene-SNP pairs from DAP-G fine-mapping analysis. See [this notebook](https://stephenslab.github.io/gtex-eqtls/analysis/Independent_eQTL_Results.html) for how the input data is prepared. The workflow below takes a number of input chunks (each chunk is a list of matrices `dat$Bhat` and `dat$Shat`) 
and computes posterior for each chunk. It is therefore suited for running in parallel posterior computation for all gene-SNP pairs, if input data chunks are provided.


```
JOB_OPT="-c midway2.yml -q midway2"
DATA_DIR=/project/compbio/GTEx_eQTL/independent_eQTL
sos run workflows/mashr_flashr_workflow.ipynb posterior \
    $JOB_OPT \
    --posterior-input $DATA_DIR/DAPG_pip_gt_0.01-AllTissues/DAPG_pip_gt_0.01-AllTissues.*.rds \
                      $DATA_DIR/ConditionalAnalysis_AllTissues/ConditionalAnalysis_AllTissues.*.rds
```

group recipe 
```
Ast,Exc
Mic,MiGA_GFM,MiGA_GTS
Oli,DLPFC_pQTL
# ... add more groups as needed, separated by commas
```

### example

##### posterior: works

In [None]:
#no slice
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb posterior  \
    --container /mnt/vast/hpc/csg/containers/xqtl_archive/stephenslab.sif \
    --analysis-units <(cat  MWE.list| cut -f 2) \
    --cwd MWE_udr  \
    --output_prefix  MWE_udr \
    --vhat mle \
    --mash_model MWE_udr/MWE_udr.EZ.V_mle.mash_model.rds -n

In [None]:
#slice
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb posterior  \
    --container /mnt/vast/hpc/csg/containers/xqtl_archive/stephenslab.sif \
    --analysis-units <(cat  MWE.list| cut -f 2) \
    --cwd MWE_udr  \
    --output_prefix  MWE_udr \
    --vhat mle \
    --mash_model MWE_udr/MWE_udr.EZ.V_mle.mash_model.rds \
    --slice_method True

##### contrast analysis: works 

In [None]:
# 0 group
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb mash_posterior_contrast   \
    --posterior_file test.list \
    --sum_file /mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/From_SuSiE/2023.5_new/ep_MiGA/output/ROSMAP_Pseudo_eQTL_DLPFC_pQTL_MiGA.merged_rds.list  \
    --cwd test \
    --cells "Ast","Exc","Inh","Mic","OPC","Oli","DLPFC_pQTL","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" 

In [None]:
# 1 group
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb mash_posterior_contrast   \
    --posterior_file test.list \
    --sum_file /mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/From_SuSiE/2023.5_new/ep_MiGA/output/ROSMAP_Pseudo_eQTL_DLPFC_pQTL_MiGA.merged_rds.list  \
    --cwd test \
    --cells "Ast","Exc","Inh","Mic","OPC","Oli","DLPFC_pQTL","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" \
    --group1 "Mic","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" 

In [None]:
# 2 group
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb mash_posterior_contrast   \
    --posterior_file test.list \
    --sum_file /mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/From_SuSiE/2023.5_new/ep_MiGA/output/ROSMAP_Pseudo_eQTL_DLPFC_pQTL_MiGA.merged_rds.list  \
    --cwd test \
    --cells "Ast","Exc","Inh","Mic","OPC","Oli","DLPFC_pQTL","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" \
    --group1 "Mic","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" \
    --group2 "Ast","Exc"

In [None]:
# recipe with 2 group
sos run /home/rf2872/codes/xqtl-pipeline/code/multivariate/MASH/mash_posterior.ipynb mash_posterior_contrast   \
    --posterior_file test.list \
    --sum_file /mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/From_SuSiE/2023.5_new/ep_MiGA/output/ROSMAP_Pseudo_eQTL_DLPFC_pQTL_MiGA.merged_rds.list  \
    --cwd test \
    --cells "Ast","Exc","Inh","Mic","OPC","Oli","DLPFC_pQTL","MiGA_GFM","MiGA_GTS","MiGA_SVZ","MiGA_THA" \
    --grouping_recipe recipe

In [None]:
[global]
import os
# Work directory & output directory
parameter: cwd = path('./')
# The filename prefix for output data
parameter: name="test"
parameter: cells = ["Ast","Exc","Inh","Mic","OPC","Oli","DLPFC_pQTL"]#order is important
parameter: group1 = []
parameter: group2 = []
parameter: group3 = []
parameter: job_size = 1
parameter: container = ''
parameter: table_name = ""
parameter: posterior_file_list = "w"
#parameter: orig_file = path
parameter: orig_file_list = "/mnt/vast/hpc/csg/rf2872/Work/MASH_test_csg/output/ALL_Ast_End_Exc_Inh_Mic_OPC_Oli.merged_rds.list"
#parameter:contrast_dir = "/home/rf2872/Work/MASH_test_csg/MASH_6_celltypes_Dan/contrast/"
##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: datadir = ""
import pandas as pd
#parameter: analysis_units = path
# handle N = per_chunk data-set in one job
parameter: per_chunk = 1
###add for test
parameter: output_prefix = ''
parameter: output_suffix = 'all'
# Exchangable effect (EE) or exchangable z-scores (EZ)
parameter: effect_model = 'EZ'
# Identifier of $\hat{V}$ estimate file
# Options are "identity", "simple", "mle", "vhat_corshrink_xcondition", "vhat_simple_specific"
parameter: vhat = 'simple'
parameter: data = path("fastqtl_to_mash_output/FastQTLSumStats.mash.rds")
data = data.absolute()
cwd = cwd.absolute()
if len(output_prefix) == 0:
    output_prefix = f"{data:bn}"
vhat_data = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.rds")
mash_model = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.mash_model.rds")
posterior_list = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.{output_suffix}.posterior_list")

def sort_uniq(seq):
    seen = set()
    return [x for x in seq if not (x in seen or seen.add(x))]

### Posterior results

1. The outcome of the `[posterior]` step should produce a number of serialized R objects `*.batch_*.posterior.rds` (can be loaded to R via `readRDS()`) -- I chopped data to batches to take advantage of computing in multiple cluster nodes. It should be self-explanary but please let me know otherwise.
2. Other posterior related files are:
    1. `*.batch_*.yaml`: gene-SNP pairs of interest, identified elsewhere (eg. fine-mapping analysis). 
    2. The corresponding univariate analysis summary statistics for gene-SNPs from `*.batch_*.yaml` are extracted and saved to `*.batch_*.rds`, creating input to the `[posterior]` step.
    3. Note the `*.batch_*.stdout` file documents some SNPs found in fine-mapping results but not found in the original `fastqtl` output.

## Slice Posterior 

take all the 13K genes, and for those with missing conditions we just drop those corresponding rows and cols in the prior model

In [None]:
# Apply posterior calculations with slice NA and set NaN/Inf 0/1E3, output_posterior_cov = T 
[posterior_1]
parameter: analysis_units = path
regions = [x.replace("\"","").strip().split() for x in open(analysis_units).readlines() if x.strip() and not x.strip().startswith('#')]
parameter: mash_model = path()
parameter: posterior_input = [path(x[0]) for x in regions]
parameter: posterior_vhat_files = paths()
# eg, if data is saved in R list as data$strong, then
# when you specify `--data-table-name strong` it will read the data as
# readRDS('{_input:r}')$strong
parameter: data_table_name = ''
parameter: bhat_table_name = 'bhat'
parameter: shat_table_name = 'sbhat'
parameter: per_chunk = '100'
##  conditions can be excluded if needs arise. If nothing to exclude keep the default 0
parameter: exclude_condition = ["1","3"]
parameter: output_prefix = "ROSMAP"
parameter: effect_model = 'EZ'
# Options are "identity", "simple", "mle", "vhat_corshrink_xcondition", "vhat_simple_specific"
parameter: vhat = 'simple'
parameter: slice_method = False 
skip_if(len(posterior_input) == 0, msg = "No posterior input data to compute on. Please specify it using --posterior-input.")
fail_if(len(posterior_vhat_files) > 1 and len(posterior_vhat_files) != len(posterior_input), msg = "length of --posterior-input and --posterior-vhat-files do not agree.")
for p in posterior_input:
    fail_if(not p.is_file(), msg = f'Cannot find posterior input file ``{p}``')
vhat_data = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.rds")
mash_model = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.mash_model.rds")
input: posterior_input, group_by = per_chunk
output: f"{cwd}/cache/mash_output_list_{_index+1}"
task: trunk_workers = 1, walltime = '20h', trunk_size = 1, mem = '20G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container
    library(mashr)
    library(dplyr)
    library(stringr)
    #library(ttt)
    handle_nan_etc = function(x) {
      x$bhat[which(is.nan(x$bhat))] = 0
      x$sbhat[which(is.nan(x$sbhat) | is.infinite(x$sbhat))] = 1E3
      return(x)
    }
    # Slice matrices
    slice_and_update_data <- function(data, vhat, snps, samples) {
        data$bhat <- data$bhat[snps, samples] %>% as.matrix
        data$sbhat <- data$sbhat[snps, samples] %>% as.matrix
        data$Z <- data$Z[snps, samples] %>% as.matrix
        vhat <- vhat[samples, samples] %>% as.matrix

        # Filter SNPs and update column names
        data$snp <- data$snp[data$snp %in% snps]
        colnames(data$bhat) <- colnames(data$sbhat) <- colnames(data$Z) <- colnames(vhat) <- samples

        return(list(data = data, vhat = vhat))
    }
  
    # Remove covariance matrices that are not needed
    remove_unnecessary_cov_matrices <- function(cov, all_samples, samples) {
      unwanted_samples <- setdiff(all.samples, samples)
      for (d in names(cov)) {
        if (d %in% unwanted_samples || d %in% paste0("ED_", unwanted_samples)) {
          cov[[d]] <- NULL
        }
      }
      return(cov)
    }

    # Update or adjust the covariance matrices
    adjust_cov_matrices <- function(cov, samples) {
      for (d in names(cov)) {
        if (d %in% samples) {
          cov[[d]] <- matrix(0, length(samples), length(samples))
          cov[[d]][which(samples == d), which(samples == d)] <- 1
        } else if (d == "identity") {
          cov[[d]] <- matrix(0, length(samples), length(samples))
          cov[[d]][1, 1] <- 1  
        } else if (is.null(colnames(cov[[d]]))) {
          cov[[d]] <- cov[[d]][1:length(samples), 1:length(samples)]
        } else {
          cov[[d]] <- cov[[d]][samples, samples]
        }
        cov[[d]] <- as.matrix(cov[[d]])
      }
      return(cov)
    }

    # Main function to update the covariance in the MASH model
    update_mash_model_cov <- function(mash_model, all_samples, samples) {
      cov <- mash_model$fitted_g$Ulist

      # Remove matrices that are not required
      cov <- remove_unnecessary_cov_matrices(cov, all_samples, samples)

      # Update or reshape the covariance matrices
      cov <- adjust_cov_matrices(cov, samples)

      # Update the covariance matrices in the model
      mash_model$fitted_g$Ulist <- cov

      # Update the 'pi' attribute of the model
      unwanted_samples <- setdiff(all.samples, samples)
      for (s in unwanted_samples) {
        mash_model$fitted_g$pi <- mash_model$fitted_g$pi[-grep(s, names(mash_model$fitted_g$pi))]
      }

      return(mash_model)
    }
  
    #slice = TRUE  # Set this to either TRUE or FALSE

    outlist = data.frame()
    for (f in c(${_input:r,})) try({

      data = readRDS(f)${('$' + data_table_name) if data_table_name else ''}
      data <- handle_nan_etc(data)

      if(c(${",".join(exclude_condition)})[1] > 0 ){
        message(paste("Excluding condition ${exclude_condition} from the analysis"))
        data$bhat = data$bhat[,-c(${",".join(exclude_condition)})]
        data$sbhat = data$sbhat[,-c(${",".join(exclude_condition)})]
        data$Z = data$Z[,-c(${",".join(exclude_condition)})]
      }

      vhat = readRDS("${vhat_data if len(posterior_vhat_files) == 0 else posterior_vhat_files[_index]}")
      mash_model <- readRDS("${mash_model}")
  
      slice_method <- ${'TRUE' if slice_method else 'FALSE'}
      if(slice_method){
        # All additional operations from the second script go here

        all.samples <- colnames(data$bhat)
        all.snps <- rownames(data$bhat)    

        #remove the rows and cols containing NA
        na.test <- data$bhat %>% as.data.frame %>% select_if(~any(!is.na(.))) %>% na.omit %>% as.matrix

        #recording meaningful rows and cols
        samples <- colnames(na.test)
        snps <- rownames(na.test)

        if(length(all.snps)!=length(snps) | length(all.samples)!=length(samples)){
            # slice data matrix
            data <- slice_and_update_data(data, vhat, snps, samples)

            if(length(all.samples)!=length(samples)){
                ##slice the prior
                mash_model <- update_mash_model_cov(mash_model, all_samples, samples)
            }
        }
      }

      mash_data = mash_set_data(data$${bhat_table_name}, Shat=data$${shat_table_name}, alpha=${1 if effect_model == 'EZ' else 0}, V=vhat, zero_Bhat_Shat_reset = 1E3)
      mash_output = mash_compute_posterior_matrices(mash_model, mash_data, output_posterior_cov=TRUE)
      mash_output$snps = data$snps
      samplename <- str_split(f, "/", simplify = T) %>% .[length(.)] %>% gsub('.rds', '', .)
      saveRDS(mash_output, paste0("${_output:d}", "/", samplename, ".posterior.rds"))
      outlist <- rbind(outlist, paste0("${_output:d}", "/", samplename, ".posterior.rds"))

    })
    write.table(outlist, ${_output:r}, col.names=F, row.names=F, quote=F)


In [None]:
[*posterior_2]
input: group_by = "all"
output:f"{cwd}/mash_output_list_{output_suffix}"
bash: expand ='${ }', workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout"
     cd ${_input[0]:d}
     cat mash_output_list_*[0-9] >> posterior_file_list
     awk -F 'cis_long_table.' '{print $2}' posterior_file_list| awk -F '.posterior.rds' '{print $1}'|paste - posterior_file_list > ${_output:r}
     rm posterior_file_list


## Posterior contrast

In [2]:
# perform mash posterior contrast for sliced data
[mash_posterior_contrast_1]
parameter: grouping_recipe = ""
parameter: posterior_file = path
parameter: sum_file = path

# Extract data from posterior_file
paths_posterior = [x.replace("\"","").strip().split()[1] for x in open(posterior_file).readlines() if x.strip() and not x.strip().startswith('#')]

# Create a dictionary from sum_file for quick lookup
dict_sum = dict([(x.replace("\"","").strip().split()[0], x.replace("\"","").strip().split()[1]) for x in open(sum_file).readlines() if x.strip() and not x.strip().startswith('#')])

# Use genes from posterior_file to fetch corresponding paths from sum_file
paths_sum = [dict_sum[x.replace("\"","").strip().split()[0]] for x in open(posterior_file).readlines() if x.strip() and not x.strip().startswith('#')]

input: paths_posterior, paired_with='paths_sum', group_by=1
output: f"{cwd}/{_input:bnn}_posterior_contrast.rds"
task: trunk_workers = 1, trunk_size = job_size, walltime = '24h',  mem = '10G', tags = f'{_output:bn}'
R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout' 
    # Load necessary libraries
    library(mashr)
    library(RhpcBLASctl)
    library(magrittr)
    library(tidyverse)
    #library(ttt)

    # Set number of threads for BLAS operations
    blas_set_num_threads(1)

    # Create a function for pairwise contrast columns
    MakePairwiseContrastCols <- function(contrast_left, orig_vector) {
        orig_vector[contrast_left[1]] <- 1
        orig_vector[contrast_left[2]] <- -1
        orig_vector
    }

    # Function to fit contrast data
    FitContrast <- function(index, orig_mean, posterior_mean, posterior_vcov) {
            population_names <- colnames(posterior_mean) %>% str_remove_all("BETA_")

            orig_mean_vector <- orig_mean[index,]
            names(orig_mean_vector) <- population_names
            orig_mean_nonzero <- as.vector(orig_mean_vector != 0)
            orig_mean_tested <- names(orig_mean_vector[orig_mean_nonzero])
            n_populations <- length(orig_mean_tested)

            pairwise_vector <- rep(0, n_populations)
            names(pairwise_vector) <- orig_mean_tested

            grouping <- grouping_all[orig_mean_tested]
            if (n_populations > 1) {
                if (n_populations > 2) {
                    #####1. deviation contrast
                    deviation_contrasts <- rep(-1, n_populations^2) %>% matrix(nrow = n_populations, ncol = n_populations)
                    diag(deviation_contrasts) <- n_populations - 1
                    rownames(deviation_contrasts) <- orig_mean_tested
                    colnames(deviation_contrasts) <- orig_mean_tested
                    deviation_contrasts_tested <- deviation_contrasts[, orig_mean_tested]
                    
                    unique_groups <- unique(grouping)
                    for (grp in unique_groups[unique_groups > 0]) {
                       #same celltype (e.g. MIC) with different populations would get 1/n for their weight, 
                        diag(deviation_contrasts_tested)[grouping == grp] <- (n_populations - 1) / length(grouping[grouping == grp])
                        deviation_contrasts_tested[grouping == grp, grouping == grp] <- (n_populations - 1) / length(grouping[grouping == grp])
                    }
                    
                    colnames(deviation_contrasts_tested) %<>% str_c("_deviation")

                    ####2. pairwise contrast
                    two_combn <- combn(orig_mean_tested, m = 2)
                    pairwise_names <- apply(two_combn, 2, str_c, collapse = "_vs_")
                    pairwise_contrast <- apply(two_combn, 2, MakePairwiseContrastCols, pairwise_vector)
                
                    colnames(pairwise_contrast) <- pairwise_names
                    
                    # Create a new matrix to store the adjusted values
                    pairwise_contrast_new <- pairwise_contrast

                    # Loop through each column to archieve such goal: e.g.
                    # microglia populations would get 1/n_Mic for their weight, 
                    # and Mic vs Mic would still be 1 vs -1 to estimate the internal difference among microglia datasets
                    for (col in colnames(pairwise_contrast)) {
                      # Split column names to get group names
                      groups <- strsplit(col, "_vs_")[[1]]

                      # Get the grouping values for the two groups
                      group_values <- grouping[names(grouping) %in% groups]

                      # Identify groups with non-zero grouping values
                      relevant_groups <- names(group_values[group_values > 0])

                      # Check if there are multiple distinct groups
                      if (length(unique(group_values)) > 1 && length(relevant_groups) > 0) {
                        distinct_groups <- unique(group_values[group_values > 0])

                        for (distinct_grp in distinct_groups) {
                          # Identify rows belonging to the current group
                          rows_in_group <- names(grouping[grouping == distinct_grp])

                          # Adjust the pairwise_contrast values for each row in the group
                          pairwise_contrast_new[rows_in_group, col] <- pairwise_contrast[rows_in_group[rows_in_group %in% groups], col] / length(rows_in_group)
                        }
                      }
                    }

                    # Replace the original matrix with the new one
                    pairwise_contrast <- pairwise_contrast_new

                    #### 3. combine them 
                    contrast_design <- cbind(deviation_contrasts_tested / (n_populations - 1), pairwise_contrast)

                } else {
                    pairwise_vector[orig_mean_tested[1]] <- 1
                    pairwise_vector[orig_mean_tested[2]] <- -1
                    contrast_design <- as.matrix(pairwise_vector)
                    colnames(contrast_design) <- str_c(orig_mean_tested[1], "_vs_", orig_mean_tested[2])
                }

                posterior_mean_subset <- posterior_mean[index,]
                posterior_mean_subset2 <- posterior_mean_subset[orig_mean_tested]
                posterior_vcov_subset <- posterior_vcov[,,index]
                posterior_vcov_subset2 <- posterior_vcov_subset[orig_mean_tested,orig_mean_tested]

                contrast_diff <- t(contrast_design) %*% posterior_mean_subset2
                contrast_vcov <- t(contrast_design) %*% posterior_vcov_subset2 %*% contrast_design
                contrast_se <- diag(contrast_vcov) %>% sqrt

                contrast_p <- 2 * (1 - pnorm(abs(contrast_diff) / contrast_se))

                contrast_diff_df <- t(contrast_diff) %>% as_tibble
                colnames(contrast_diff_df) %<>% str_c("mean_contrast_", .)
                contrast_se_df <- t(contrast_se) %>% as_tibble
                colnames(contrast_se_df) %<>% str_c("se_contrast_", .)
                contrast_p_df <- t(contrast_p) %>% as_tibble
                colnames(contrast_p_df) %<>% str_c("p_contrast_", .)

                contrast_df <- bind_cols(contrast_diff_df, contrast_se_df, contrast_p_df)
            } else if(grouping[orig_mean_tested][1]!=grouping[orig_mean_tested][2]){
                contrast_vector <- rep(NA, length(population_names))
                names(contrast_vector) <- str_c("mean_contrast_", population_names, "_deviation")
                contrast_df <- t(contrast_vector) %>% as_tibble
            }
            contrast_df
        }

    if(length("${cells}") > 0){
        # All the cells
        cells <- c("${", ".join(cells)}") %>% str_split(., ",", simplify = TRUE) %>% as.character 

        # Automatically set grouping categories based on the recipe， set0 for the celltypes without multiple populations
        grouping_all <- rep(0, length(cells))
        names(grouping_all) <- cells
        
  
        # Read groupings from the recipe
        if(length("${group1}") > 0){
            cell_groups <- list(
              ${"group1 = c(" + ", ".join(["'" + item + "'" for item in group1]) + ")" if len(group1) > 0 else ""} 
              ${", group2 = c(" + ", ".join(["'" + item + "'" for item in group2]) + ")" if len(group2) > 0 else ""} 
              ${", group3 = c(" + ", ".join(["'" + item + "'" for item in group3]) + ")" if len(group3) > 0 else ""}
            )
            if(!is.null(cell_groups)) {
              cell_groups <- map(cell_groups, ~str_split(.x, ",", simplify = TRUE) %>% as.character())
            }
        }
  
        if("${grouping_recipe}" != ""){
            cell_groups <- readLines("${grouping_recipe}")
            cell_groups <- lapply(cell_groups, function(g) strsplit(g, ",")[[1]])
        }

        if(!is.null(cell_groups)){
            for(i in seq_along(cell_groups)) {
              grouping_all[cell_groups[[i]]] <- i
            }
        }
    }

    
    # Read the data files
    orig_data <- read_rds("${_paths_sum[0]}")$bhat
    posterior_data <- read_rds("${_input}")
    posterior_mean <- posterior_data$PosteriorMean
    posterior_cov <- posterior_data$PosteriorCov

    # Align data and clean-up NaN values
    orig_data <- orig_data[, colnames(posterior_mean), drop = FALSE]
    orig_data[which(is.nan(orig_data))] <- 0 # Placeholder for NaNs

    # Apply the FitContrast function and consolidate results
    contrast_result <- map(1:nrow(posterior_mean), FitContrast, orig_data, posterior_mean, posterior_cov) %>% bind_rows %>%
        select(matches("mean_contrast.*deviation"), matches("mean_contrast.*_vs_"), 
               matches("se_contrast.*deviation"), matches("se_contrast.*_vs_"), 
               matches("p_contrast.*deviation"), matches("p_contrast.*_vs_"))
    rownames(contrast_result) <- rownames(posterior_mean)

    write_rds(contrast_result,  ${_output:r})

In [1]:
# merge the contrast data with slice data
[mash_posterior_contrast_2]
input: group_by = "all"
output: f"{cwd}/posterior_sum.csv"
task: trunk_workers = 1, trunk_size = job_size, walltime = '24h',  mem = '10G', tags = f'{_output:bn}'  

R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout' 
        library(dplyr)
        library(tidyverse)
        library(ggnewscale)
        
        all.list <- stringr::str_split("${_input}", " ", simplify = T)
       
        p_cut <- 1E-05

        cells<-c("${",".join(cells)}")%>%str_split(.,",",simplify = T)%>%as.character #ggnewscale cannot use a specified order, I can not find a good way to order them by category for now
        conditions <-combn(cells, m = 2) %>%apply(., 2, str_c, collapse = "_vs_") 
        

        df <- matrix(ncol = length(conditions), nrow = 4) %>% as.data.frame()
        colnames(df) <- conditions
        rownames(df) <- c( "n_sig_snp", "n_snp","n_sig_feature","n_all_feature")
        for (con in conditions) {
            n.all.sig.snp <- n.all.snp <- n.all.sig.feature <- n.all.feature <- 0

            for (i in 1:length(all.list)) {
                print(i)
                tmp <- readRDS(all.list[i])
                p.mtx <- tmp %>% select(matches("p_contrast.*_vs_"))
                p.mtx.con <- p.mtx %>% select(matches(con))
                n.sig.snp <- sum(p.mtx.con < p_cut)
                # print(n.sig.snp)
                if(ncol(p.mtx.con)>0){
                    p.mtx.con<-na.omit(p.mtx.con)
                    n.sig.snp <- sum(p.mtx.con < p_cut)
                    n.snp <- nrow(p.mtx.con)
                } else {
                    n.sig.snp <- n.snp <-0
                }
                # print(n.snp)
                n.sig.feature <- ifelse(n.sig.snp > 0, 1, 0)
                n.feature<- ifelse (n.snp > 0 , 1, 0)

                n.all.sig.snp <- n.sig.snp + n.all.sig.snp
                n.all.snp <- n.all.snp + n.snp
                n.all.sig.feature <- n.all.sig.feature + n.sig.feature
                n.all.feature <- n.all.feature + n.feature
            }
            df[, con] <- c( n.all.sig.snp,n.all.snp, n.all.sig.feature,n.all.feature)
        }
        write.csv(df, "${_output}")

In [1]:
# plot contrast result
[mash_posterior_contrast_3, posterior_cntrast_plot]
input: group_by = "all"
output: f"{cwd}/posterior_sum.png"
task: trunk_workers = 1, trunk_size = job_size, walltime = '24h',  mem = '10G', tags = f'{_output:bn}'  

R: expand = "${ }",stderr = f'{_output:n}.stderr', stdout = f'{_output:n}.stdout' 
        library(dplyr)
        library(tidyverse)
        library(ggnewscale)
        df <- read.csv("${_input:a}",row.names=1)
        colnames(df)<-gsub("DLPFC_","",colnames(df))
        for (i in 1:ncol(df)) {
            con1 <- stringr::str_split(colnames(df)[i], "_vs_", simplify = T)[, 1]
            con2 <- stringr::str_split(colnames(df)[i], "_vs_", simplify = T)[, 2]

            if (con1 > con2) {
                new.name <- paste0(con2, "_vs_", con1)
                colnames(df)[i] <- new.name
            }
        }

        ## summarizxse with approach 1: snp-feature pair
        snp.ratio <- df["n_sig_snp", ] / df["n_snp", ]
        snp.ratio <- snp.ratio %>%
            t() %>%
            as.data.frame()
        colnames(snp.ratio) <- "ratio"
        snp.ratio$group <- "snp"

        ## summarize with approach 2: feature
        fet.ratio <- df["n_sig_feature", ] / df["n_all_feature", ]
        fet.ratio <- fet.ratio %>%
            t() %>%
            as.data.frame()
        rownames(fet.ratio) <- paste0(stringr::str_split(rownames(fet.ratio), "_vs_", simplify = T)[, 2], "_vs_", stringr::str_split(rownames(fet.ratio), "_vs_", simplify = T)[, 1])
        colnames(fet.ratio) <- "ratio"
        fet.ratio$group <- "feature"
        ratio <- rbind(snp.ratio, fet.ratio)

        ## I need to add the below to make it Simmetrie

        cons <- rownames(ratio) %>%
            str_split(., "_vs_", simplify = T) %>%
            .[, 1] %>%
            unique()
        for (i in 1:length(cons)) {
            new.name <- paste0(cons[i], "_vs_", cons[i])
            ratio[new.name, ] <- 0
        }

        ratio$con1 <- stringr::str_split(rownames(ratio), "_vs_", simplify = T)[, 1]
        ratio$con2 <- stringr::str_split(rownames(ratio), "_vs_", simplify = T)[, 2]

        ## prepare for the plot, score1 is for snp-feature pair, score2 is for feature only
        ratio$score1 <- ratio$score2 <- 0
        ratio$score1[ratio$group == "snp"] <- ratio$ratio[ratio$group == "snp"]
        ratio$score2[ratio$group == "feature"] <- ratio$ratio[ratio$group == "feature"]
        ratio$label <- paste0(round(ratio$ratio, 4) * 100, "%")
        ratio$label[ratio$group == 0] <- NA

        # plot
        num_cols <- length(cons)
        height <- width <- 4 + num_cols * 0.5 
        ggplot(ratio[ratio$group == "snp", ], aes(x = con1, y = con2)) +
            geom_tile(aes(fill = score1)) +
            scale_fill_gradient2("SNP_Feature pair",
                low = "#762A83", mid = "white", high = "#1B7837"
            ) +
            new_scale("fill") +
            geom_tile(aes(fill = score2), data = subset(ratio, group != "snp")) +
            scale_fill_gradient2("Feature",
                low = "#1B7837", mid = "white", high = "#762A83"
            ) +
            geom_text(data = ratio, aes(label = label)) +
            theme_bw()
        #geom_text(data=ratio, aes(label = label, color = factor(group))) +theme_bw()
        #ggsave(gsub(".csv",".png",filename))

        ggsave("${_output}",width = width, height = height)