# FateZ Explain 

This notebook demonstrate how to utilize explanatory features of the models.

In [None]:
import os
import torch
from torch.utils.data import DataLoader
from torch.nn import CrossEntropyLoss, L1Loss
import fatez.lib as lib
import fatez.model as model
import fatez.model.gat as gat
import fatez.model.bert as bert
import fatez.process.explainer as explainer
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer

# Ignoring warnings because of using LazyLinear
import warnings
warnings.filterwarnings('ignore')


### Make some fake data and build model first.

In [None]:
# Parameters
k = 10              # Equivalent to total gene number
top_k = 4           # Equivalent to TF number
n_feature = 3       # Feature matrix dimmension
n_sample = 10       # Fake samples to make
batch_size = 1      # Batch size
n_class = 4         # Class number
n_bin = 100         # Depreciated
masker_ratio = 0.5  # Masking ratio before data input to BERT Encoder

# Params for GAT
gat_param = {
    'd_model': n_feature,
    'en_dim': 8,
    'n_hidden': 4,
    'nhead': 2,
    'device':'cpu',
    'dtype': torch.float32,
}

# Params for BERT
# Need to make sure d_model is divisible by nhead
bert_encoder_param = {
    'd_model': gat_param['en_dim'],
    'n_layer': 6,
    'nhead': 8,
    'dim_feedforward': gat_param['en_dim'],
    'dtype': torch.float32,
}

# Generate Fake data
dataset = lib.FateZ_Dataset(
    samples = [
        [
            torch.randn(k, gat_param['d_model'], dtype = torch.float32),
            torch.randn(top_k, k, dtype = torch.float32)
        ] for i in range(n_sample)
    ],
    labels = torch.empty(n_sample, dtype = torch.long).random_(n_class)
)
# Make datalaoder
train_dataloader = DataLoader(dataset, batch_size = batch_size, shuffle = True)

# Build Models
gat_model = gat.GAT(**gat_param)
masker = model.Masker(ratio = masker_ratio)
bert_encoder = bert.Encoder(**bert_encoder_param)

print('Fake gene num:', k)
print('Fake TF num:', top_k)
print('Fake Sample Number:', n_sample)
print('Batch Size:', batch_size)
print('Class Number:', n_class, '\n')

### Now we perform pre-training with no label.

In [None]:
pre_training = pre_trainer.Model(
    gat = gat_model,
    masker = masker,
    bin_pro = model.Binning_Process(n_bin = n_bin),
    bert_model = bert.Pre_Train_Model(
        bert_encoder, n_bin = n_bin, n_dim = gat_model.d_model
    )
)

for input, _ in train_dataloader:
    output = pre_training(input[0], input[1])
    L1Loss()(
        output, torch.split(input[0], output.shape[1] , dim=1)[0]
    ).backward()

### Then, we can go for fine tuning part with class labels.

In [None]:
fine_tuning = fine_tuner.Model(
    gat = gat_model,
    bin_pro = model.Binning_Process(n_bin = n_bin),
    bert_model = bert.Fine_Tune_Model(
        bert_encoder, n_hidden = 2, n_class = n_class
    )
)

for input, label in train_dataloader:
    output = fine_tuning(input[0], input[1])
    CrossEntropyLoss()(output, label).backward()

### To explain Fine Tuning model in general. 

Note: to make overall conclusion on the contribution of a specific gene, we would need to sum up importance values for every feature dimension (RNA-count, peaks)

In [None]:
# Get background data
background_data = [a for a, _ in DataLoader(dataset, batch_size = n_sample)][0]
explain = explainer.Gradient(fine_tuning, background_data)

# vars can be used to estimate how accurate the explanation would be: lower the better
gene_shap_values, vars = explain.shap_values(input, return_variances = True)
print(f'Explaining {len(gene_shap_values)} classes.')

# Having 2 inputs
assert len(gene_shap_values[0]) == 2

# Only the feat mat explanation should be working
print(gene_shap_values[0][0].shape)

# Only the adj mat explanation should NOT be working since lacking gradient
print(gene_shap_values[0][1].shape)


### To explain the BERT part for analyzing importances of TFs only.

Note: similarly, we would want to sum up values across embed dimensions.

In [None]:
# We also should accumulate gat_out for every trained input.
# Here I just make 1 gat_out for example
gat_out = fine_tuning.get_gat_output(input[0], input[1])
explain = explainer.Gradient(fine_tuning.bert_model, gat_out)

regulon_shap_values, vars = explain.shap_values(gat_out, return_variances=True)
print(f'Explaining {len(regulon_shap_values)} classes.')

# Now we only have one input.
print(regulon_shap_values[0].shape)

### To explain the GAT for analyzing GRP importances.

The grp_explain here is purely based on the GAT attention weights.

In [None]:
grp_explain = gat_model.explain(input[0][0], input[1][0])
print(grp_explain.shape)

# Or we can feed in matrices with ones to extract attention weights.
grp_explain = gat_model.explain(
    torch.ones_like(input[0][0]), torch.ones_like(input[1][0])
)
print(grp_explain.shape)


Utilizing the importance values of each gene or TF regulon inferred from the calculated shapley values above would be sufficent to obtain importances of each GRP.

In [None]:
# Sum up shapley values of each features for every node (gene or TF).
regulon_importance_values = regulon_shap_values[0].sum(2)
gene_importance_values = gene_shap_values[0][0].sum(2)
print(regulon_importance_values.shape)
print(gene_importance_values.shape)

grp_importance = torch.matmul(
    torch.Tensor(regulon_importance_values[0]), grp_explain
)
grp_importance = torch.matmul(
    grp_explain, torch.Tensor(gene_importance_values[0])
)

