### 1. Preprocess Data

In [1]:
import numpy as np
import torch
import scipy.sparse
import os
import pandas as pd
import numpy as np
import networkx
import obonet
import json
import configparser

In [2]:
config_path = "../config/main.conf"
conf = configparser.ConfigParser()
conf.read(config_path)

['../config/main.conf']

In [3]:
contact_maps = {}
for file in os.listdir('../data/contact_maps/sparse_matrices')[:10000]:
    if file.endswith('.npz'):
        pdb_code = file.split('.')[0]
        contact_map = scipy.sparse.load_npz('../data/contact_maps/sparse_matrices/' + file)
        contact_map.resize((3000, 3000))
        contact_maps[pdb_code] = contact_map

In [4]:
url = "http://purl.obolibrary.org/obo/go/go-basic.obo"
graph = obonet.read_obo(url)
goid_to_category = {id_: data.get('namespace') for id_, data in graph.nodes(data=True)}

In [5]:
df = pd.read_csv('../data/GO/pdb_chain_go.csv', skiprows = 1, error_bad_lines=False, warn_bad_lines = False)
df = df[['PDB', 'GO_ID']]
pdb_to_go = {}
for key, value in df.values:
    go_list = pdb_to_go.get(key.upper(), set())
    go_list.add(value)
    pdb_to_go[key.upper()] = go_list

In [6]:
with open("../data/contact_maps/pdb_sequences.json","r") as f:
    sequence_dict = json.load(f)

In [7]:
X_mf_data = []
y_mf_data = []
for key, value in contact_maps.items():
    go_list = []
    for go_term in pdb_to_go.get(key, []):
        if goid_to_category.get(go_term, '') == 'molecular_function':
            go_list.append(go_term)
    if len(go_list) > 0:
        X_mf_data.append((sequence_dict[key], value))
        y_mf_data.append(go_list)


In [8]:
mf_counts = {}
for labels in y_mf_data:
    for label in labels:
        mf_counts[label] = mf_counts.get(label, 0) + 1

In [9]:
removal_indices = []
for i in range(len(y_mf_data)):
    labels = y_mf_data[i]
    for label in labels:
        if mf_counts[label] < 25:
            labels.remove(label)
            if len(labels) == 0:
                removal_indices.append(i)
                
for i in reversed(removal_indices):
    X_mf_data.pop(i)
    y_mf_data.pop(i)
            

In [10]:
mf_terms = set()
for labels in y_mf_data:
    for label in labels:
        mf_terms.add(label)
        
mf_vocab = {mf: idx for idx, mf in enumerate(mf_terms)}
len(mf_vocab)

593

In [11]:
# from utils/data.py
def prepare_sequence(seq, vocab, padding):
    """
    function to process the data, padding them
    TODO later will move to specific preprocessing part
    """
    res = ['<PAD>'] * padding
    res[:min(padding, len(seq))] = seq[:min(padding, len(seq))]
    # use 0 for padding
    idxs = [vocab[w] for w in res]
    return torch.tensor(idxs, dtype=torch.long)

vocab_path = '../config/genome_vocab.json'
with open(vocab_path, 'r') as of:
    vocab = json.load(of)

prepare_sequence(X_mf_data[0][0], vocab, 3000).shape

torch.Size([3000])

In [12]:
print(len(X_mf_data), len(y_mf_data))

7420 7420


### 2. Train Model

In [13]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_add_pool as gap
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.data import DataLoader
from sklearn.metrics import accuracy_score

In [14]:
from sklearn.model_selection import train_test_split
X_train_mf, X_val_mf, y_train_mf, y_val_mf = train_test_split(X_mf_data, y_mf_data, test_size=0.125, random_state=41)
dataset_train = []
for data, raw_labels in zip(X_train_mf, y_train_mf):
    x = prepare_sequence(data[0], vocab, 3000).float()
    edge_index = from_scipy_sparse_matrix(data[1])[0]
    labels = torch.tensor([mf_vocab[y] for y in raw_labels])
    targets = torch.zeros((len(mf_vocab)))
    targets[labels] = 1
    targets = torch.tensor(targets)
    dataset_train.append(Data(x = x, edge_index = edge_index, y = targets))
dataset_val = []
for data, raw_labels in zip(X_val_mf, y_val_mf):
    x = prepare_sequence(data[0], vocab, 3000).float()
    edge_index = from_scipy_sparse_matrix(data[1])[0]
    labels = torch.tensor([mf_vocab[y] for y in raw_labels])
    targets = torch.zeros((len(mf_vocab)))
    targets[labels] = 1
    targets = torch.tensor(targets)
    dataset_val.append(Data(x = x, edge_index = edge_index, y = targets))

  # Remove the CWD from sys.path while we load stuff.


In [15]:
language_model = torch.load('model_LSTM.attn_4epoch', map_location=torch.device('cpu'))
lm_embedding = language_model['model_state_dict']['module.word_embeddings.weight']
lm_embedding.shape

torch.Size([27, 80])

