In [1]:
!git clone https://github.com/vladislareon/Sparse_vector
!git clone https://github.com/vladislareon/z_dna

Cloning into 'Sparse_vector'...
remote: Enumerating objects: 8, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (5/5), done.[K
remote: Total 8 (delta 0), reused 0 (delta 0), pack-reused 0[K
Unpacking objects: 100% (8/8), done.
Cloning into 'z_dna'...
remote: Enumerating objects: 2052, done.[K
remote: Counting objects: 100% (31/31), done.[K
remote: Compressing objects: 100% (25/25), done.[K
remote: Total 2052 (delta 8), reused 0 (delta 0), pack-reused 2021[K
Receiving objects: 100% (2052/2052), 1.75 GiB | 10.06 MiB/s, done.
Resolving deltas: 100% (8/8), done.
Checking out files: 100% (2024/2024), done.


In [22]:
!git clone https://github.com/vladislareon/Interpretation

Cloning into 'Interpretation'...
remote: Enumerating objects: 37, done.[K
remote: Counting objects: 100% (37/37), done.[K
remote: Compressing objects: 100% (35/35), done.[K
remote: Total 37 (delta 14), reused 2 (delta 0), pack-reused 0[K
Unpacking objects: 100% (37/37), done.


In [1]:
# Data loading
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score, f1_score
from IPython.display import clear_output

from joblib import load
from tqdm import trange
from tqdm.notebook import tqdm


# Graph dataset
from torch_geometric.loader import DataLoader
from torch_geometric.data import Dataset, Data

from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import StratifiedKFold


# GNN Model
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, GraphConv, GATConv, GATv2Conv, SAGEConv


# Sparse vector
from Sparse_vector.sparse_vector import SparseVector

In [2]:
chrom_names = [f'chr{i}' for i in list(range(1, 23)) + ['X', 'Y','M']]

features = [i[:-4] for i in os.listdir('z_dna/hg38_features/sparse/') if i.endswith('.pkl')]
groups = ['DNase-seq', 'Histone', 'RNA polymerase', 'TFs and others']
feature_names = [i for i in features]

In [3]:
def chrom_reader(chrom):
    files = sorted([i for i in os.listdir(f'z_dna/hg38_dna/') if f"{chrom}_" in i])
    return ''.join([load(f"z_dna/hg38_dna/{file}") for file in files])

In [4]:
%%time
DNA = {chrom:chrom_reader(chrom) for chrom in tqdm(chrom_names)}
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_shin.pkl')
#ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

ZDNA = load('z_dna/hg38_zdna/sparse/ZDNA_cousine.pkl')

DNA_features = {feature: load(f'z_dna/hg38_features/sparse/{feature}.pkl')
                for feature in tqdm(feature_names)}

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/1946 [00:00<?, ?it/s]

CPU times: user 1min 32s, sys: 4.18 s, total: 1min 36s
Wall time: 1min 36s


In [5]:
width = 100

In [6]:
class GraphDataset(Dataset):
    def __init__(self, chroms, features,
                 dna_source, features_source,
                 labels, intervals,
                 transform=None, pre_transform=None, pre_filter=None):
        self.chroms = chroms
        self.features = features
        self.dna_source = dna_source
        self.features_source = features_source
        self.labels = labels
        self.intervals = intervals
        self.le = LabelBinarizer().fit(np.array([["A"], ["C"], ["T"], ["G"]]))

        self.ei = [[],[]]
        for i in range(width-1):
            self.ei[0].append(i)
            self.ei[0].append(i+1)
            self.ei[1].append(i+1)
            self.ei[1].append(i)
        super().__init__(transform, pre_transform, pre_filter)

    def len(self):
        return len(self.intervals)

    def get(self, idx):
        interval = self.intervals[idx]
        chrom = interval[0]
        begin = int(interval[1])
        end = int(interval[2])
        dna_OHE = self.le.transform(list(self.dna_source[chrom][begin:end].upper()))

        feature_matr = []
        for feature in self.features:
            source = self.features_source[feature]
            feature_matr.append(source[chrom][begin:end])

        if len(feature_matr) > 0:
            X = np.hstack((dna_OHE, np.array(feature_matr).T/1000)).astype(np.float32)
        else:
            X = dna_OHE.astype(np.float32)
        X = torch.tensor(X, dtype=torch.float)

        edge_index = torch.tensor(np.array(self.ei), dtype=torch.long)

        y = self.labels[interval[0]][interval[1]: interval[2]]
        y = torch.tensor(y, dtype=torch.int64)

        return Data(x=X.unsqueeze(0), edge_index=edge_index, y=y.unsqueeze(0))

In [7]:
np.random.seed(10)

ints_in = []
ints_out = []

for chrm in chrom_names:
    for st in trange(0, ZDNA[chrm].shape - width, width):
        interval = [st, min(st + width, ZDNA[chrm].shape)]
        if ZDNA[chrm][interval[0]: interval[1]].any():
            ints_in.append([chrm, interval[0], interval[1]])
        else:
            ints_out.append([chrm, interval[0], interval[1]])

ints_in = np.array(ints_in)
ints_out = np.array(ints_out)[np.random.choice(range(len(ints_out)), size=len(ints_in) * 2, replace=False)]

100%|█████████████████████████████████████████████████████████████████████████████████████████| 2489564/2489564 [00:36<00:00, 68861.41it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 2421935/2421935 [00:33<00:00, 71425.22it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1982955/1982955 [00:28<00:00, 69682.75it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1902145/1902145 [00:26<00:00, 70926.84it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1815382/1815382 [00:25<00:00, 72339.66it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1708059/1708059 [00:25<00:00, 68040.11it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1593459/1593459 [00:21<00:00, 72790.66it/s]
100%|███████████████

In [8]:
np.random.seed(42)
equalized = ints_in
equalized = [[inter[0], int(inter[1]), int(inter[2])] for inter in equalized]

train_inds, test_inds = next(StratifiedKFold().split(equalized, [f"{int(i < 400)}_{elem[0]}"
                                                                 for i, elem
                                                                 in enumerate(equalized)]))

train_intervals, test_intervals = [equalized[i] for i in train_inds], [equalized[i] for i in test_inds]

In [9]:
np.random.seed(42)
train_dataset = GraphDataset(chrom_names, feature_names,
                            DNA, DNA_features,
                            ZDNA, train_intervals)

test_dataset = GraphDataset(chrom_names, feature_names,
                           DNA, DNA_features,
                           ZDNA, test_intervals)

In [10]:
np.random.seed(42)
params = {'batch_size':1,
          'num_workers':4,
          'shuffle':True}

loader_train = DataLoader(train_dataset, **params)
loader_test = DataLoader(test_dataset, **params)

# GNN Model

In [11]:
class GraphZSAGEConv_v5_lin(torch.nn.Module):
    def __init__(self):
        super(GraphZSAGEConv_v5_lin, self).__init__()
        self.conv1 = SAGEConv(1950, 1800)
        self.conv2 = SAGEConv(1800, 1650)
        self.conv3 = SAGEConv(1650, 1500)
        self.conv4 = SAGEConv(1500, 1350)
        self.conv5 = SAGEConv(1350, 1200)
        self.conv6 = SAGEConv(1200, 1050)
        self.conv7 = SAGEConv(1050, 900)
        self.conv8 = SAGEConv(900, 750)
        self.conv9 = SAGEConv(750, 600)
        self.conv10 = SAGEConv(600, 450)
        self.conv11 = SAGEConv(450, 300)
        self.conv12 = SAGEConv(300, 150)
        self.conv13 = SAGEConv(150, 64)

        self.fc1 = torch.nn.Linear(64, 32)
        self.fc2 = torch.nn.Linear(32, 2)

    def forward(self, x, edge):
        x = self.conv1(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv2(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv3(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv4(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv5(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv6(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv7(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv8(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv9(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv10(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv11(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv12(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = self.conv13(x, edge)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        
        x = F.relu(self.fc1(x))
        x = self.fc2(x)

        return F.log_softmax(x, dim=-1)


In [12]:
model = GraphZSAGEConv_v5_lin()
model.load_state_dict(torch.load("Cousine_GraphZSAGEConv_v5_lin_F1=77.75_epoch=17.pt"))
model = model.cuda()
model.eval()

GraphZSAGEConv_v5_lin(
  (conv1): SAGEConv(1950, 1800, aggr=mean)
  (conv2): SAGEConv(1800, 1650, aggr=mean)
  (conv3): SAGEConv(1650, 1500, aggr=mean)
  (conv4): SAGEConv(1500, 1350, aggr=mean)
  (conv5): SAGEConv(1350, 1200, aggr=mean)
  (conv6): SAGEConv(1200, 1050, aggr=mean)
  (conv7): SAGEConv(1050, 900, aggr=mean)
  (conv8): SAGEConv(900, 750, aggr=mean)
  (conv9): SAGEConv(750, 600, aggr=mean)
  (conv10): SAGEConv(600, 450, aggr=mean)
  (conv11): SAGEConv(450, 300, aggr=mean)
  (conv12): SAGEConv(300, 150, aggr=mean)
  (conv13): SAGEConv(150, 64, aggr=mean)
  (fc1): Linear(in_features=64, out_features=32, bias=True)
  (fc2): Linear(in_features=32, out_features=2, bias=True)
)

# Captum methods: GuidedBackprop

In [13]:
import captum
from captum.attr import IntegratedGradients, GradientShap, LayerGradCam, LRP
from captum.attr import visualization as viz
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

In [14]:
from torch_geometric.explain import Explainer, CaptumExplainer

explainer = Explainer(
    model=model,
    algorithm=CaptumExplainer('GuidedBackprop'),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='probs',
    ),
)

In [15]:
%%time
mean_1 = np.zeros(1950, dtype=float)
cnt= 0

for dt in tqdm(loader_test):
    x, edge, y = dt.x.cuda(), dt.edge_index.cuda(), dt.y.cuda().long()
    valid_edges = (edge < width).all(dim=0)
    edge = edge[:, valid_edges]

    output = model(x, edge.squeeze())
    pred = torch.argmax(output, dim=-1)

    # find True Positive indices
    idxs = []
    for i in range(width):
        if pred[0][i] == y[0][i] and y[0][i] == 1:
            idxs.append(i)

    #torch.cuda.empty_cache()
    explanation = explainer(x.squeeze(), edge)
    #explanation.visualize_feature_importance(top_k=10)
    
    node_mask = explanation.node_mask

    if node_mask[idxs, :].shape != (0, 1950):
        node_mask = torch.mean(node_mask[idxs, :], dim=0)
        node_mask = node_mask.cpu().detach().numpy()
        mean_1 += node_mask
        cnt += 1


print('done interpretation')

  0%|          | 0/9041 [00:00<?, ?it/s]



done interpretation
CPU times: user 9min 39s, sys: 1min 2s, total: 10min 42s
Wall time: 10min 47s


In [16]:
mean = mean_1 / cnt
print(mean.shape)
print(mean)

(1950,)
[-3.75372952e-01  5.90027887e-02  3.73165539e-02 ... -5.61160422e-12
  1.03455687e-02 -5.15065494e-04]


In [17]:
torch.save(torch.from_numpy(mean), 'mean_GraphZSAGEConv_v5_lin_GuidedBackprop.pt')