In [1]:
# Import package
import numpy as np
import random
from numpy.core.numeric import ones_like
from torch_geometric.data import Data, DataLoader
import torch
from tqdm import tqdm
from utils.debruijn import DeBruijnGraph

In [2]:
# Define tf and datapaths
TF = 'ctcf'

PATH_ = "/DATA/yogesh/encodeDream/%s/"%(TF)
# Using readlines()
file1 = open(PATH_+'processed_out/final_training.csv', 'r')
Lines = file1.readlines()

In [None]:
# Visualize the distribution of +ve samples in data
all_pos = []
window = []
for i in tqdm(range(0,len(Lines),1000000)):
    pos = 0
    for line in Lines[i:i+1000000]:
        if line.strip().split(',')[1]=='B':
            pos+=1
    all_pos.append(pos)
    window.append(str(int(i/1000000))+':'+str(int(i/1000000+10)))
    
import matplotlib.pyplot as plt
plt.bar(window, all_pos)
plt.ylabel('Number of bound datapoints')
plt.xlabel('Datapoint window (in lacs)')
plt.xticks(rotation=90)

plt.show()

In [5]:
# Sample +ve and -ve samples from data
DATAPOINTS = 200000
count_b, count_u = 0, 0
op = open(PATH_+'processed_out/balanced_training.csv',"w")
    
for line in tqdm(Lines):
    if line.strip().split(',')[1]=='B':
        count_b+=1
        op.write(line)
    
    if count_b>int(DATAPOINTS*0.2):
        break
        
for line in tqdm(Lines):
    if line.strip().split(',')[1]=='U':
        count_u+=1
        op.write(line)
    if count_u>int(DATAPOINTS*0.8):
        break

  9%|▊         | 4071330/47549176 [00:04<00:43, 1008649.20it/s]
  0%|          | 160891/47549176 [00:00<01:08, 688099.90it/s]


In [6]:
# Import balanced data
file1 = open(PATH_+'processed_out/balanced_training.csv', 'r')
Lines = file1.readlines()

In [7]:
# Prepare Data
kmer, DATALIST = 5, []
counter = 0
for line in tqdm(Lines):
    onehot_x = []
    # get DNA sequence from the dummy file 
    s = line.strip().split(',')[0]
    d=DeBruijnGraph(s,kmer)    
    for node in d.x.flatten():
        one_hot_ = d.one_hot_encode(node).flatten()
        onehot_x.append(one_hot_.tolist())

    # Arrays to pytorch tensors
    onehot_x_tensor = torch.tensor(np.array(onehot_x), dtype=torch.float)
    onehot_edge_index_tensor = torch.tensor(d.edge_index, dtype=torch.long)
    
    # get the label from the data
    if line.strip().split(',')[1]=='U':
        y_tensor = torch.tensor(0)
    else:
        y_tensor = torch.tensor(1)

    # Add tensors to torch_geometric data object
    data = Data(x=onehot_x_tensor, edge_index=onehot_edge_index_tensor, y=y_tensor)

    DATALIST.append(data)
    counter+=1
    if counter>=300000:
        break

100%|██████████| 199998/199998 [24:51<00:00, 134.12it/s]


In [8]:
print('Datapoints:', len(DATALIST))

print('x_shape:', DATALIST[1].x.shape)
print('edge_index_shape:', DATALIST[1].edge_index.shape)
print()

Datapoints: 199998
x_shape: torch.Size([94, 16])
edge_index_shape: torch.Size([2, 196])



In [9]:
temp_data = DATALIST[1]  # Get the first graph object.

print()
print(temp_data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {temp_data.num_nodes}')
print(f'Number of edges: {temp_data.num_edges}')
print(f'Average node degree: {temp_data.num_edges / temp_data.num_nodes:.2f}')
print(f'Contains self-loops: {temp_data.contains_self_loops()}')
print(f'Is undirected: {temp_data.is_undirected()}')


Data(edge_index=[2, 196], x=[94, 16], y=1)
Number of nodes: 94
Number of edges: 196
Average node degree: 2.09
Contains self-loops: True
Is undirected: False


In [10]:
# Split data
torch.manual_seed(11)
random.shuffle(DATALIST)

split_n = int(0.8*(len(DATALIST)))

train_dataset = DATALIST[:split_n]
test_dataset = DATALIST[split_n:]

print('Number of training graph:', len(train_dataset))
print('Number of testing graph:', len(test_dataset))

Number of training graph: 159998
Number of testing graph: 40000


In [11]:
#####
from torch_geometric.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=True)

