[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/cheninstitutecaltech/Caltech_DATASAI_Neuroscience_23/blob/main/07_20_23_day9_causal_modeling/code/solutions/exercise3.ipynb)


# Causal structure discovery: HCP resting-state fMRI data
Authors: Iman Wahle and Frederick Eberhardt

In this notebook, we will use the workflow explored in exercises 1 and 2 to 
infer a causal graph over parcels in the human brain from resting-state fMRI
activity as was done in [Dubois et al. 2017](https://www.biorxiv.org/content/10.1101/214486v1.full.pdf). The data we are working with was collected through the [Human
Connectome Project](https://www.humanconnectome.org/study/hcp-young-adult/project-protocol/resting-state-fmri) (HCP). In particular, we will be working with "Dataset 1"
specified in Dubois et al., which consists of 11 files of mean-centered 
resting-state samples from distinct sets of 80 subjects. Each file includes 
5440 samples (68 per subject) of activity over 110 parcels in the brain. 
The parcellation used here is the Harvard-Oxford atlas. The value for each 
parcel is set to be the average activity across all voxels within the parcel.

In [2]:
# colab setup
!git clone https://github.com/eberharf/fges-py.git
!pip install -q corner dill sortedcontainers gdown
!cd fges-py
import gdown
gdown.download_folder(url='https://drive.google.com/drive/folders/1vX4ZP63YTXKZSKIouwkXb-f--JNlfWoy',
                      output='data', quiet=True)

In [3]:
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
sns.set_context('talk')
import sys
sys.path.append('fges-py')
from SEMScore import *
from fges import *

## Load data

To start, write a function `load_data` that takes a `file_id` int as
an argument and returns a matrix of shape n_samples x n_parcels.

The data is stored at 'data/HCPcombined_HO110_25_GM_Finn_noTsmooth_RL_step35_nSub80_{}.tsv',
where `{}` should be replaced by the `file_id`. tsv files can be loaded
using `np.loadtxt`. The first row in each file contains column headers and
can be skipped. Entries in the file are separated by tabs.

Load in the data for `file_id = 1`.

A list of variable names is saved in 'data/parcel_labels.npy'. Load this in as well.

In [4]:
# add code here

## Initial data exploration

As before, make sure the data matches your expectations:

1. Print out the shape of the data and the names of the variables included
   - make note of how the parcels are ordered from the list of parcel names
2. Construct a [corner plot](https://corner.readthedocs.io/en/latest/pages/quickstart/)
   for five parcels (doing this for all parcels will take a really long time)
3. Plot the correlation matrix over all parcels

In [None]:
# inspect data shape and variable labels
# confirm that the number of samples and number of parcels match what we
# expect, and that the data matrix is formatted as (n_samples, n_parcels)

# add code here

In [None]:
# make a corner plot of the first few parcels
from corner import corner

# add code here

In [None]:
# Plot the correlation matrix between all 110 parcels (zero out the diagonal). 

# add code here

What structure do you see in the correlation matrix?

> Answer here

## Run the FGES algorithm

Since we are working with a relatively large dataset here, we will use an
alternative to the PC algorithm called Fast Greedy Equivalence Search (FGES), 
which is optimized for large numbers of variables.
The implementation we will use can be found [here](https://github.com/eberharf/fges-py),
and details about the algorithm can be found in Ramsey et al. 2016. 

The following function specifies an FGES object that infers edges across our
variables from the data provided. 

In [8]:
def infer_edges(data, s=8):
    '''
    Arguments:
        data : an n_samples x n_nodes array
        s : sparsity parameter for FGES (default = 8 as was used in Dubois et al.)
    Returns:
        edges : a list of tuples, where each tuple (i,j) represents an edge 
                found between node i and node j
        fges_result : dict of results from fges.search() (needed for estimating
                      the correlation matrix later on)
    '''

    # FGES takes a score function that depends on the data and a user-determined
    # sparsity level (penalty discount)
    score = SEMBicScore(penalty_discount=s, dataset=data)

    # run FGES
    fges = FGES(range(data.shape[1]), score, filename=data)
    fges_result = fges.search()
    edges = fges_result['graph'].edges()
    return edges, fges_result

edges,fges_result = infer_edges(data)

Instead of specifying the PDAG as an n_nodes x n_nodes matrix
like those we worked with in the previous exercises, this package specifies 
adjacencies as a list of tuples (`edges`), where each tuple $(i,j)$ included in
the list indicates a directed edge from $i$ into $j$. An undirected edge is
represented by including both $(i,j)$ and $(j,i)$ in the list.

For now, we are just interested in node adjacencies (without orientation 
information). Write a function `fges_edges_to_mat` that takes as input:

- `edges` : a list of edge tuples
- `n_nodes` : the total number of variables in our graph

The function should return an `n_nodes` x `n_nodes` numpy array, where entries
$(i,j)$ and $(j,i)$ are both set to 1 if there is an edge between node $i$ and
node $j$ and are 0 otherwise.

Use your function to convert the `edges` list constructed above to an array 
`adj_mat`. Visualize the resulting array using `plt.imshow` (make sure to 
label the parcel names).

In [9]:
# convert list of edges that fges returns to an adjacency matrix where entries
# (i,j) and (j,i) are 1 if there is an edge between node i and node j and 0 otherwise
def fges_edges_to_mat(edges, n_nodes):
    
    # add code here
    
    pass

adj_mat = fges_edges_to_mat(edges, data.shape[1])

Visualize adjacency matrix found by algorithm. What structure do you see? 

In [None]:
# visualize adjacency matrix found by algorithm. What structure do you see?

# add code here

While in the previous workflows, we had to 1) convert the PDAG to a DAG,
2) estimate the weights connection weights and residuals, and 3) compute
the resulting correlation matrix of the estimated graph, `fges-py` provides
a class called `SemEstimator` that will do this all for us. The following function
does so and returns the numpy array `est_corr` that is the correlation matrix
from the inferred graph.

