In [16]:
from torch_geometric.data import DataLoader
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torch_geometric.nn import global_max_pool as gmp
from torch.nn import Sequential, Linear, ReLU
from torch_geometric.nn import GCNConv, GATConv, GATv2Conv, SAGEConv, GINEConv, GINConv

In [2]:
print(torch.__version__)
print(torch.version.cuda)

1.11.0
10.2


In [3]:
device = torch.device("cuda:2")
# device = torch.device("cpu")
device

device(type='cuda', index=2)

In [4]:
from utils_data import TestbedDataset
from models import GATNet_E, GATNet, GCNNet, GATv2Net, SAGENet, GINNet, GINENet

In [5]:
model = GATNet()
model_path = 'root_folder/root_013/models/model_GAT-EP300-SW801010_GDSC.model'

# model = GCNNet()
# model_path = 'root_folder/root_003/models/model_GCN-EP300-SW801010_GDSC.model'
model.load_state_dict(torch.load(model_path, map_location=device))
model.to(device)

GATNet(
  (gcn1): GATConv(334, 334, heads=10)
  (gcn2): GATConv(3340, 128, heads=1)
  (fc_g1): Linear(in_features=128, out_features=128, bias=True)
  (conv_xt_1): Conv1d(1, 32, kernel_size=(8,), stride=(1,))
  (pool_xt_1): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv_xt_2): Conv1d(32, 64, kernel_size=(8,), stride=(1,))
  (pool_xt_2): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv_xt_3): Conv1d(64, 128, kernel_size=(8,), stride=(1,))
  (pool_xt_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (fc1_xt): Linear(in_features=2944, out_features=128, bias=True)
  (fc1): Linear(in_features=256, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

## DrugNet
extract only the drug-gnn part from the pretrained model

In [65]:
class DrugNet(GATNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def arguments_read(self, *args, **kwargs):

        data: Batch = kwargs.get('data') or None

        if not data:
            if not args:
                assert 'x' in kwargs
                assert 'edge_index' in kwargs
                x, edge_index = kwargs['x'], kwargs['edge_index'],
                batch = kwargs.get('batch')
                if batch is None:
                    batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device)
            elif len(args) == 2:
                x, edge_index = args[0], args[1]
                batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
            elif len(args) == 3:
                x, edge_index, batch = args[0], args[1], args[2]
            else:
                raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
        else:
            x, edge_index, batch = data.x, data.edge_index, data.batch

        return x, edge_index, batch

    def forward(self, *args, **kwargs):
        x, edge_index, batch_drug = self.arguments_read(*args, **kwargs)
        
        x = self.dropout(x)
        # x = F.dropout(x, p=0.2, training=self.training)
        x = F.elu(self.gcn1(x, edge_index))
        # x = F.dropout(x, p=0.2, training=self.training)
        
        x = self.dropout(x)
        x = self.gcn2(x, edge_index)
        x = self.relu(x)
        
#         batch_drug = torch.zeros(x.shape[0], dtype=torch.long).to(device)
        x = gmp(x, batch_drug)          # global max pooling
        x = self.fc_g1(x)
        x = self.relu(x)
        
        return x

In [66]:
drug_model = DrugNet()
drug_model.to(device)
drug_model

DrugNet(
  (gcn1): GATConv(334, 334, heads=10)
  (gcn2): GATConv(3340, 128, heads=1)
  (fc_g1): Linear(in_features=128, out_features=128, bias=True)
  (conv_xt_1): Conv1d(1, 32, kernel_size=(8,), stride=(1,))
  (pool_xt_1): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv_xt_2): Conv1d(32, 64, kernel_size=(8,), stride=(1,))
  (pool_xt_2): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (conv_xt_3): Conv1d(64, 128, kernel_size=(8,), stride=(1,))
  (pool_xt_3): MaxPool1d(kernel_size=3, stride=3, padding=0, dilation=1, ceil_mode=False)
  (fc1_xt): Linear(in_features=2944, out_features=128, bias=True)
  (fc1): Linear(in_features=256, out_features=1024, bias=True)
  (fc2): Linear(in_features=1024, out_features=128, bias=True)
  (out): Linear(in_features=128, out_features=1, bias=True)
  (relu): ReLU()
  (dropout): Dropout(p=0.5, inplace=False)
)

In [6]:
branch_folder = "root_folder/root_013"
dataset = 'GDSC'
test_data = TestbedDataset(root=branch_folder, dataset=dataset+'_test_mix')

Pre-processed data found: root_folder/root_013/processed/GDSC_test_mix.pt, loading ...


In [19]:
test_batch = 1
test_loader = DataLoader(test_data, batch_size=test_batch, shuffle=False)

In [20]:
one = next(iter(test_loader))
one

DataBatch(x=[32, 334], edge_index=[2, 74], y=[1], edge_features=[74, 4], smiles=[1], drug_name=[1], cell_line_name=[1], target=[1, 735], c_size=[1], batch=[32], ptr=[2])

when batch size is 1,
batch_drug is a tensor of zeros with shape = x.shape[0]

In [33]:
# for data in test_loader:
#     print(data.batch)

In [52]:
data = one.to(device)
# out, _ = model(data.x, data.edge_index, data.target, data.batch, data.edge_features)
drug_out = drug_model(data)

In [53]:
drug_out.shape

torch.Size([1, 128])

In [34]:
from dig.xgraph.method import SubgraphX

In [68]:
explainer = SubgraphX(drug_model, num_classes=128, device=device, explain_graph=True, reward_method='gnn_score')

In [69]:
# x, x_cell_mut, edge_index, batch_drug, edge_feat = data.x, data.target, data.edge_index, data.batch, data.edge_features
# kwargs = {
#     "target": x_cell_mut,
#     "batch": batch_drug,
#     "edge_feat": edge_feat
# }

# _, explanation_results, related_preds = explainer(x, edge_index)

In [71]:
x, edge_index = data.x, data.edge_index
node_num = x.shape[0]    # use total node number as max_nodes makes the searching extremely slow!
_, explanation_results, related_preds = explainer(x, edge_index, max_nodes=node_num)

In [74]:
len(explanation_results)

128

In [78]:
len(explanation_results[1])

1931

In [79]:
explanation_results[1]

[{'data': DataBatch(x=[32, 334], edge_index=[2, 74], batch=[32], ptr=[2]),
  'coalition': [10, 11, 12],
  'ori_graph': <networkx.classes.graph.Graph at 0x7f96ff1e38e0>,
  'W': 0,
  'N': 0,
  'P': 0.00800465140491724},
 {'data': DataBatch(x=[32, 334], edge_index=[2, 74], batch=[32], ptr=[2]),
  'coalition': [2, 3, 4, 5, 6, 9, 10, 11, 12, 13, 19, 20],
  'ori_graph': <networkx.classes.graph.Graph at 0x7f96ff1e38e0>,
  'W': 0.007714778184890747,
  'N': 1,
  'P': 0.007884478196501732},
 {'data': DataBatch(x=[32, 334], edge_index=[2, 74], batch=[32], ptr=[2]),
  'coalition': [0, 1, 2, 5, 6, 9, 10, 11, 12, 13, 19, 20],
  'ori_graph': <networkx.classes.graph.Graph at 0x7f96ff1e38e0>,
  'W': 0.0077341836877167225,
  'N': 1,
  'P': 0.007884478196501732},
 {'data': DataBatch(x=[32, 334], edge_index=[2, 74], batch=[32], ptr=[2]),
  'coalition': [0, 1, 2, 5, 6, 9, 10, 11, 12, 19, 20],
  'ori_graph': <networkx.classes.graph.Graph at 0x7f96ff1e38e0>,
  'W': 0,
  'N': 0,
  'P': 0.007865823805332184},


## 

In [None]:
class CustomizedNet(GATNet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        
    def arguments_read(self, *args, **kwargs):

        data: Batch = kwargs.get('data') or None

        if not data:
            if not args:
                assert 'x' in kwargs
                assert 'edge_index' in kwargs
                x, edge_index = kwargs['x'], kwargs['edge_index'],
                batch = kwargs.get('batch')
                if batch is None:
                    batch = torch.zeros(kwargs['x'].shape[0], dtype=torch.int64, device=x.device)
            elif len(args) == 2:
                x, edge_index = args[0], args[1]
                batch = torch.zeros(args[0].shape[0], dtype=torch.int64, device=x.device)
            elif len(args) == 3:
                x, edge_index, batch = args[0], args[1], args[2]
            else:
                raise ValueError(f"forward's args should take 2 or 3 arguments but got {len(args)}")
        else:
            x, edge_index, batch = data.x, data.edge_index, data.batch

        return x, edge_index, batch