for step, data in enumerate(train_loader):
    print(f'Step {step+1}:')
    print('=====')
    print(f'Number of graphs in current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
=====
Number of graphs in current batch: 512
Batch(batch=[60305], edge_index=[2, 100352], ptr=[513], x=[60305, 16], y=[512])

Step 2:
=====
Number of graphs in current batch: 512
Batch(batch=[60956], edge_index=[2, 100352], ptr=[513], x=[60956, 16], y=[512])

Step 3:
=====
Number of graphs in current batch: 512
Batch(batch=[61022], edge_index=[2, 100352], ptr=[513], x=[61022, 16], y=[512])

Step 4:
=====
Number of graphs in current batch: 512
Batch(batch=[61339], edge_index=[2, 100352], ptr=[513], x=[61339, 16], y=[512])

Step 5:
=====
Number of graphs in current batch: 512
Batch(batch=[60464], edge_index=[2, 100352], ptr=[513], x=[60464, 16], y=[512])

Step 6:
=====
Number of graphs in current batch: 512
Batch(batch=[61139], edge_index=[2, 100352], ptr=[513], x=[61139, 16], y=[512])

Step 7:
=====
Number of graphs in current batch: 512
Batch(batch=[60846], edge_index=[2, 100352], ptr=[513], x=[60846, 16], y=[512])

Step 8:
=====
Number of graphs in current batch: 512
Batch(bat

Step 63:
=====
Number of graphs in current batch: 512
Batch(batch=[60786], edge_index=[2, 100352], ptr=[513], x=[60786, 16], y=[512])

Step 64:
=====
Number of graphs in current batch: 512
Batch(batch=[60727], edge_index=[2, 100352], ptr=[513], x=[60727, 16], y=[512])

Step 65:
=====
Number of graphs in current batch: 512
Batch(batch=[60720], edge_index=[2, 100352], ptr=[513], x=[60720, 16], y=[512])

Step 66:
=====
Number of graphs in current batch: 512
Batch(batch=[60586], edge_index=[2, 100352], ptr=[513], x=[60586, 16], y=[512])

Step 67:
=====
Number of graphs in current batch: 512
Batch(batch=[61119], edge_index=[2, 100352], ptr=[513], x=[61119, 16], y=[512])

Step 68:
=====
Number of graphs in current batch: 512
Batch(batch=[60707], edge_index=[2, 100352], ptr=[513], x=[60707, 16], y=[512])

Step 69:
=====
Number of graphs in current batch: 512
Batch(batch=[60496], edge_index=[2, 100352], ptr=[513], x=[60496, 16], y=[512])

Step 70:
=====
Number of graphs in current batch: 512
B

Step 126:
=====
Number of graphs in current batch: 512
Batch(batch=[60762], edge_index=[2, 100352], ptr=[513], x=[60762, 16], y=[512])

Step 127:
=====
Number of graphs in current batch: 512
Batch(batch=[60508], edge_index=[2, 100352], ptr=[513], x=[60508, 16], y=[512])

Step 128:
=====
Number of graphs in current batch: 512
Batch(batch=[60979], edge_index=[2, 100352], ptr=[513], x=[60979, 16], y=[512])

Step 129:
=====
Number of graphs in current batch: 512
Batch(batch=[60361], edge_index=[2, 100352], ptr=[513], x=[60361, 16], y=[512])

Step 130:
=====
Number of graphs in current batch: 512
Batch(batch=[60363], edge_index=[2, 100352], ptr=[513], x=[60363, 16], y=[512])

Step 131:
=====
Number of graphs in current batch: 512
Batch(batch=[61243], edge_index=[2, 100352], ptr=[513], x=[61243, 16], y=[512])

Step 132:
=====
Number of graphs in current batch: 512
Batch(batch=[60755], edge_index=[2, 100352], ptr=[513], x=[60755, 16], y=[512])

Step 133:
=====
Number of graphs in current batc

Step 187:
=====
Number of graphs in current batch: 512
Batch(batch=[60450], edge_index=[2, 100352], ptr=[513], x=[60450, 16], y=[512])

Step 188:
=====
Number of graphs in current batch: 512
Batch(batch=[60496], edge_index=[2, 100352], ptr=[513], x=[60496, 16], y=[512])

Step 189:
=====
Number of graphs in current batch: 512
Batch(batch=[61213], edge_index=[2, 100352], ptr=[513], x=[61213, 16], y=[512])

Step 190:
=====
Number of graphs in current batch: 512
Batch(batch=[60636], edge_index=[2, 100352], ptr=[513], x=[60636, 16], y=[512])

Step 191:
=====
Number of graphs in current batch: 512
Batch(batch=[60837], edge_index=[2, 100352], ptr=[513], x=[60837, 16], y=[512])

Step 192:
=====
Number of graphs in current batch: 512
Batch(batch=[60163], edge_index=[2, 100352], ptr=[513], x=[60163, 16], y=[512])

Step 193:
=====
Number of graphs in current batch: 512
Batch(batch=[60655], edge_index=[2, 100352], ptr=[513], x=[60655, 16], y=[512])

Step 194:
=====
Number of graphs in current batc

Step 248:
=====
Number of graphs in current batch: 512
Batch(batch=[60952], edge_index=[2, 100352], ptr=[513], x=[60952, 16], y=[512])

Step 249:
=====
Number of graphs in current batch: 512
Batch(batch=[60833], edge_index=[2, 100352], ptr=[513], x=[60833, 16], y=[512])

Step 250:
=====
Number of graphs in current batch: 512
Batch(batch=[61364], edge_index=[2, 100352], ptr=[513], x=[61364, 16], y=[512])

Step 251:
=====
Number of graphs in current batch: 512
Batch(batch=[60651], edge_index=[2, 100352], ptr=[513], x=[60651, 16], y=[512])

Step 252:
=====
Number of graphs in current batch: 512
Batch(batch=[61006], edge_index=[2, 100352], ptr=[513], x=[61006, 16], y=[512])

Step 253:
=====
Number of graphs in current batch: 512
Batch(batch=[60679], edge_index=[2, 100352], ptr=[513], x=[60679, 16], y=[512])

Step 254:
=====
Number of graphs in current batch: 512
Batch(batch=[60469], edge_index=[2, 100352], ptr=[513], x=[60469, 16], y=[512])

Step 255:
=====
Number of graphs in current batc

Step 311:
=====
Number of graphs in current batch: 512
Batch(batch=[60146], edge_index=[2, 100352], ptr=[513], x=[60146, 16], y=[512])

Step 312:
=====
Number of graphs in current batch: 512
Batch(batch=[60251], edge_index=[2, 100352], ptr=[513], x=[60251, 16], y=[512])

Step 313:
=====
Number of graphs in current batch: 254
Batch(batch=[30450], edge_index=[2, 49784], ptr=[255], x=[30450, 16], y=[254])



In [12]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GENConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(5)
        self.conv1 = GCNConv(16, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        #self.conv1 = GENConv(12, hidden_channels)
        #self.conv2 = GENConv(hidden_channels, hidden_channels)
        #self.conv3 = GENConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)
        
    def forward(self, x, edge_index, batch):
        # Get node embeeding
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        
        # Readout alayer
        x = global_mean_pool(x, batch)
        
        # Out layer
        x = F.dropout(x, p=0.4, training=self.training)
        x = self.lin(x)
        return x
    
model = GCN(hidden_channels=256)
print(model)

GCN(
  (conv1): GCNConv(16, 256)
  (conv2): GCNConv(256, 256)
  (conv3): GCNConv(256, 256)
  (lin): Linear(in_features=256, out_features=2, bias=True)
)


In [13]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.set_device(1)
print('Current cuda device ID:',torch.cuda.current_device())
print('Current cuda device name:', torch.cuda.get_device_name())

cuda:1
Current cuda device ID: 1
Current cuda device name: Tesla V100-PCIE-32GB


In [14]:
# Train/test
model = GCN(hidden_channels=256)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_func = torch.nn.CrossEntropyLoss()

def to_device(data, device):
    return data.to(device, non_blocking=True)

def train():
    model.train()
    for data in train_loader:
        a = to_device(data.x, device)
        b = to_device(data.edge_index, device)
        c = to_device(data.batch, device)
        d = to_device(data.y, device)
        
        out = model(a, b, c)
        #print(out)
        
        loss = loss_func(out, d)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
def test(loader):
    model.eval()
    correct = 0
    aupr, auroc, loss=[], [], []
    for data in loader:
        a = to_device(data.x, device)
        b = to_device(data.edge_index, device)
        c = to_device(data.batch, device)
        d = to_device(data.y, device)
        
        out = model(a, b, c)
        prob = F.softmax(out, dim=1)
        
        # AUPR
        precision, recall, thresholds = precision_recall_curve(d.detach().cpu().clone().numpy(), prob[:,1].detach().cpu().clone().numpy())
        aupr.append(auc(recall, precision))
        
        # AUROC
        try:
            auroc.append(roc_auc_score(d.detach().cpu().clone().numpy(), prob[:,1].detach().cpu().clone().numpy()))
        except:
            pass
        
        # LOSS
        loss_ = loss_func(out, d)
        loss.append(loss_.detach().cpu().clone().numpy())
        
        # print(prob)
        
        pred = out.argmax(dim=1)
        correct += int((pred==d).sum())
        
    return correct/len(loader.dataset), np.nanmean(aupr), np.nanmean(auroc), np.nanmean(loss)

In [15]:
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn.metrics import roc_auc_score

In [None]:
np.seterr(divide='ignore', invalid='ignore')
train_acc_, test_acc_, train_aupr_, test_aupr_, train_auroc_, test_auroc_ = [], [], [], [], [], []
for epoch in range(1, 200):
    train()
    train_acc, train_aupr, train_auroc, train_loss = test(train_loader)
    test_acc, test_aupr, test_auroc, test_loss = test(test_loader)
    
    train_acc_.append(train_acc)
    test_acc_.append(test_acc)
    train_aupr_.append(train_aupr)
    test_aupr_.append(test_aupr)
    train_auroc_.append(train_auroc)
    test_auroc_.append(test_auroc)
    
    print(f'Epoch:{epoch:03d}')
    print(f'Train::ACC:{train_acc} AUPR:{train_aupr} AUROC:{train_auroc} LOSS:{train_loss}')
    print(f'Test ::ACC:{test_acc} AUPR:{test_aupr} AUROC:{test_auroc} LOSS:{test_loss}')
    print()
    
    # print(f'Epoch:{epoch:03d},Train acc:{train_acc:.4f},Train aupr:{train_aupr:.4f},Train auroc:{train_auroc:.4f},Test acc:{test_acc:.4f},Test aupr:{test_aupr:.4f},Test auroc:{test_auroc:.4f}')

Epoch:001
Train::ACC:0.8048788109851374 AUPR:0.4032429947993151 AUROC:0.6823279152425826 LOSS:0.4666190445423126
Test ::ACC:0.8036 AUPR:0.4148729165290732 AUROC:0.6900809023889612 LOSS:0.4666746258735657

Epoch:002
Train::ACC:0.8090101126264079 AUPR:0.43054676549749843 AUROC:0.7126942526328717 LOSS:0.45545417070388794
Test ::ACC:0.807925 AUPR:0.44145313081040477 AUROC:0.7181316770344256 LOSS:0.45561710000038147

Epoch:003
Train::ACC:0.8152289403617545 AUPR:0.44322016885153953 AUROC:0.7249053401516338 LOSS:0.44834104180336
Test ::ACC:0.813375 AUPR:0.45407406480061163 AUROC:0.7297990238626457 LOSS:0.4478108882904053

Epoch:004
Train::ACC:0.8124414055175689 AUPR:0.46096189314502156 AUROC:0.737788480179833 LOSS:0.44289109110832214
Test ::ACC:0.8112 AUPR:0.4642574768386935 AUROC:0.7378559813494865 LOSS:0.4450170695781708

Epoch:005
Train::ACC:0.8167977099713747 AUPR:0.4733536541179277 AUROC:0.7466119570255808 LOSS:0.4326411187648773
Test ::ACC:0.815325 AUPR:0.479475158143198 AUROC:0.7497497

Epoch:041
Train::ACC:0.8354416930211628 AUPR:0.5821501269556487 AUROC:0.8135186649307397 LOSS:0.3875111937522888
Test ::ACC:0.825775 AUPR:0.5301394560593802 AUROC:0.7820044484047062 LOSS:0.41348299384117126

Epoch:042
Train::ACC:0.8386667333341666 AUPR:0.5841610745790669 AUROC:0.815739661873105 LOSS:0.38943225145339966
Test ::ACC:0.8251 AUPR:0.5319805657926772 AUROC:0.7809665986954223 LOSS:0.4151720404624939

Epoch:043
Train::ACC:0.8346479330991637 AUPR:0.5843925255627519 AUROC:0.8152413096013961 LOSS:0.3861038088798523
Test ::ACC:0.8257 AUPR:0.5321472155889433 AUROC:0.782168919718267 LOSS:0.4114278554916382

Epoch:044
Train::ACC:0.8388917361467019 AUPR:0.5878498802318995 AUROC:0.8180091730427315 LOSS:0.3852083086967468
Test ::ACC:0.82425 AUPR:0.5338277498769882 AUROC:0.7816362026487401 LOSS:0.4134116470813751

Epoch:045
Train::ACC:0.8374854685683571 AUPR:0.5876511804832736 AUROC:0.8180619014739156 LOSS:0.3822564482688904
Test ::ACC:0.8269 AUPR:0.5331184561753333 AUROC:0.78332081279636

In [None]:
import matplotlib.pyplot as plt
plt.plot(train_acc_)
plt.plot(test_acc_)
plt.plot(train_aupr_)
plt.show()