In [None]:
from SemEstimator import SemEstimator

def estimate_corr(data, fges_result):
    '''
    Arguments:
        data : an n_samples x n_nodes numpy array
        fges_result : a dict of results returned by fges.search()
    Returns:
        est_corr : an n_nodes x n_nodes numpy array estimated correlation matrix
    '''
    sem_est = SemEstimator(data, sparsity=4)

    # provide to the estimator the DAG found above
    sem_est.pattern = fges_result['graph']

    # estimate the weights and residuals
    sem_est.estimate()

    # get covariance matrix from SemEstimator
    est_cov = sem_est.graph_cov

    # compute correlation matrix from covariance matrix
    stdistdj = np.sqrt(np.diag(est_cov))
    est_corr = est_cov / np.outer(stdistdj, stdistdj)
    return est_corr

est_corr = estimate_corr(data, fges_result)

Visualize the estimated correlation matrix (remember to zero out the diagonal)
and compare to the correlation matrix computed from the data. 

In [None]:
# add code here

Quantify how closely the estimated correlation matrix matches that found
from the data. To do this, vectorize the lower-triangular elements in the true
and estimated matrices and compute the Pearson correlation between the two vectors.

`np.corrcoef` and `np.tril_indices` may be useful here.

In [None]:
# add code here

## Compare graphs found from different subject groups

Now that we have performed this analysis for samples from one file (one set
of 80 subjects), we can compare the results of this pipeline across various
data subsets. One option is to compare results across the data files included
here (where each file corresponds to a different set of 80 subjects). Another
option is to compare results across sample subsets from the session we have 
worked with so far. Try one or both of these approaches and repurpose the functions 
above to construct a for loop that for each sample set:

1. loads in the data
2. runs FGES to get a list of inferred edges
3. converts the list of edges to an adjacency matrix

Store the adjacency matrix found on each loop iteration in a list.

In [14]:
# add code here

Compute the average adjacency matrix and visualize it with `plt.imshow` to see
how often each edge is found from data across sample sets.

In [None]:
# add code here