# FateZ Explain 

This notebook demonstrate how to utilize explanatory methods of FateZ models.

In [8]:
import os
import sys
import torch
from torch.utils.data import DataLoader
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.explainer as explainer
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

# Ignoring warnings because of using LazyLinear
import warnings
warnings.filterwarnings('ignore')

print('Done')

Done


### Build model and make some fake data first.

In [2]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 1,      # Batch size
}

# Load built-in config file
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)

factory_kwargs = {'device': 'cuda', 'dtype': torch.float32,}

# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()

print('Done')

Done


### Now we perform pre-training with no label.


Here trainer's $train\_adj$ is set to False, and the model is NOT reconstructing the adjacency matrices, etc.

In [3]:
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   2.855018
1   2.470851
2   2.365319
3   2.289390
4   1.723159
5   2.777436
6   2.480390
7   1.580967
8   1.582995
9   1.561940
10  2.168747


In the case of pre-training with reconstructing adjacency matrices as well.

In [4]:
config['pre_trainer']['train_adj'] = True
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   2.429936
1   2.488689
2   2.209337
3   1.618044
4   2.654165
5   2.571504
6   1.685288
7   1.725671
8   1.828922
9   2.634892
10  2.184645


### Then, we can go for fine tuning part with class labels.

In [5]:
tuner = fine_tuner.Set(config, factory_kwargs, prev_model = trainer.model)
report = tuner.train(train_dataloader, report_batch = True, quiet = True)
print(report)

        Loss  ACC
0   1.121826  0.0
1   1.119385  0.0
2   0.394212  1.0
3   0.394363  1.0
4   0.393310  1.0
5   1.119084  0.0
6   0.394333  1.0
7   1.118813  0.0
8   1.120818  0.0
9   0.394569  1.0
10  0.757071  0.5


### To explain Fine Tuning model in general. 

Note: to make overall conclusion on the contribution of a specific gene, we would need to sum up importance values for every feature dimension (RNA-count, peaks)

In [9]:
# Suppress debug print outs
suppressor = process.Quiet_Mode()
suppressor.on()

# Get background data
background_data = DataLoader(train_dataloader.dataset, batch_size = params['n_sample'])
background_data = [a for a, _ in background_data][0]
background_data = [a.to(factory_kwargs['device']) for a in background_data]
explain = explainer.Gradient(tuner.model, background_data)
print('Suppress here?')
# vars can be used to estimate how accurate the explanation would be: lower the better
for input_data, _ in train_dataloader:
    input_data = [i.to(factory_kwargs['device']).to_dense() for i in input_data]
    gene_shap_values, vars = explain.shap_values(input_data, return_variances = True)
    break
suppressor.off()
    
    
print(f'Explaining {len(gene_shap_values)} classes.')

# Having 2 inputs
assert len(gene_shap_values[0]) == 2

# Only the feat mat explanation should be working
print(f'Each class has explain in shape of {gene_shap_values[0][0].shape}.')

# The adj mat explanation should NOT be working since lacking gradient
# print(gene_shap_values[0][1].shape)

Explaining 2 classes.
Each class has explain in shape of (1, 10, 2).


### To explain the BERT part for analyzing importances of TFs only.

Note: similarly, we would want to sum up values across embed dimensions.

In [10]:
# We also should accumulate gat_out for every trained input.
# Here I just make 1 gat_out for example
gat_out = tuner.model.get_gat_output(input_data[0], input_data[1])
explain = explainer.Gradient(tuner.model.bert_model, gat_out)

regulon_shap_values, vars = explain.shap_values(gat_out, return_variances=True)
print(f'Explaining {len(regulon_shap_values)} classes.')
print(f'Each class has explain in shape of {regulon_shap_values[0].shape}.')

Explaining 2 classes.
Each class has explain in shape of (1, 4, 3).


### To explain the GAT for analyzing GRP importances.

The grp_explain here is purely based on the GAT attention weights.

In [11]:
grp_explain = tuner.model.gat.explain(input_data[0][0], input_data[1][0])
print(grp_explain.shape)

# Or we can feed in matrices with ones to extract attention weights.
grp_explain = tuner.model.gat.explain(
    torch.ones_like(input_data[0][0]),
    torch.ones_like(input_data[1][0])
)
print(grp_explain.shape)

torch.Size([4, 10])
torch.Size([4, 10])


Utilizing the importance values of each gene or TF regulon inferred from the calculated shapley values above would be sufficent to obtain importances of each GRP.

In [12]:
# Sum up shapley values of each features for every node (gene or TF).
regulon_importance_values = regulon_shap_values[0].sum(2)
gene_importance_values = gene_shap_values[0][0].sum(2)
print(regulon_importance_values.shape)
print(gene_importance_values.shape)
grp_explain = grp_explain.to(factory_kwargs['device'])
grp_importance = torch.matmul(torch.Tensor(regulon_importance_values[0]).to(factory_kwargs['device']), grp_explain)
grp_importance = torch.matmul(grp_explain, torch.Tensor(gene_importance_values[0]).to(factory_kwargs['device']))

(1, 4)
(1, 10)
