In [1]:
# Import package
import numpy as np
import random
from numpy.core.numeric import ones_like
from torch_geometric.data import Data, DataLoader
import torch
from tqdm import tqdm
from utils.debruijn import DeBruijnGraph

In [2]:
# Define tf and datapaths
TF = 'ctcf'

PATH_ = "/DATA/yogesh/encodeDream/%s/"%(TF)
# Using readlines()
file1 = open(PATH_+'processed_out/final_training.csv', 'r')
Lines = file1.readlines()

In [3]:
# Prepare Data
kmer, DATALIST = 4, []
for line in tqdm(Lines):
    onehot_x = []
    # get DNA sequence from the dummy file 
    s = line.strip().split(',')[0]
    d=DeBruijnGraph(s,kmer)    
    for node in d.x.flatten():
        one_hot_ = d.one_hot_encode(node).flatten()
        onehot_x.append(one_hot_.tolist())

    # Arrays to pytorch tensors
    onehot_x_tensor = torch.tensor(np.array(onehot_x), dtype=torch.float)
    onehot_edge_index_tensor = torch.tensor(d.edge_index, dtype=torch.long)
    
    # get the label from the data
    if line.strip().split(',')[1]=='U':
        y_tensor = torch.tensor(0)
    else:
        y_tensor = torch.tensor(1)

    # Add tensors to torch_geometric data object
    data = Data(x=onehot_x_tensor, edge_index=onehot_edge_index_tensor, y=y_tensor)

    DATALIST.append(data)

print('Datapoints:', len(DATALIST))

print('x_shape:', DATALIST[1].x.shape)
print('edge_index_shape:', DATALIST[1].edge_index.shape)
print()

  0%|          | 145294/47549176 [15:51<86:16:19, 152.63it/s] 


KeyboardInterrupt: 

In [4]:
temp_data = DATALIST[1]  # Get the first graph object.

print()
print(temp_data)
print('=============================================================')

# Gather some statistics about the first graph.
print(f'Number of nodes: {temp_data.num_nodes}')
print(f'Number of edges: {temp_data.num_edges}')
print(f'Average node degree: {temp_data.num_edges / temp_data.num_nodes:.2f}')
print(f'Contains self-loops: {temp_data.contains_self_loops()}')
print(f'Is undirected: {temp_data.is_undirected()}')


Data(edge_index=[2, 197], x=[60, 12], y=0)
Number of nodes: 60
Number of edges: 197
Average node degree: 3.28
Contains self-loops: True
Is undirected: False


In [5]:
# Split data
torch.manual_seed(11)
random.shuffle(DATALIST)

split_n = int(0.8*(len(DATALIST)))

train_dataset = DATALIST[:split_n]
test_dataset = DATALIST[split_n:]

print('Number of training graph:', len(train_dataset))
print('Number of testing graph:', len(test_dataset))

Number of training graph: 116235
Number of testing graph: 29059


In [6]:
#####
from torch_geometric.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=True)

for step, data in enumerate(train_loader):
    print(f'Step {step+1}:')
    print('=====')
    print(f'Number of graphs in current batch: {data.num_graphs}')
    print(data)
    print()

Step 1:
=====
Number of graphs in current batch: 512
Batch(batch=[27705], edge_index=[2, 100864], ptr=[513], x=[27705, 12], y=[512])

Step 2:
=====
Number of graphs in current batch: 512
Batch(batch=[27968], edge_index=[2, 100864], ptr=[513], x=[27968, 12], y=[512])

Step 3:
=====
Number of graphs in current batch: 512
Batch(batch=[27735], edge_index=[2, 100864], ptr=[513], x=[27735, 12], y=[512])

Step 4:
=====
Number of graphs in current batch: 512
Batch(batch=[27783], edge_index=[2, 100864], ptr=[513], x=[27783, 12], y=[512])

Step 5:
=====
Number of graphs in current batch: 512
Batch(batch=[27776], edge_index=[2, 100864], ptr=[513], x=[27776, 12], y=[512])

Step 6:
=====
Number of graphs in current batch: 512
Batch(batch=[28039], edge_index=[2, 100864], ptr=[513], x=[28039, 12], y=[512])

Step 7:
=====
Number of graphs in current batch: 512
Batch(batch=[27661], edge_index=[2, 100864], ptr=[513], x=[27661, 12], y=[512])

