# FateZ Explain 

This notebook demonstrate how to utilize explanatory methods of FateZ models.

In [1]:
import os
import torch
from torch.utils.data import DataLoader
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process.explainer as explainer
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

# Ignoring warnings because of using LazyLinear
import warnings
warnings.filterwarnings('ignore')

### Build model and make some fake data first.

In [2]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 1,      # Batch size
}

# Load built-in config file
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)

factory_kwargs = {'device': 'cuda', 'dtype': torch.float32,}

# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()



### Now we perform pre-training with no label.


Here trainer's $n\_dim\_adj$ is set to None, and the model is NOT reconstructing the adjacency matrices, etc.

In [28]:
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   1.371900
1   0.768842
2   1.256326
3   1.237347
4   1.134314
5   0.577454
6   1.105542
7   1.238400
8   0.952839
9   0.915199
10  1.055816


In the case of pre-training with reconstructing adjacency matrices as well.

In [29]:
config['pre_trainer']['n_dim_adj'] = config['input_sizes'][0][1]
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   1.816574
1   1.730838
2   1.741792
3   2.077016
4   1.784506
5   2.210017
6   1.868762
7   1.616021
8   2.049183
9   1.998312
10  1.889302


### Then, we can go for fine tuning part with class labels.

In [30]:
tuner = fine_tuner.Set(config, factory_kwargs, prev_model = trainer.model)
report = tuner.train(train_dataloader, report_batch = True)
print(report)

        Loss  ACC
0   1.393003  0.0
1   1.395453  0.0
2   1.484483  0.0
3   1.469457  0.0
4   1.292269  1.0
5   1.478473  0.0
6   1.389301  0.0
7   1.281845  1.0
8   1.476806  0.0
9   1.279046  1.0
10  1.394014  0.3


### To explain Fine Tuning model in general. 

Note: to make overall conclusion on the contribution of a specific gene, we would need to sum up importance values for every feature dimension (RNA-count, peaks)

In [47]:
# Get background data
background_data = [a for a, _ in DataLoader(train_dataloader.dataset, batch_size = params['n_sample'])][0]
background_data = [a.to(factory_kwargs['device']) for a in background_data]
explain = explainer.Gradient(tuner.model, background_data)

# vars can be used to estimate how accurate the explanation would be: lower the better
for input_data, _ in train_dataloader:
    input_data = [i.to(factory_kwargs['device']) for i in input_data]
    gene_shap_values, vars = explain.shap_values(input_data, return_variances = True)
    break

print(f'Explaining {len(gene_shap_values)} classes.')

# Having 2 inputs
assert len(gene_shap_values[0]) == 2

# Only the feat mat explanation should be working
print(f'Each class has explain in shape of {gene_shap_values[0][0].shape}.')

# The adj mat explanation should NOT be working since lacking gradient
# print(gene_shap_values[0][1].shape)

Explaining 4 classes.
Each class has explain in shape of (1, 10, 2).


### To explain the BERT part for analyzing importances of TFs only.

Note: similarly, we would want to sum up values across embed dimensions.

In [49]:
# We also should accumulate gat_out for every trained input.
# Here I just make 1 gat_out for example
gat_out = tuner.model.get_gat_output(input_data[0], input_data[1])
explain = explainer.Gradient(tuner.model.bert_model, gat_out)

regulon_shap_values, vars = explain.shap_values(gat_out, return_variances=True)
print(f'Explaining {len(regulon_shap_values)} classes.')
print(f'Each class has explain in shape of {regulon_shap_values[0].shape}.')

Explaining 4 classes.
Each class has explain in shape of (1, 4, 3).


### To explain the GAT for analyzing GRP importances.

The grp_explain here is purely based on the GAT attention weights.

In [53]:
grp_explain = tuner.model.gat.explain(input_data[0][0], input_data[1][0])
print(grp_explain.shape)

# Or we can feed in matrices with ones to extract attention weights.
grp_explain = tuner.model.gat.explain(torch.ones_like(input_data[0][0]), torch.ones_like(input_data[1][0]))
print(grp_explain.shape)

torch.Size([4, 10])
torch.Size([4, 10])


Utilizing the importance values of each gene or TF regulon inferred from the calculated shapley values above would be sufficent to obtain importances of each GRP.

In [54]:
# Sum up shapley values of each features for every node (gene or TF).
regulon_importance_values = regulon_shap_values[0].sum(2)
gene_importance_values = gene_shap_values[0][0].sum(2)
print(regulon_importance_values.shape)
print(gene_importance_values.shape)

grp_importance = torch.matmul(torch.Tensor(regulon_importance_values[0]), grp_explain)
grp_importance = torch.matmul(grp_explain, torch.Tensor(gene_importance_values[0]))

(1, 4)
(1, 10)