In [16]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.embedding = torch.nn.Embedding.from_pretrained(lm_embedding)
        self.conv1 = GCNConv(80, 128)
        self.conv2 = GCNConv(128, 128)
        self.conv3 = GCNConv(128, 256)
        self.linear1 = torch.nn.Linear(256, 512)
        self.linear2 = torch.nn.Linear(512, len(mf_vocab))

    def forward(self, data):
        
        x, edge_index, batch = data.x, data.edge_index, data.batch
                
        x = self.embedding(x.long())
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = gap(x, batch)
        x = self.linear1(x)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.linear2(x)
        x = torch.sigmoid(x)
        
        return x

In [17]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device).float()
optimizer = torch.optim.Adam(model.parameters(), lr=0.00005, weight_decay=5e-4)
train_loader = DataLoader(dataset_train, shuffle = True, batch_size=64)
model.train()
for epoch in range(20):
    for iteration, data in enumerate(train_loader):
        optimizer.zero_grad()
        out = model(data)
        criterion = torch.nn.BCELoss()
        loss = criterion(out, data.y.float().reshape(-1, len(mf_vocab)))
        loss.backward()
        optimizer.step()
        if iteration % 40 == 0:
            print(f'epoch {epoch + 1}, iteration {iteration}')
            print(f'loss = {loss}')
            pred = (out > 0.5).float()
            label = data.y.float().reshape(-1, len(mf_vocab))
            print('precision =', ((label * pred).sum()/label.sum()).item())
            print('recall =', ((label * pred).sum()/pred.sum()).item())
            print('total # of go terms =', label.sum().item())
            print('predicted # of go terms =', pred.sum().item())
            print()

epoch 1, iteration 0
loss = 2.5587947368621826
precision = 0.5746268630027771
recall = 0.007982169277966022
total # of go terms = 268.0
predicted # of go terms = 19293.0

epoch 1, iteration 40
loss = 0.4654340445995331
precision = 0.18241041898727417
recall = 0.00860611628741026
total # of go terms = 307.0
predicted # of go terms = 6507.0

epoch 1, iteration 80
loss = 0.1372823566198349
precision = 0.04545454680919647
recall = 0.009188361465930939
total # of go terms = 264.0
predicted # of go terms = 1306.0

epoch 2, iteration 0
loss = 0.0759175568819046
precision = 0.06086956337094307
recall = 0.029723990708589554
total # of go terms = 230.0
predicted # of go terms = 471.0

epoch 2, iteration 40
loss = 0.06926358491182327
precision = 0.04642857238650322
recall = 0.05179283022880554
total # of go terms = 280.0
predicted # of go terms = 251.0

epoch 2, iteration 80
loss = 0.05172669515013695
precision = 0.07058823853731155
recall = 0.1090909093618393
total # of go terms = 255.0
predicte

In [None]:
model.eval()
val_loader = DataLoader(dataset_val, batch_size = 64)
with torch.no_grad():
    train_logits = {}
    val_logits = {}
    train_labels = {}
    val_labels = {}
    train_precisions = []
    train_recalls = []
    val_precisions = []
    val_recalls = []
    for thresh in np.arange(0.01, .91, .001):
        train_precision = 0.0
        train_recall = 0.0
        train_num_batches = 0.0
        for idx, data in enumerate(train_loader):
            data = data.to(device)
            if idx in train_logits:
                pred = (train_logits[idx] > thresh).float()
                label = train_labels[idx]
            else:
                logits = model(data)
                train_logits[idx] = logits
                pred = (logits > thresh).float()
                label = data.y.float().reshape(-1, len(mf_vocab))
                train_labels[idx] = label
            train_precision += ((label * pred).sum()/label.sum()).item()
            train_recall += ((label * pred).sum()/pred.sum()).item()
            train_num_batches += 1
        #print('precision = ', train_precision/train_num_batches)
        #print('recall = ', train_recall/train_num_batches)
        train_precisions.append(train_precision/train_num_batches)
        train_recalls.append(train_recall/train_num_batches)
        
        val_precision = 0.0
        val_recall = 0.0
        val_num_batches = 0.0
    
        for data in val_loader:
            data = data.to(device) 
            if idx in val_logits:
                pred = (val_logits[idx] > thresh).float()
                label = val_labels[idx]
            else:
                logits = model(data)
                val_logits[idx] = logits
                pred = (logits > thresh).float()
                label = data.y.float().reshape(-1, len(mf_vocab))
                val_labels[idx] = label
            val_precision += ((label * pred).sum()/label.sum()).item()
            val_recall += ((label * pred).sum()/pred.sum()).item()
            val_num_batches += 1
        #print('precision = ', val_precision/val_num_batches)
        #print('recall = ', val_recall/val_num_batches)
        val_precisions.append(val_precision/val_num_batches)
        val_recalls.append(val_recall/val_num_batches)

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import auc
plt.plot(train_precisions, train_recalls)
plt.plot(val_precisions, train_recalls)
plt.title('Averaged Precision Recall for Training data')
plt.legend([f'train auc = {round(auc(train_precisions, train_recalls), 3)}', f'val auc = {round(auc(val_precisions, train_recalls), 3)}'])
plt.xlabel('Precision')
plt.ylabel('Recall')
plt.show()

### 