Step 8:
=====
Number of graphs in current batch: 512
Batch(bat

Step 67:
=====
Number of graphs in current batch: 512
Batch(batch=[28007], edge_index=[2, 100864], ptr=[513], x=[28007, 12], y=[512])

Step 68:
=====
Number of graphs in current batch: 512
Batch(batch=[27926], edge_index=[2, 100864], ptr=[513], x=[27926, 12], y=[512])

Step 69:
=====
Number of graphs in current batch: 512
Batch(batch=[27868], edge_index=[2, 100864], ptr=[513], x=[27868, 12], y=[512])

Step 70:
=====
Number of graphs in current batch: 512
Batch(batch=[27859], edge_index=[2, 100864], ptr=[513], x=[27859, 12], y=[512])

Step 71:
=====
Number of graphs in current batch: 512
Batch(batch=[27861], edge_index=[2, 100864], ptr=[513], x=[27861, 12], y=[512])

Step 72:
=====
Number of graphs in current batch: 512
Batch(batch=[27939], edge_index=[2, 100864], ptr=[513], x=[27939, 12], y=[512])

Step 73:
=====
Number of graphs in current batch: 512
Batch(batch=[27866], edge_index=[2, 100864], ptr=[513], x=[27866, 12], y=[512])

Step 74:
=====
Number of graphs in current batch: 512
B

Step 131:
=====
Number of graphs in current batch: 512
Batch(batch=[28036], edge_index=[2, 100864], ptr=[513], x=[28036, 12], y=[512])

Step 132:
=====
Number of graphs in current batch: 512
Batch(batch=[27745], edge_index=[2, 100864], ptr=[513], x=[27745, 12], y=[512])

Step 133:
=====
Number of graphs in current batch: 512
Batch(batch=[27997], edge_index=[2, 100864], ptr=[513], x=[27997, 12], y=[512])

Step 134:
=====
Number of graphs in current batch: 512
Batch(batch=[27934], edge_index=[2, 100864], ptr=[513], x=[27934, 12], y=[512])

Step 135:
=====
Number of graphs in current batch: 512
Batch(batch=[27948], edge_index=[2, 100864], ptr=[513], x=[27948, 12], y=[512])

Step 136:
=====
Number of graphs in current batch: 512
Batch(batch=[27915], edge_index=[2, 100864], ptr=[513], x=[27915, 12], y=[512])

Step 137:
=====
Number of graphs in current batch: 512
Batch(batch=[27715], edge_index=[2, 100864], ptr=[513], x=[27715, 12], y=[512])

Step 138:
=====
Number of graphs in current batc

Step 194:
=====
Number of graphs in current batch: 512
Batch(batch=[27827], edge_index=[2, 100864], ptr=[513], x=[27827, 12], y=[512])

Step 195:
=====
Number of graphs in current batch: 512
Batch(batch=[27776], edge_index=[2, 100864], ptr=[513], x=[27776, 12], y=[512])

Step 196:
=====
Number of graphs in current batch: 512
Batch(batch=[27951], edge_index=[2, 100864], ptr=[513], x=[27951, 12], y=[512])

Step 197:
=====
Number of graphs in current batch: 512
Batch(batch=[27768], edge_index=[2, 100864], ptr=[513], x=[27768, 12], y=[512])

Step 198:
=====
Number of graphs in current batch: 512
Batch(batch=[27529], edge_index=[2, 100864], ptr=[513], x=[27529, 12], y=[512])

Step 199:
=====
Number of graphs in current batch: 512
Batch(batch=[27928], edge_index=[2, 100864], ptr=[513], x=[27928, 12], y=[512])

Step 200:
=====
Number of graphs in current batch: 512
Batch(batch=[27772], edge_index=[2, 100864], ptr=[513], x=[27772, 12], y=[512])

Step 201:
=====
Number of graphs in current batc

In [52]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv,GENConv
from torch_geometric.nn import global_mean_pool

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(5)
        self.conv1 = GCNConv(12, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        #self.conv1 = GENConv(12, hidden_channels)
        #self.conv2 = GENConv(hidden_channels, hidden_channels)
        #self.conv3 = GENConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, 2)
        
    def forward(self, x, edge_index, batch):
        # Get node embeeding
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        
        # Readout alayer
        x = global_mean_pool(x, batch)
        
        # Out layer
        x = F.dropout(x, p=0.4, training=self.training)
        x = self.lin(x)
        return x
    
model = GCN(hidden_channels=512)
print(model)

GCN(
  (conv1): GCNConv(12, 512)
  (conv2): GCNConv(512, 512)
  (conv3): GCNConv(512, 512)
  (lin): Linear(in_features=512, out_features=2, bias=True)
)


In [53]:
device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
print(device)
torch.cuda.set_device(1)
print('Current cuda device ID:',torch.cuda.current_device())
print('Current cuda device name:', torch.cuda.get_device_name())

cuda:1
Current cuda device ID: 1
Current cuda device name: Tesla V100-PCIE-32GB


In [54]:
# Train/test
model = GCN(hidden_channels=256)
model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.1)
loss_func = torch.nn.CrossEntropyLoss()

def to_device(data, device):
    return data.to(device, non_blocking=True)

def train():
    model.train()
    for data in train_loader:
        a = to_device(data.x, device)
        b = to_device(data.edge_index, device)
        c = to_device(data.batch, device)
        d = to_device(data.y, device)
        
        out = model(a, b, c)
        #print(out)
        
        loss = loss_func(out, d)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
    
def test(loader):
    model.eval()
    correct = 0
    aupr=[]
    for data in loader:
        a = to_device(data.x, device)
        b = to_device(data.edge_index, device)
        c = to_device(data.batch, device)
        d = to_device(data.y, device)
        
        out = model(a, b, c)
        
        precision, recall, thresholds = precision_recall_curve(d.detach().cpu().clone().numpy(), out[:,1].detach().cpu().clone().numpy())
        aupr.append(auc(recall, precision))
        # print(aupr)
        
        pred = out.argmax(dim=1)
        correct += int((pred==d).sum())
        
    return correct/len(loader.dataset), np.nanmean(aupr)

In [55]:
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc

In [None]:
np.seterr(divide='ignore', invalid='ignore')
train_, test_ = [], []
for epoch in range(1, 100):
    train()
    train_acc, train_aupr = test(train_loader)
    test_acc, test_aupr = test(test_loader)
    
    train_.append(train_acc)
    test_.append(test_acc)
    
    print(f'Epoch:{epoch:03d}, Train acc: {train_acc:.4f}, Train aupr: {train_aupr:.4f}, Test acc: {test_acc:.4f}, Test aupr: {test_aupr:.4f}')

Epoch:001, Train acc: 0.9944, Train aupr: 0.0080, Test acc: 0.9944, Test aupr: 0.0232
Epoch:002, Train acc: 0.9944, Train aupr: 0.3239, Test acc: 0.9944, Test aupr: 0.3450
Epoch:003, Train acc: 0.9944, Train aupr: 0.3229, Test acc: 0.9944, Test aupr: 0.3307
