In [1]:
# Data loading
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

from joblib import load
from tqdm import trange
from tqdm.notebook import tqdm

# Graph dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import StratifiedKFold


# GNN Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, GATConv, GATv2Conv, SAGEConv


# Sparse vector
from Sparse_vector.sparse_vector import SparseVector

In [2]:
chrom_names = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y','M']]

features = [i[:-4] for i in os.listdir('z_dna/hg38_features/sparse/') if i.endswith('.pkl')]
groups = ['DNase-seq', 'Histone', 'RNA polymerase', 'TFs and others']
feature_names = [i for i in features]

In [3]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f'z_dna/hg38_dna/') if f"{chrom}_" in i])
    return ''.join([load(f"z_dna/hg38_dna/{file}") for file in files])

In [4]:
%%time
DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chrom_names)}
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_shin.pkl')
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

DNA_features = {feature: load(f'z_dna/hg38_features/sparse/{feature}.pkl')
                for feature in tqdm(feature_names)}

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/1946 [00:00<?, ?it/s]

CPU times: user 2min 12s, sys: 4.78 s, total: 2min 17s
Wall time: 2min 16s


In [None]:
from tqdm import tqdm
import sys

tqdm.pandas()

In [13]:
width = 100

In [14]:
import numpy as np
from itertools import product

# Функция для генерации подгрупп (subgroups)
def generate_subgroups(n):
    nucleotides = ['A', 'T', 'G', 'C']
    subgroups = []
    for i in range(1, n + 1):  # Генерация комбинаций длиной от 1 до n
        subgroups.extend([''.join(p) for p in product(nucleotides, repeat=i)])
    return subgroups

# Функция для кодирования последовательности
def encode_sequence_as_features_ndarray(n_str: str, k_str: str):
    n = len(n_str)
    k = len(k_str)
    result = np.zeros(n, dtype=int)
    
    for i in range(n - k + 1):
        if n_str[i:i+k] == k_str:
            result[i:i+k] = 1
    
    return result.tolist()

In [15]:
class GraphDataset(Dataset):
    def __init__(self, chroms, features,
                 dna_source, features_source,
                 labels, intervals,
                 transform=None, pre_transform=None, pre_filter=None):
        self.chroms = chroms
        self.features = features
        self.dna_source = dna_source
        self.features_source = features_source
        self.labels = labels
        self.intervals = intervals
        self.groups = generate_subgroups(4)
        self.k_mer = 4
        self.le = LabelBinarizer().fit(np.array([["A"], ["C"], ["T"], ["G"]]))

        self.ei = [[],[]]
        for i in range(width-1):
            self.ei[0].append(i)
            self.ei[0].append(i+1)
            self.ei[1].append(i+1)
            self.ei[1].append(i)
        super().__init__(transform, pre_transform, pre_filter)

    def len(self):
        return len(self.intervals)

    def get(self, idx):
        interval = self.intervals[idx]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        
        dna_OHE = []
        
        for group in self.groups:
            #print(group)
            featuress = encode_sequence_as_features_ndarray(self.dna_source[chrom][begin:end].upper(), group)
            
            dna_OHE.append(featuress)
        
        dna_OHE = list(map(list, zip(*dna_OHE)))
        dna_OHE = np.array(dna_OHE)

        feature_matr = []
        #for feature in self.features:
        #    source = self.features_source[feature]
        #    feature_matr.append(source[chrom][begin:end])

        if len(feature_matr) > 0:
            X = np.hstack((dna_OHE, np.array(feature_matr).T/1000)).astype(np.float32)
        else:
            X = dna_OHE.astype(np.float32)
        X = torch.tensor(X, dtype=torch.float)

        edge_index = torch.tensor(np.array(self.ei), dtype=torch.long)

        y = self.labels[interval[0]][interval[1]: interval[2]]
        y = torch.tensor(y, dtype=torch.int64)

        return Data(x=X.unsqueeze(0), edge_index=edge_index, y=y.unsqueeze(0))

