# FateZ Multiomic Pertubation Effect Prediction(?)
This notebook demonstrate how to implement Pertubation Effect Prediction method with FateZ's modules.

In [3]:
import os
import sys
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import pandas as pd
import fatez.lib as lib
import fatez.test as test
import fatez.model as model
import fatez.tool.JSON as JSON
import fatez.process as process
import fatez.process.fine_tuner as fine_tuner
import fatez.process.pre_trainer as pre_trainer
from pkg_resources import resource_filename

suppressor = process.Quiet_Mode()
print('Done import')

Done import


### Build model and make some fake data first.

In [5]:
# Parameters
params = {
    'n_sample': 10,       # Fake samples to make
    'batch_size': 2,      # Batch size
}

# Load built-in config file
config = JSON.decode(resource_filename(
        __name__, '../../fatez/data/config/gat_bert_config.json'
    )
)

device = 'cuda'
dtype = torch.float32

# Generate Fake data
faker = test.Faker(model_config = config, **params)
pertubation_dataloader = faker.make_data_loader()
result_dataloader = faker.make_data_loader()

# Make id of pertubation result the 'label' of each sample
for i,k in enumerate(pertubation_dataloader.dataset.samples):
    k.y = i
    
print('Done Fake Data')

Done Fake Data


### The model will be architecturally similar with a pretrainer

In [6]:
worker = pre_trainer.Set(config, dtype = dtype)

print('Model Set')

Model Set


### However, the training part will be littel bit different
This part is modified based on pre_trainer.Trainer.train()

In [16]:
report_batch = False
size = worker.input_sizes

net, device = worker.use_device(device)
net.train(True)
best_loss = 99
loss_all = 0
report = list()

for x,y in pertubation_dataloader:
    
    # Prepare input data as always
    input = [ele.to(device) for ele in x]
    
    # Mute some debug outputs
    suppressor.on()
    node_rec, adj_rec = net(input)
    suppressor.off()
    
    # Prepare pertubation result data using a seperate dataloader
    y = [result_dataloader.dataset.samples[ele].to(device) for ele in y]
    # Please be noted here that this script is only reconstructing TF parts
    # To reconstruct whole genome, we can certainly add an additionaly layer which takes adj_rec and node_rec to do the job.
    node_results = torch.split(
        torch.stack([ele.x for ele in y], 0),
        node_rec.shape[1],
        dim = 1
    )[0]
    adj_results = lib.get_dense_adjs(
        y, (size['n_reg'],size['n_node'],size['edge_attr'])
    )
    
    # Get total loss
    loss = worker.criterion(node_rec, node_results)
    if adj_rec is not None:
        loss += worker.criterion(adj_rec, adj_results)
    
    # Some backward stuffs here
    loss.backward()
    nn.utils.clip_grad_norm_(worker.model.parameters(), worker.max_norm)
    worker.optimizer.step()
    worker.optimizer.zero_grad()

    # Accumulate
    best_loss = min(best_loss, loss.item())
    loss_all += loss.item()

    # Some logs
    if report_batch: report.append([loss.item()])


worker.scheduler.step()
report.append([loss_all / len(pertubation_dataloader)])
report = pd.DataFrame(report)
report.columns = ['Loss', ]
print(report)

       Loss
0  2.730115


### In the case of tuning unlabeled data, which does not have pertubation results... 
We shall set another trainer using previous model.

In [24]:
tuner = pre_trainer.Set(config, prev_model = worker.model, dtype = dtype)

# Some new fake data
tuner_dataloader = faker.make_data_loader()

# And the tuning process is also based on input reconstruction as pretraining
suppressor.on()
report = tuner.train(tuner_dataloader, report_batch = False, device = device)
suppressor.off()
print(report)

       Loss
0  2.118923


### Then we shall just use worker object to make predictions.
Similar with the training block above for worker, but no need to prepare y.

In [27]:
net, device = worker.use_device(device)
net.train(True)

for x,_ in tuner_dataloader:
    
    # Prepare input data as always
    input = [ele.to(device) for ele in x]
    
    # Mute some debug outputs
    suppressor.on()
    node_rec, adj_rec = net(input)
    suppressor.off()
    print(node_rec, adj_rec)

tensor([[[-0.1218, -0.0319],
         [ 0.2913, -0.8462],
         [ 0.1863, -0.7708],
         [-0.0997, -0.1563]],

        [[-0.1467,  0.0468],
         [ 0.5391, -0.9599],
         [ 0.2706, -0.8496],
         [-0.1051, -0.0997]]], device='cuda:0', grad_fn=<ViewBackward0>) tensor([[[ 0.9759,  0.7606, -0.4116,  0.1393, -0.1631,  0.3622, -0.4710,
           0.3269, -0.7119,  1.0409],
         [ 0.6905,  0.9797,  0.8428, -0.2868,  0.7114,  0.6660, -0.4230,
           1.2142, -0.2275,  1.2271],
         [ 0.6833,  0.9137,  0.8117, -0.2809,  0.6595,  0.6044, -0.4215,
           1.2556, -0.3990,  1.1928],
         [ 0.9935,  0.8441, -0.2341,  0.0878, -0.0140,  0.4077, -0.4854,
           0.4835, -0.7113,  1.1195]],

        [[ 0.9530,  0.6944, -0.5049,  0.1638, -0.2525,  0.3264, -0.4590,
           0.2521, -0.7332,  0.9838],
         [ 0.4754,  0.8820,  0.9961, -0.3643,  0.7589,  0.7388, -0.3493,
           1.1680,  0.1863,  1.1010],
         [ 0.6238,  0.9135,  0.9208, -0.3226,  0.7265,