In [1]:
# Data loading
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

from joblib import load
from tqdm import trange
from tqdm.notebook import tqdm

# Graph dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import StratifiedKFold


# GNN Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, GATConv, GATv2Conv, SAGEConv


# Sparse vector
from Sparse_vector.sparse_vector import SparseVector

In [2]:
from tqdm import tqdm
import sys
import time

tqdm.pandas()

In [3]:
chrom_names = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y','M']]

features = [i[:-4] for i in os.listdir('z_dna/hg38_features/sparse/') if i.endswith('_2.pkl')]
groups = ['DNase-seq', 'Histone', 'RNA polymerase', 'TFs and others']
feature_names = [i for i in features]

In [4]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f'z_dna/hg38_dna/') if f"{chrom}_" in i])
    return ''.join([load(f"z_dna/hg38_dna/{file}") for file in files])

In [5]:
%%time
DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chrom_names)}
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_shin.pkl')
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

DNA_features = {feature: load(f'z_dna/hg38_features/sparse/{feature}.pkl')
                for feature in tqdm(feature_names)}

100%|███████████████████████████████████████████████████████████████████████████████████| 25/25 [00:04<00:00,  6.11it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1946/1946 [00:52<00:00, 36.91it/s]

CPU times: user 52.3 s, sys: 4.84 s, total: 57.1 s
Wall time: 56.8 s





In [6]:
width = 100

In [7]:
import numpy as np
from itertools import product

def generate_fix_n_subgroups(n):
    nucleotides = ['A', 'T', 'G', 'C']
    subgroups = []
    subgroups.extend([''.join(p) for p in product(nucleotides, repeat=n)])
    return subgroups

def generate_subgroups(n):
    nucleotides = ['A', 'T', 'G', 'C']
    subgroups = []
    for i in range(1, n + 1):
        subgroups.extend([''.join(p) for p in product(nucleotides, repeat=i)])
    return subgroups

def encode_sequence_as_features_ndarray(n_str: str, k_str: str):
    n = len(n_str)
    k = len(k_str)
    result = np.zeros(n, dtype=int)
    
    for i in range(n - k + 1):
        if n_str[i:i+k] == k_str:
            result[i:i+k] = 1
    
    return result.tolist()

In [8]:
from torch_geometric.data import Dataset, Data
class GraphDataset(Dataset):
    def __init__(self, chroms, features,
                 dna_source, features_source,
                 labels, intervals, k_mer=1, groups=['A','T','G','C'],
                 transform=None, pre_transform=None, pre_filter=None):
        self.chroms = chroms
        self.features = features
        self.dna_source = dna_source
        self.features_source = features_source
        self.labels = labels
        self.intervals = intervals
        
        self.k_mer = k_mer
        self.groups = groups

        self.ei = [[],[]]
        for i in range(width-1):
            self.ei[0].append(i)
            self.ei[0].append(i+1)
            self.ei[1].append(i+1)
            self.ei[1].append(i)
        super().__init__(transform, pre_transform, pre_filter)

    def len(self):
        return len(self.intervals)

    def get(self, idx):
        interval = self.intervals[idx]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        
        dna_OHE = []
        
        for group in self.groups:
            featuress = encode_sequence_as_features_ndarray(self.dna_source[chrom][begin:end].upper(), group)
            
            dna_OHE.append(featuress)
        
        dna_OHE = list(map(list, zip(*dna_OHE)))
        dna_OHE = np.array(dna_OHE)
        feature_matr = []

        if len(feature_matr) > 0:
            X = np.hstack((dna_OHE, np.array(feature_matr).T/1000)).astype(np.float32)
        else:
            X = dna_OHE.astype(np.float32)
        X = torch.tensor(X, dtype=torch.float)

        edge_index = torch.tensor(np.array(self.ei), dtype=torch.long)

        y = self.labels[interval[0]][interval[1]: interval[2]]
        y = torch.tensor(y, dtype=torch.int64)

        return Data(x=X.unsqueeze(0), edge_index=edge_index, y=y.unsqueeze(0))

In [9]:
np.random.seed(10)
width = 100

ints_in = []
ints_out = []

for chrm in chrom_names:
    for st in trange(0, ZDNA[chrm].shape - width, width):
        interval = [st, min(st + width, ZDNA[chrm].shape)]
        if ZDNA[chrm][interval[0]: interval[1]].any():
            ints_in.append([chrm, interval[0], interval[1]])
        else:
            ints_out.append([chrm, interval[0], interval[1]])

ints_in = np.array(ints_in)
ints_out = np.array(ints_out)[np.random.choice(range(len(ints_out)), size=len(ints_in) * 2, replace=False)]

100%|██████████████████████████████████████████████████████████████████████| 2489564/2489564 [00:34<00:00, 71973.87it/s]
100%|██████████████████████████████████████████████████████████████████████| 2421935/2421935 [00:32<00:00, 73930.02it/s]
100%|██████████████████████████████████████████████████████████████████████| 1982955/1982955 [00:26<00:00, 74496.02it/s]
100%|██████████████████████████████████████████████████████████████████████| 1902145/1902145 [00:25<00:00, 73785.41it/s]
100%|██████████████████████████████████████████████████████████████████████| 1815382/1815382 [00:24<00:00, 72735.30it/s]
100%|██████████████████████████████████████████████████████████████████████| 1708059/1708059 [00:22<00:00, 76424.60it/s]
100%|██████████████████████████████████████████████████████████████████████| 1593459/1593459 [00:22<00:00, 71818.01it/s]
100%|██████████████████████████████████████████████████████████████████████| 1451386/1451386 [00:19<00:00, 76112.07it/s]
100%|███████████████████████████

In [10]:
from sklearn.model_selection import StratifiedShuffleSplit
np.random.seed(42)
equalized = np.vstack((ints_in, ints_out))
equalized = [[inter[0], int(inter[1]), int(inter[2])] for inter in equalized]

labels = np.array([1]*len(ints_in) + [0]*len(ints_out))
chromes = [inter[0] for inter in equalized]
strat_labels = np.array([f"{label}_{chrom}" for label, chrom in zip(labels, chromes)])
print(strat_labels)
  
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_inds, test_inds = next(sss.split(equalized, strat_labels))
  
train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]