In [16]:
np.random.seed(10)
width = 100

ints_in = []
ints_out = []

for chrm in chrom_names:
    for st in trange(0, ZDNA[chrm].shape - width, width):
        interval = [st, min(st + width, ZDNA[chrm].shape)]
        if ZDNA[chrm][interval[0]: interval[1]].any():
            ints_in.append([chrm, interval[0], interval[1]])
        else:
            ints_out.append([chrm, interval[0], interval[1]])

ints_in = np.array(ints_in)
ints_out = np.array(ints_out)[np.random.choice(range(len(ints_out)), size=len(ints_in) * 2, replace=False)]

100%|██████████████████████████████████████████████████████████████████████| 2489564/2489564 [00:42<00:00, 59128.79it/s]
100%|██████████████████████████████████████████████████████████████████████| 2421935/2421935 [00:38<00:00, 62800.02it/s]
100%|██████████████████████████████████████████████████████████████████████| 1982955/1982955 [00:30<00:00, 64041.42it/s]
100%|██████████████████████████████████████████████████████████████████████| 1902145/1902145 [00:30<00:00, 62759.60it/s]
100%|██████████████████████████████████████████████████████████████████████| 1815382/1815382 [00:29<00:00, 62437.20it/s]
100%|██████████████████████████████████████████████████████████████████████| 1708059/1708059 [00:27<00:00, 61474.06it/s]
100%|██████████████████████████████████████████████████████████████████████| 1593459/1593459 [00:24<00:00, 65748.77it/s]
100%|██████████████████████████████████████████████████████████████████████| 1451386/1451386 [00:24<00:00, 59584.86it/s]
100%|███████████████████████████

In [19]:
np.random.seed(42)
equalized = np.vstack((ints_in, ints_out))
equalized = [[inter[0], int(inter[1]), int(inter[2])] for inter in equalized]

train_inds, test_inds = next(StratifiedKFold().split(equalized, [f"{int(i < 400)}_{elem[0]}"
                                                                 for i, elem
                                                                 in enumerate(equalized)]))

train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]

In [148]:
def filter_gc_cg(sequences):
    # Создаём пустой список для результатов
    filtered_sequences = []
    
    # Перебираем каждую последовательность в массиве
    for seq in sequences:
        # Проверяем, содержит ли последовательность "GC" или "CG"
        if "GC" in seq or "CG" in seq:
            # Если да, добавляем её в результат
            filtered_sequences.append(seq)
    
    # Возвращаем отфильтрованный список
    return filtered_sequences

In [149]:
max_length = 4
groups = generate_subgroups(max_length)
features_count = len(groups)

np.random.seed(42)
features_count

340

In [20]:
np.random.seed(42)
train_dataset = GraphDataset(chrom_names, feature_names,
                            DNA, DNA_features,
                            ZDNA, train_intervals)

test_dataset = GraphDataset(chrom_names, feature_names,
                           DNA, DNA_features,
                           ZDNA, test_intervals)

