# FateZ Explain 

This notebook demonstrate how to utilize explanatory methods of FateZ models.

In [17]:
import os
import sys
import torch
from torch.utils.data import DataLoader
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

# Ignoring warnings because of using LazyLinear
import warnings
warnings.filterwarnings('ignore')

print('Done')

Done


### Build model and make some fake data first.

In [18]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 1,      # Batch size
}

# Load built-in config file
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)

factory_kwargs = {'device': 'cuda', 'dtype': torch.float32,}

# Generate Fake data
faker = test.Faker(model_config = config, **params)
train_dataloader = faker.make_data_loader()

print('Done')

Done


### Now we perform pre-training with no label.


Here trainer's $train\_adj$ is set to False, and the model is NOT reconstructing the adjacency matrices, etc.

In [19]:
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   1.523740
1   2.056911
2   2.464677
3   2.160961
4   1.769112
5   2.518550
6   2.565397
7   1.717235
8   2.797816
9   2.683969
10  2.225837


In the case of pre-training with reconstructing adjacency matrices as well.

In [20]:
config['pre_trainer']['train_adj'] = True
trainer = pre_trainer.Set(config, factory_kwargs)
report = trainer.train(train_dataloader, report_batch = True)
print(report)

        Loss
0   1.895170
1   1.933711
2   2.200883
3   2.626011
4   2.680342
5   1.341890
6   2.156959
7   2.802052
8   1.489988
9   2.411124
10  2.153813


### Then, we can go for fine tuning part with class labels.

In [21]:
tuner = fine_tuner.Set(config, factory_kwargs, prev_model = trainer.model)
report = tuner.train(train_dataloader, report_batch = True,)
print(report)

        Loss  ACC
0   0.493794  1.0
1   0.446725  1.0
2   1.043058  0.0
3   1.076692  0.0
4   1.044668  0.0
5   0.486640  1.0
6   1.019905  0.0
7   0.943047  0.0
8   0.482891  1.0
9   0.489378  1.0
10  0.752680  0.5


### To explain model.

Three kinds of explanations are available:
1. edge_explain
2. regulon_explain
3. node_explain

In [25]:
# Initializing edge explain matrix and regulon explain matrix
adj_exp = torch.zeros((config['input_sizes']['n_reg'], config['input_sizes']['n_node']))
reg_exp = torch.zeros((config['input_sizes']['n_reg'], config['encoder']['d_model']))

# Make background data
bg = [a for a,_ in DataLoader(train_dataloader.dataset, batch_size = params['n_sample'])][0]
# Set explainer through taking input data from pseudo-dataloader
explain = tuner.model.make_explainer([a.to(factory_kwargs['device']) for a in bg])

for x,_ in train_dataloader:
    data = [a.to(factory_kwargs['device']) for a in x]
    adj_temp, reg_temp, vars = tuner.model.explain_batch(data, explain)
    adj_exp += adj_temp
    
    print(f'Explaining {len(reg_temp)} classes.')
    
    # Only the feat mat explanation should be working
    print(f'Each class has regulon explain in shape of {reg_temp[0][0].shape}.\n')

    # Only taking explainations for class 0
    for exp in reg_temp[0]: reg_exp += abs(exp)
    break

reg_exp = torch.sum(reg_exp, dim = -1)
node_exp = torch.matmul(reg_exp, adj_exp.type(reg_exp.dtype))
print('Edge Explain:\n', adj_exp, '\n')
print('Reg Explain:\n', reg_exp, '\n')
print('Node Explain:\n', node_exp, '\n')

Explaining 2 classes.
Each class has regulon explain in shape of (4, 4).

Edge Explain:
 tensor([[0.0268, 0.0269, 0.0271, 0.0245, 0.0247, 0.0253, 0.0252, 0.0257, 0.0000,
         0.0367],
        [0.0328, 0.0317, 0.0301, 0.0304, 0.0290, 0.0296, 0.0320, 0.0267, 0.0000,
         0.0273],
        [0.0282, 0.0253, 0.0250, 0.0250, 0.0254, 0.0252, 0.0271, 0.0290, 0.0000,
         0.0285],
        [0.0255, 0.0267, 0.0266, 0.0264, 0.0292, 0.0260, 0.0326, 0.0275, 0.0000,
         0.0285]]) 

Reg Explain:
 tensor([0.0036, 0.0017, 0.0130, 0.0060], dtype=torch.float64) 

Node Explain:
 tensor([0.0007, 0.0006, 0.0006, 0.0006, 0.0006, 0.0006, 0.0007, 0.0007, 0.0000,
        0.0007], dtype=torch.float64) 