['1_chr1' '1_chr1' '1_chr1' ... '0_chr18' '0_chr8' '0_chr6']


In [11]:
np.random.seed(42)

max_k_mer_size = 5
groups = generate_subgroups(max_k_mer_size)
feature_count = len(groups)

train_dataset = GraphDataset(chrom_names, feature_names,
                            DNA, DNA_features,
                            ZDNA, train_intervals, max_k_mer_size, groups)

test_dataset = GraphDataset(chrom_names, feature_names,
                           DNA, DNA_features,
                           ZDNA, test_intervals, max_k_mer_size, groups)

np.random.seed(42)
params = {'batch_size':32,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [12]:
def set_random_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [13]:
set_random_seed(42)

# GNN Model

In [14]:
from torch import nn
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score, f1_score, average_precision_score, recall_score
from IPython.display import clear_output

class GraphZSAGEConv_13L(torch.nn.Module):
    def __init__(self, feature_count):
        super(GraphZSAGEConv_13L, self).__init__()
        
        self.conv1 = SAGEConv(feature_count, 1024)
        self.conv2 = SAGEConv(1024, 1024)
        self.conv3 = SAGEConv(1024, 512)
        self.conv4 = SAGEConv(512, 512)
        self.conv5 = SAGEConv(512, 256)
        
        self.conv6 = SAGEConv(256, 256)
        self.conv7 = SAGEConv(256, 128)
        self.conv8 = SAGEConv(128, 128)
        self.conv9 = SAGEConv(128, 64)
        self.conv10 = SAGEConv(64, 64)
        
        self.conv11 = SAGEConv(64, 32)
        self.conv12 = SAGEConv(32, 32)
        self.conv13 = SAGEConv(32, 2)
        
        self.norm1 = torch.nn.GroupNorm(num_groups=512, num_channels=1024)
        self.norm2 = torch.nn.GroupNorm(num_groups=512, num_channels=1024)
        self.norm3 = torch.nn.GroupNorm(num_groups=256, num_channels=512)
        self.norm4 = torch.nn.GroupNorm(num_groups=256, num_channels=512)
        self.norm5 = torch.nn.GroupNorm(num_groups=128, num_channels=256)
        self.norm6 = torch.nn.GroupNorm(num_groups=128, num_channels=256)
        self.norm7 = torch.nn.GroupNorm(num_groups=64, num_channels=128)
        self.norm8 = torch.nn.GroupNorm(num_groups=64, num_channels=128)
        self.norm9 = torch.nn.GroupNorm(num_groups=32, num_channels=64)
        self.norm10 = torch.nn.GroupNorm(num_groups=32, num_channels=64)
        self.norm11 = torch.nn.GroupNorm(num_groups=16, num_channels=32)
        self.norm12 = torch.nn.GroupNorm(num_groups=16, num_channels=32)
        

    def forward(self, x, edge):
        
        if x.dim() == 2:
            x = x.unsqueeze(0)
            
        x = self.conv1(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm1(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv2(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm2(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        
        x = self.conv3(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm3(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv4(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm4(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv5(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm5(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv6(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm6(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv7(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm7(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv8(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm8(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv9(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm9(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv10(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm10(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv11(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm11(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv12(x, edge)
        x = x.permute(0, 2, 1)
        x = self.norm12(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x)
        
        x = self.conv13(x, edge)

        return F.log_softmax(x, dim=-1)

In [15]:
model = GraphZSAGEConv_13L(feature_count)
model.load_state_dict(torch.load("GraphZSAGEConv_11L_1_5_k_mers.pt"))
#model = model.cuda()
model.eval()

GraphZSAGEConv_13L(
  (conv1): SAGEConv(1364, 1024, aggr=mean)
  (conv2): SAGEConv(1024, 1024, aggr=mean)
  (conv3): SAGEConv(1024, 512, aggr=mean)
  (conv4): SAGEConv(512, 512, aggr=mean)
  (conv5): SAGEConv(512, 256, aggr=mean)
  (conv6): SAGEConv(256, 256, aggr=mean)
  (conv7): SAGEConv(256, 128, aggr=mean)
  (conv8): SAGEConv(128, 128, aggr=mean)
  (conv9): SAGEConv(128, 64, aggr=mean)
  (conv10): SAGEConv(64, 64, aggr=mean)
  (conv11): SAGEConv(64, 32, aggr=mean)
  (conv12): SAGEConv(32, 32, aggr=mean)
  (conv13): SAGEConv(32, 2, aggr=mean)
  (norm1): GroupNorm(512, 1024, eps=1e-05, affine=True)
  (norm2): GroupNorm(512, 1024, eps=1e-05, affine=True)
  (norm3): GroupNorm(256, 512, eps=1e-05, affine=True)
  (norm4): GroupNorm(256, 512, eps=1e-05, affine=True)
  (norm5): GroupNorm(128, 256, eps=1e-05, affine=True)
  (norm6): GroupNorm(128, 256, eps=1e-05, affine=True)
  (norm7): GroupNorm(64, 128, eps=1e-05, affine=True)
  (norm8): GroupNorm(64, 128, eps=1e-05, affine=True)
  (norm9

# Captum methods: IntegratedGradients

In [16]:
!pip install captum


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [65]:
import captum
from captum.attr import IntegratedGradients, GradientShap, LayerGradCam, LRP
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from torch_geometric.explain import Explainer, CaptumExplainer

In [27]:
explainer = Explainer(
    model=wrappedmodel,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [28]:
features_count = len(groups)
features_count

1364

In [29]:
%%time

mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

device = torch.device(f'cuda:{1}')
with torch.cuda.device(device):
    model = model.to('cuda:1')
    for dt in tqdm(loader_test):
        
        x, edge, y = dt.x.to('cuda:1'), dt.edge_index.to('cuda:1'), dt.y.to('cuda:1').long()
        
        output = model(x, edge)
        pred = torch.argmax(output, dim=-1)
        
        idxs = []
        for i in range(width):
            if pred[0][i] == y[0][i] and y[0][i] == 1:
                idxs.append(i)
        
        torch.cuda.empty_cache()
        explanation = explainer(x.squeeze(), edge, target=1)
        node_mask = explanation.node_mask

        if node_mask[idxs, :].shape != (0, features_count):
            node_mask = torch.mean(node_mask[idxs, :], dim=0)
            node_mask = np.array(node_mask.cpu())
            mean_1 += node_mask
            cnt += 1

print('done interpretation')

100%|██████████████████████████████████████████████████████████████████████████| 27121/27121 [10:46:28<00:00,  1.43s/it]

done interpretation
CPU times: user 10h 43min 49s, sys: 1min 45s, total: 10h 45min 34s
Wall time: 10h 46min 28s





In [30]:
mean = mean_1 / cnt
print(mean.shape)
print(mean)

(1364,)
[ 1.32936813e-03  7.77668408e-04  2.92216314e-03 ...  1.21453108e-04
 -1.73796021e-03  3.75329942e-05]


In [31]:
mean

array([ 1.32936813e-03,  7.77668408e-04,  2.92216314e-03, ...,
        1.21453108e-04, -1.73796021e-03,  3.75329942e-05])

In [32]:
torch.save(torch.from_numpy(mean), 'GraphZSAGEConv_13L_1_5_k_mers_IG.pt')

# salency

In [17]:
class WrappedModel(torch.nn.Module):
    def __init__(self, base_model):
        super().__init__()
        self.base_model = base_model

    def forward(self, x, edge_index):
        out = self.base_model(x, edge_index) 
        a = out[:, :, 1] 
        return a 

In [18]:
wrappedmodel = WrappedModel(model)

In [19]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=wrappedmodel,
    algorithm=CaptumExplainer('Saliency'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [20]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [22]:
features_count = len(groups)
features_count

1364

In [23]:
%%time

mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

device = torch.device(f'cuda:{1}')
with torch.cuda.device(device):
    model = model.to('cuda:1')
    for dt in tqdm(loader_test):
        
        x, edge, y = dt.x.to('cuda:1'), dt.edge_index.to('cuda:1'), dt.y.to('cuda:1').long()
        
        output = model(x, edge)
        pred = torch.argmax(output, dim=-1)
        
        idxs = []
        for i in range(width):
            if pred[0][i] == y[0][i] and y[0][i] == 1:
                idxs.append(i)
        
        torch.cuda.empty_cache()
        explanation = explainer(x.squeeze(), edge, target=1)
        node_mask = explanation.node_mask

        if node_mask[idxs, :].shape != (0, features_count):
            node_mask = torch.mean(node_mask[idxs, :], dim=0)
            node_mask = np.array(node_mask.cpu())
            mean_1 += node_mask
            cnt += 1

print('done interpretation')

100%|█████████████████████████████████████████████████████████████████████████████| 27121/27121 [28:05<00:00, 16.09it/s]

done interpretation
CPU times: user 26min 48s, sys: 49.6 s, total: 27min 37s
Wall time: 28min 5s





In [24]:
mean = mean_1 / cnt
print(mean.shape)

(1364,)


In [25]:
mean

array([0.00147995, 0.00129851, 0.00237362, ..., 0.00142435, 0.00493159,
       0.00426171])

In [26]:
torch.save(torch.from_numpy(mean), 'GraphZSAGEConv_13L_1_5_k_mers_Saliency.pt')

In [138]:
torch.load('GraphZSAGEConv_13L_1_5_k_mers_Saliency.pt')

tensor([1.7642e-02, 1.9714e-02, 2.2653e-02,  ..., 2.6891e-09, 5.5811e-05,
        1.6170e-03], dtype=torch.float64)

# IxG

In [52]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=wrappedmodel,
    algorithm=CaptumExplainer('InputXGradient'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [53]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [54]:
%%time

mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

device = torch.device(f'cuda:{1}')
with torch.cuda.device(device):
    model = model.to('cuda:1')
   
    for dt in tqdm(loader_test):
        
        x, edge, y = dt.x.to('cuda:1'), dt.edge_index.to('cuda:1'), dt.y.to('cuda:1').long()
     
        output = model(x, edge)
        pred = torch.argmax(output, dim=-1)
      
        idxs = []
        for i in range(width):
            if pred[0][i] == y[0][i] and y[0][i] == 1:
                idxs.append(i)
      
        torch.cuda.empty_cache()
        explanation = explainer(x.squeeze(), edge)
        node_mask = explanation.node_mask
        
        if node_mask[idxs, :].shape != (0, features_count):
            node_mask = torch.mean(node_mask[idxs, :], dim=0)
            node_mask = np.array(node_mask.detach().cpu())
            mean_1 += node_mask
            cnt += 1

print('done interpretation')

100%|█████████████████████████████████████████████████████████████████████████████| 27121/27121 [40:28<00:00, 11.17it/s]

done interpretation
CPU times: user 35min 32s, sys: 4min 31s, total: 40min 3s
Wall time: 40min 28s





In [55]:
mean = mean_1 / cnt
print(mean.shape)

(1364,)


In [56]:
mean

array([-1.44478283e-05, -4.28186405e-05, -2.65626381e-04, ...,
       -3.17674658e-06, -9.58518565e-05, -2.63277797e-05])

In [57]:
torch.save(torch.from_numpy(mean), 'GraphZSAGEConv_13L_1_5_k_mers_IxG.pt')

# GNN_explainer

In [58]:
import captum
from captum.attr import IntegratedGradients, GradientShap, LayerGradCam, LRP
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

In [59]:
from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=wrappedmodel,
    algorithm=GNNExplainer(epochs=50),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [60]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [61]:
%%time

mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

device = torch.device(f'cuda:{1}')
with torch.cuda.device(device):
    model = model.to('cuda:1')
   
    for dt in tqdm(loader_test):
        
        x, edge, y = dt.x.to('cuda:1'), dt.edge_index.to('cuda:1'), dt.y.to('cuda:1').long()
     
        output = model(x, edge)
        pred = torch.argmax(output, dim=-1)
      
        idxs = []
        for i in range(width):
            if pred[0][i] == y[0][i] and y[0][i] == 1:
                idxs.append(i)
      
        torch.cuda.empty_cache()
        explanation = explainer(x.squeeze(), edge)
        node_mask = explanation.node_mask
        
        if node_mask[idxs, :].shape != (0, features_count):
            node_mask = torch.mean(node_mask[idxs, :], dim=0)
            node_mask = np.array(node_mask.detach().cpu())
            mean_1 += node_mask
            cnt += 1

print('done interpretation')

100%|██████████████████████████████████████████████████████████████████████████| 27121/27121 [12:04:24<00:00,  1.60s/it]

done interpretation
CPU times: user 11h 45min 40s, sys: 18min 31s, total: 12h 4min 12s
Wall time: 12h 4min 24s





In [62]:
mean = mean_1 / cnt
print(mean.shape)


(1364,)


In [63]:
mean

array([0.05584951, 0.05864221, 0.18252613, ..., 0.00302195, 0.01032195,
       0.00502422])

In [64]:
torch.save(torch.from_numpy(mean), 'GraphZSAGEConv_13L_1_5_k_mers_GNN_explainer.pt')