# FateZ Explain 

This notebook demonstrate how to utilize explanatory methods of FateZ models.

In [1]:
import os
import torch
from torch.utils.data import DataLoader
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process.explainer as explainer
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

# Ignoring warnings because of using LazyLinear
import warnings
warnings.filterwarnings('ignore')

### Build model and make some fake data first.

In [2]:
# Parameters
params = {
    'k': 10,              # Equivalent to total gene number
    'top_k': 4,           # Equivalent to TF number
    'n_features': 3,      # Feature matrix dimmension
    'n_sample': 10,       # Fake samples to make
    'batch_size': 1,      # Batch size
    'n_class': 4,         # Class number
}

# Load built-in config file
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/test_config.json'
    )
)

# Adjust parametes according to data dims
config['gat']['params']['d_model'] = params['n_features']
config['fine_tuner']['n_class'] = params['n_class']
factory_kwargs = {'device': 'cpu', 'dtype': torch.float32,}

# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()



### Now we perform pre-training with no label.


Here trainer's $n\_dim\_adj$ is set to None, and the model is NOT reconstructing the adjacency matrices, etc.

In [3]:
trainer = pre_trainer.Set(config, factory_kwargs)
pt_loss = trainer.train(train_dataloader)
print(f'Pre-Trainer total loss:{pt_loss}\n')

Pre-Trainer total loss:1.1127713918685913



In the case of pre-training with reconstructing adjacency matrices as well.

In [4]:
config['pre_trainer']['n_dim_adj'] = params['k']
trainer = pre_trainer.Set(config, factory_kwargs)
pt_loss = trainer.train(train_dataloader)
print(f'Pre-Trainer total loss:{pt_loss}\n')

Pre-Trainer total loss:1.924447774887085



### Then, we can go for fine tuning part with class labels.

In [7]:
tuner = fine_tuner.Tuner(
    gat = trainer.model.gat,
    encoder = trainer.model.bert_model.encoder,
    rep_embedder = trainer.model.bert_model.rep_embedder,
    **config['fine_tuner'],
    **factory_kwargs,
)

for input, label in train_dataloader:
    output = tuner.model(input[0], input[1])
    torch.nn.CrossEntropyLoss()(output, label).backward()

### To explain Fine Tuning model in general. 

Note: to make overall conclusion on the contribution of a specific gene, we would need to sum up importance values for every feature dimension (RNA-count, peaks)

In [8]:
# Get background data
background_data = [a for a, _ in DataLoader(train_dataloader.dataset, batch_size = params['n_sample'])][0]
explain = explainer.Gradient(tuner.model, background_data)

# vars can be used to estimate how accurate the explanation would be: lower the better
gene_shap_values, vars = explain.shap_values(input, return_variances = True)
print(f'Explaining {len(gene_shap_values)} classes.')

# Having 2 inputs
assert len(gene_shap_values[0]) == 2

# Only the feat mat explanation should be working
print(f'Each class has explain in shape of {gene_shap_values[0][0].shape}.')

# The adj mat explanation should NOT be working since lacking gradient
# print(gene_shap_values[0][1].shape)

Explaining 4 classes.
Each class has explain in shape of (1, 10, 3).


### To explain the BERT part for analyzing importances of TFs only.

Note: similarly, we would want to sum up values across embed dimensions.

In [9]:
# We also should accumulate gat_out for every trained input.
# Here I just make 1 gat_out for example
gat_out = tuner.model.get_gat_output(input[0], input[1])
explain = explainer.Gradient(tuner.model.bert_model, gat_out)

regulon_shap_values, vars = explain.shap_values(gat_out, return_variances=True)
print(f'Explaining {len(regulon_shap_values)} classes.')
print(f'Each class has explain in shape of {regulon_shap_values[0].shape}.')

Explaining 4 classes.
Each class has explain in shape of (1, 4, 3).


### To explain the GAT for analyzing GRP importances.

The grp_explain here is purely based on the GAT attention weights.

In [10]:
grp_explain = tuner.model.gat.explain(input[0][0], input[1][0])
print(grp_explain.shape)

# Or we can feed in matrices with ones to extract attention weights.
grp_explain = tuner.model.gat.explain(
    torch.ones_like(input[0][0]), torch.ones_like(input[1][0])
)
print(grp_explain.shape)

torch.Size([4, 10])
torch.Size([4, 10])


Utilizing the importance values of each gene or TF regulon inferred from the calculated shapley values above would be sufficent to obtain importances of each GRP.

In [11]:
# Sum up shapley values of each features for every node (gene or TF).
regulon_importance_values = regulon_shap_values[0].sum(2)
gene_importance_values = gene_shap_values[0][0].sum(2)
print(regulon_importance_values.shape)
print(gene_importance_values.shape)

grp_importance = torch.matmul(
    torch.Tensor(regulon_importance_values[0]), grp_explain
)
grp_importance = torch.matmul(
    grp_explain, torch.Tensor(gene_importance_values[0])
)

(1, 4)
(1, 10)