In [21]:
np.random.seed(42)
params = {'batch_size':32,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

# GNN Model

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

class GraphZSAGEConv_13L(torch.nn.Module):
    def __init__(self, top_count):
        super(GraphZSAGEConv_13L, self).__init__()
        
        self.conv1 = SAGEConv(top_count, 1024)
        self.conv2 = SAGEConv(1024, 1024)
        self.conv3 = SAGEConv(1024, 512)
        self.conv4 = SAGEConv(512, 512)
        self.conv5 = SAGEConv(512, 256)
        
        self.conv6 = SAGEConv(256, 256)
        self.conv7 = SAGEConv(256, 128)
        self.conv8 = SAGEConv(128, 128)
        self.conv9 = SAGEConv(128, 64)
        self.conv10 = SAGEConv(64, 64)
        
        self.conv11 = SAGEConv(64, 32)
        self.conv12 = SAGEConv(32, 32)
        self.conv13 = SAGEConv(32, 2)
        
        self.norm1 = torch.nn.GroupNorm(num_groups=512, num_channels=1024)
        self.norm2 = torch.nn.GroupNorm(num_groups=512, num_channels=1024)
        self.norm3 = torch.nn.GroupNorm(num_groups=256, num_channels=512)
        self.norm4 = torch.nn.GroupNorm(num_groups=256, num_channels=512)
        self.norm5 = torch.nn.GroupNorm(num_groups=128, num_channels=256)
        self.norm6 = torch.nn.GroupNorm(num_groups=128, num_channels=256)
        self.norm7 = torch.nn.GroupNorm(num_groups=64, num_channels=128)
        self.norm8 = torch.nn.GroupNorm(num_groups=64, num_channels=128)
        self.norm9 = torch.nn.GroupNorm(num_groups=32, num_channels=64)
        self.norm10 = torch.nn.GroupNorm(num_groups=32, num_channels=64)
        self.norm11 = torch.nn.GroupNorm(num_groups=16, num_channels=32)
        self.norm12 = torch.nn.GroupNorm(num_groups=16, num_channels=32)
        

    def forward(self, x):
        x = self.conv1(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm1(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv2(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm2(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        
        x = self.conv3(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm3(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv4(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm4(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv5(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm5(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv6(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm6(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv7(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm7(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv8(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm8(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv9(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm9(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv10(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm10(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv11(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm11(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x, training=self.training)
        
        x = self.conv12(x, edge.to('cuda:1'))
        x = x.permute(0, 2, 1)
        x = self.norm12(x)
        x = x.permute(0, 2, 1)
        x = F.relu(x)
        #x = F.dropout(x)
        
        x = self.conv13(x, edge.cuda())

        return F.log_softmax(x, dim=-1)

In [None]:
model = GraphZSAGEConv_13L()
model= torch.load("model_GC_5.pt")
model = model.cuda()
model.eval()

GraphZSAGEConv_v5_lin(
  (conv1): SAGEConv(390, 1800, aggr=mean)
  (conv2): SAGEConv(1800, 1650, aggr=mean)
  (conv3): SAGEConv(1650, 1500, aggr=mean)
  (conv4): SAGEConv(1500, 1350, aggr=mean)
  (conv5): SAGEConv(1350, 1200, aggr=mean)
  (conv6): SAGEConv(1200, 1050, aggr=mean)
  (conv7): SAGEConv(1050, 900, aggr=mean)
  (conv8): SAGEConv(900, 750, aggr=mean)
  (conv9): SAGEConv(750, 600, aggr=mean)
  (conv10): SAGEConv(600, 450, aggr=mean)
  (conv11): SAGEConv(450, 300, aggr=mean)
  (conv12): SAGEConv(300, 150, aggr=mean)
  (conv13): SAGEConv(150, 64, aggr=mean)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=2, bias=True)
)

In [None]:
model = GraphZSAGEConv_13L()
model= torch.load("model1234_86_2.pt")
model = model.cuda()
model.eval()

GraphZSAGEConv_v5_lin(
  (conv1): SAGEConv(340, 1800, aggr=mean)
  (conv2): SAGEConv(1800, 1650, aggr=mean)
  (conv3): SAGEConv(1650, 1500, aggr=mean)
  (conv4): SAGEConv(1500, 1350, aggr=mean)
  (conv5): SAGEConv(1350, 1200, aggr=mean)
  (conv6): SAGEConv(1200, 1050, aggr=mean)
  (conv7): SAGEConv(1050, 900, aggr=mean)
  (conv8): SAGEConv(900, 750, aggr=mean)
  (conv9): SAGEConv(750, 600, aggr=mean)
  (conv10): SAGEConv(600, 450, aggr=mean)
  (conv11): SAGEConv(450, 300, aggr=mean)
  (conv12): SAGEConv(300, 150, aggr=mean)
  (conv13): SAGEConv(150, 64, aggr=mean)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=2, bias=True)
)

# Captum methods: IntegratedGradients

In [17]:
!pip install captum


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.1.2[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m


In [18]:
import captum
from captum.attr import IntegratedGradients, GradientShap, LayerGradCam, LRP
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap
from torch_geometric.explain import Explainer, CaptumExplainer

In [17]:


explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('IntegratedGradients'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [43]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = np.array(node_mask.cpu())
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]

done interpretation
CPU times: user 2h 53min 13s, sys: 35.3 s, total: 2h 53min 49s
Wall time: 2h 54min 5s


In [None]:
mean = mean_1 / cnt
print(mean.shape)
print(mean)

In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_1234_IG.pt')

# salency

In [19]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('Saliency'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [20]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [21]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = np.array(node_mask.cpu())
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]



done interpretation
CPU times: user 8min 31s, sys: 30.9 s, total: 9min 2s
Wall time: 9min 10s


In [22]:
mean = mean_1 / cnt
print(mean.shape)

(390,)


In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_Saliency_TP.pt')

# IxG

In [24]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('InputXGradient'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [25]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [26]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0
cc = 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    #torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = node_mask.cpu().detach().numpy()
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]

done interpretation
CPU times: user 8min 30s, sys: 26.9 s, total: 8min 57s
Wall time: 9min 4s


In [27]:
mean = mean_1 / cnt
print(mean.shape)

(390,)


In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_IxG_TP.pt')

# Deconvolution

In [29]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('Deconvolution'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [30]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [31]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    #torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = node_mask.cpu().detach().numpy()
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]



done interpretation
CPU times: user 8min 29s, sys: 26.3 s, total: 8min 56s
Wall time: 9min 4s


In [32]:
mean = mean_1 / cnt
print(mean.shape)

(390,)


In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_GC_Deconvolution_TP.pt')

# GuidedBackprop

In [34]:
explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('GuidedBackprop'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [35]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

In [36]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    #torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = node_mask.cpu().detach().numpy()
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]

done interpretation
CPU times: user 8min 38s, sys: 30.3 s, total: 9min 8s
Wall time: 9min 15s


In [37]:
mean = mean_1 / cnt
print(mean.shape)


(390,)


In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_GC_GuidedBackprop_TP.pt')

# GNN_explainer

In [39]:
import captum
from captum.attr import IntegratedGradients, GradientShap, LayerGradCam, LRP
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

In [40]:
from torch_geometric.explain import Explainer, GNNExplainer

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=50),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [41]:
%%time
mean_1 = np.zeros(features_count, dtype=float)
cnt= 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    #torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, features_count):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = np.array(node_mask.cpu())
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]

done interpretation
CPU times: user 3h 18min 47s, sys: 1min 32s, total: 3h 20min 19s
Wall time: 3h 20min 15s


In [42]:
mean = mean_1 / cnt
print(mean.shape)


(390,)


In [None]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_13L_GC_GNN_explainer_TP.pt')

# 5-меры с GC

In [None]:
#datа_IG = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_IG.pt")
datа_Salency = torch.load("mean_GraphZSAGEConv_13L_GC_Saliency_TP.pt")
datа_IxG = torch.load("mean_GraphZSAGEConv_13L_GC_IxG_TP.pt")
datа_Deconvolution = torch.load("mean_GraphZSAGEConv_13L_GC_Deconvolution_TP.pt")
datа_GuidedBackprop = torch.load("mean_GraphZSAGEConv_13L_GC_GuidedBackprop_TP.pt")
datа_GNN_explainer = torch.load("mean_GraphZSAGEConv_13L_GC_GNN_explainer_TP.pt")

In [8]:
df = pd.DataFrame({'Feature': groups,
                   #'Impact_IG': datа_IG,
                   'Impact_Saliency' : datа_Salency,
                   'Impact_IxG' : datа_IxG,
                   'Impact_Deconvolution' : datа_Deconvolution,
                   'Impact_GuidedBackprop' : datа_GuidedBackprop,
                  'Impact_GNN_explainer' : datа_GNN_explainer
                  })

#df['Impact_IG'] = np.abs(df['Impact_IG'])
df['Impact_Saliency'] = np.abs(df['Impact_Saliency'])
df['Impact_IxG'] = np.abs(df['Impact_IxG'])
df['Impact_Deconvolution'] = np.abs(df['Impact_Deconvolution'])
df['Impact_GuidedBackprop'] = np.abs(df['Impact_GuidedBackprop'])
df['Impact_GNN_explainer'] = np.abs(df['Impact_GNN_explainer'])

p_deviation = pd.DataFrame() # сюда будем собирать процентные средние

for column in df.columns:
    if column == 'Feature':
        continue
    
    mean = df[column].mean()
    p_deviation[f'{column}_p_deviation'] = np.abs((((df[column] - mean) / mean) * 100)) # считаем процентное среднее
    
p_deviation['mean_deviation'] = p_deviation.mean(axis=1)
p_deviation['Feature'] = df['Feature']
features_range = p_deviation[['Feature','mean_deviation']].sort_values(by='mean_deviation', ascending=False)

features_range.to_csv('result_5_GC_TP.csv')

In [None]:
features_range

# 6-меры с GC

In [9]:
#datа_IG = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_IG.pt")
datа_Salency = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_Saliency.pt")
datа_IxG = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_IxG.pt")
datа_Deconvolution = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_Deconvolution.pt")
datа_GuidedBackprop = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_GuidedBackprop.pt")
datа_GNN_explainer = torch.load("mean_GraphZSAGEConv_v5_lin_5_GC_GNN_explainer.pt")

In [10]:
df = pd.DataFrame({'Feature': groups,
                   #'Impact_IG': datа_IG,
                   'Impact_Saliency' : datа_Salency,
                   'Impact_IxG' : datа_IxG,
                   'Impact_Deconvolution' : datа_Deconvolution,
                   'Impact_GuidedBackprop' : datа_GuidedBackprop,
                  #'Impact_GNN_explainer' : datа_GNN_explainer
                  })

#df['Impact_IG'] = np.abs(df['Impact_IG'])
df['Impact_Saliency'] = np.abs(df['Impact_Saliency'])
df['Impact_IxG'] = np.abs(df['Impact_IxG'])
df['Impact_Deconvolution'] = np.abs(df['Impact_Deconvolution'])
df['Impact_GuidedBackprop'] = np.abs(df['Impact_GuidedBackprop'])
#df['Impact_GNN_explainer'] = np.abs(df['Impact_GNN_explainer'])

p_deviation = pd.DataFrame() # сюда будем собирать процентные средние

for column in df.columns:
    if column == 'Feature':
        continue
    
    mean = df[column].mean()
    p_deviation[f'{column}_p_deviation'] = (((df[column] - mean) / mean) * 100) # считаем процентное среднее
    
p_deviation['mean_deviation'] = p_deviation.mean(axis=1)
p_deviation['Feature'] = df['Feature']
features_range = p_deviation[['Feature','mean_deviation']].sort_values(by='mean_deviation', ascending=False)

features_range.to_csv('result_6_GC.csv')

In [11]:
features_range

Unnamed: 0,Feature,mean_deviation
360,CGCGC,535.473142
255,GCGCG,521.918843
221,GCACA,444.203700
359,CGCGG,428.336893
326,CGTGT,400.402222
...,...,...
23,ATCGA,-92.988782
90,TACGC,-94.519388
289,CTTGC,-95.395234
144,TCGGT,-96.247785
