# MASH analysis pipeline with data-driven prior matrices

In this notebook, we utilize the MASH prior, referred to as the [mixture_prior](https://github.com/cumc/xqtl-pipeline/blob/6c637645ce16aee2aa7dc86bbc334fb6bb66b9d9/code/multivariate/MASH/mixture_prior.ipynb#L4), from a previous step. Our objective is to conduct a multivariate analysis under the MASH model. After fitting the model, we subsequently compute the posteriors for our variables of interest.

## Minimal working example
/home/rf2872/Work/Multivariate/MASH/MWE/output/MWE.rds


## multivariate analysis with [prior](https://github.com/cumc/xqtl-pipeline/blob/d43590b2da112aab357447e4ba931d95bc464cb5/code/multivariate/MASH/mixture_prior.ipynb) from [MWE]((https://github.com/cumc/xqtl-pipeline/blob/d43590b2da112aab357447e4ba931d95bc464cb5/code/multivariate/MASH/mixture_prior.ipynb))

In [6]:
#2: mash_fit, use the prior data 

sos run pipeline/mash_fit.ipynb mash \
    --container /mnt/vast/hpc/csg/containers_xqtl/stephenslab.sif \
    --output_prefix MWE_udr \
    --data output/MWE.rds \
    --cwd MWE_udr --vhat mle 

INFO: Running [32mmash_1[0m: Fit MASH mixture model (time estimate: <15min for 70K by 49 matrix)
INFO: Running [32mvhat_mle[0m: V estimate: "mle" method
INFO: [32mvhat_mle[0m is [32mcompleted[0m.
INFO: [32mvhat_mle[0m output:   [32m/mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/MWE/MWE_udr/MWE_udr.EZ.V_mle.rds[0m
INFO: [32mmash_1[0m is [32mcompleted[0m.
INFO: [32mmash_1[0m output:   [32m/mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/MWE/MWE_udr/MWE_udr.EZ.V_mle.mash_model.rds[0m
INFO: Running [32mmash_2[0m: Compute posterior for the "strong" set of data as in Urbut et al 2017. This is optional because most of the time we want to apply the MASH model learned on much larger data-set.
INFO: [32mmash_2[0m is [32mcompleted[0m.
INFO: [32mmash_2[0m output:   [32m/mnt/vast/hpc/csg/rf2872/Work/Multivariate/MASH/MWE/MWE_udr/MWE_udr.EZ.posterior.rds[0m
INFO: Workflow mash (ID=wc5d41041963cfe84) is executed successfully with 3 completed steps.


### Global parameter settings

In [2]:
[global]
parameter: cwd = path('./mashr_flashr_workflow_output')
# Input summary statistics data
parameter: data = path("fastqtl_to_mash_output/FastQTLSumStats.mash.rds")
# Prefix of output files. If not specified, it will derive it from data.
# If it is specified, for example, `--output-prefix AnalysisResults`
# It will save output files as `{cwd}/AnalysisResults*`.
parameter: output_prefix = ''
parameter: output_suffix = 'all'
# Exchangable effect (EE) or exchangable z-scores (EZ)
parameter: effect_model = 'EZ'
# Identifier of $\hat{V}$ estimate file
# Options are "identity", "simple", "mle", "vhat_corshrink_xcondition", "vhat_simple_specific"
parameter: vhat = 'simple'
# Options are "ed" and "udr"
parameter: stat_algo = "ed"
parameter: mixture_components = ['flash', 'flash_nonneg', 'pca',"canonical"]
parameter: container = ""
parameter: entrypoint={('micromamba run -a "" -n' + ' ' + container.split('/')[-1][:-4]) if container.endswith('.sif') else f''}
data = data.absolute()
cwd = cwd.absolute()
if len(output_prefix) == 0:
    output_prefix = f"{data:bn}"
prior_data = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.prior.rds")
vhat_data = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.rds")
mash_model = file_target(f"{cwd:a}/{output_prefix}.{effect_model}.V_{vhat}.mash_model.rds")

def sort_uniq(seq):
    seen = set()
    return [x for x in seq if not (x in seen or seen.add(x))]

### Command interface

In [8]:
# V estimate: "mle" method
[vhat_mle]
# number of samples to use
parameter: n_subset = 6000
# maximum number of iterations
parameter: max_iter = 6

input: data, prior_data
output: f'{vhat_data:nn}.V_mle.rds'
task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container, entrypoint=entrypoint
    library(mashr)
    dat = readRDS(${_input[0]:r})
    # choose random subset
    set.seed(1)
    random.subset = sample(1:nrow(dat$random.b), min(${n_subset}, nrow(dat$random.b)))
    random.subset = mash_set_data(dat$random.b[random.subset,], dat$random.s[random.subset,], alpha=${1 if effect_model == 'EZ' else 0}, zero_Bhat_Shat_reset = 1E3)
    # estimate V mle
    vhatprior = mash_estimate_corr_em(random.subset, readRDS(${_input[1]:r})$U, max_iter = ${max_iter})
    vhat = vhatprior$V
    saveRDS(vhat, ${_output:r})


## `mashr` mixture model fitting


In [9]:
# Fit MASH mixture model (time estimate: <15min for 70K by 49 matrix)
[mash_1]
parameter: outputlevel = 1
input: data, vhat_data, prior_data
output: mash_model

task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container, entrypoint=entrypoint
    library(mashr)
    dat = readRDS(${_input[0]:r})
    vhat = readRDS(${_input[1]:r})
    U = readRDS(${_input[2]:r})$U
    mash_data = mash_set_data(dat$random.b, Shat=dat$random.s, alpha=${1 if effect_model == 'EZ' else 0}, V=vhat, zero_Bhat_Shat_reset = 1E3)
    saveRDS(mash(mash_data, Ulist = U, outputlevel = ${outputlevel}), ${_output:r})

### Optional posterior computations

Additionally provide posterior for the "strong" set in MASH input data.

In [10]:
# Compute posterior for the "strong" set of data as in Urbut et al 2017.
# This is optional because most of the time we want to apply the 
# MASH model learned on much larger data-set.
[mash_2]
# default to True; use --no-compute-posterior to disable this
parameter: compute_posterior = True
# input Vhat file for the batch of posterior data
skip_if(not compute_posterior)

input: data, vhat_data, mash_model
output: f"{cwd:a}/{output_prefix}.{effect_model}.posterior.rds"

task: trunk_workers = 1, walltime = '36h', trunk_size = 1, mem = '4G', cores = 1, tags = f'{_output:bn}'
R: expand = "${ }", workdir = cwd, stderr = f"{_output:n}.stderr", stdout = f"{_output:n}.stdout", container = container, entrypoint=entrypoint
    library(mashr)
    dat = readRDS(${_input[0]:r})
    vhat = readRDS(${_input[1]:r})
    mash_data = mash_set_data(dat$strong.b, Shat=dat$strong.s, alpha=${1 if effect_model == 'EZ' else 0}, V=vhat, zero_Bhat_Shat_reset = 1E3)
    mash_model = readRDS(${_input[2]:ar})
    saveRDS(mash_compute_posterior_matrices(mash_model, mash_data), ${_output:r})