In [1]:
#Import necessary packages
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import numpy as np
import networkx as nx
import pandas as pd
import dgl


2024-04-17 19:49:56.974653: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
#Import model and tokenizer from ProtBERT model developed by the Rost lab
bert_model = BertModel.from_pretrained("Rostlab/prot_bert")
# Tokenizer for ProtBERT
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert")

Downloading:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/112 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/86.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/361 [00:00<?, ?B/s]

In [None]:
#Obtain BERT embeddings with ProtBERT model. 2048 selected as maximum length based on original paper describing the method.
def get_bert_embeddings(sequence):
    inputs = tokenizer(sequence, return_tensors='pt', padding=True, max_length=2048, truncation=True)
    outputs = bert_model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1)  # Average pooling of BERT embeddings
    return embeddings

In [None]:
# Integrating node embeddings of contact maps generated by Node2Vec
def load_node_embeddings(directory, protein_ids):
    node_embeddings = []
    directory = "embed_dir"
    for protein_id in protein_ids:
        embedding_path = os.path.join(directory, f"{protein_id}_contact_map_embedding.npy")
        if os.path.exists(embedding_path):
            embedding = np.load(embedding_path)
            node_embeddings.append(embedding)
        else:
            # Handle the case where embedding file is missing by creating a placeholder (filled with zeros)
            node_embeddings.append(np.zeros((1, embedding_dim)))  # Placeholder embedding
    node_embeddings = np.array(node_embeddings)
    return torch.tensor(graph_embeddings, dtype=torch.float32)

In [None]:
#Feature Extraction - Assign value based on presence of (known) NLS
def extract_nls_labels(df):
    nls_labels = []
    for _, row in df.iterrows():
        if row["Begin"] == 0 or row["End"] == 0:
            nls_labels.append(0)  # Non-NLS sequence
        else:
            nls_labels.append(1)  # NLS sequence
    return nls_labels

nls_labels = extract_nls_labels(df)

In [None]:
#Combining two types of embeddings
def combine_embeddings(bert_embeddings, node_embeddings):
    combined_embeddings = torch.cat((bert_embeddings, node_embeddings), dim=1)
    return combined_embeddings

In [None]:
class GCN_BERT_Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCN_BERT_Model, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc2(x)
        return x


In [None]:
class CustomDataset(Dataset):
    def __init__(self, sequences, structures, labels): #Adding NLS labels to protein sequences and corresponding 3-D structures
        self.sequences = sequences
        self.structures = structures
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        sequence = self.sequences[idx]
        structure = self.structures[idx]
        label = self.labels[idx]
        return sequence, structure, label

In [None]:
#Train-test split
train_sequences, test_sequences, train_structures, test_structures, train_labels, test_labels = train_test_split(
    sequences, structures, nls_labels, test_size=0.2, random_state=42)

train_dataset = CustomDataset(train_sequences, train_structures, train_labels)
test_dataset = CustomDataset(test_sequences, test_structures, test_labels)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
# Define hyperparameters
input_dim = bert_model.config.hidden_size + node_embeddings.shape[1]
hidden_dim = 128
output_dim = 2  # Binary classification

In [None]:
# Initialize model, loss, and optimizer
model = GCN_BERT_Model(input_dim, hidden_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# Train the model
num_epochs=3
model.train()
for epoch in range(num_epochs):
    for sequences, protein_ids, labels in train_loader:
        bert_embeddings = get_bert_embeddings(sequences)
        node_embeddings = load_graph_embeddings(graph_embedding_directory, protein_ids)
        combined_embeddings = combine_embeddings(bert_embeddings, node_embeddings)
        optimizer.zero_grad()
        outputs = model(combined_embeddings)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

# After each epoch, evaluate on the training set
    model.eval()
    train_predictions = []
    train_true_labels = []
    for sequences, protein_ids, labels in train_loader:
        bert_embeddings = get_bert_embeddings(sequences)
        node_embeddings = load_graph_embeddings(graph_embedding_directory, protein_ids)
        combined_embeddings = combine_embeddings(bert_embeddings, node_embeddings)
        outputs = model(combined_embeddings)
        _, predicted = torch.max(outputs, 1)
        train_predictions.extend(predicted.tolist())
        train_true_labels.extend(labels.tolist())

    train_accuracy = accuracy_score(train_true_labels, train_predictions)
    train_report = classification_report(train_true_labels, train_predictions)

    print(f"Epoch {epoch + 1} - Training Metrics:")
    print(f"Accuracy: {train_accuracy}")
    print("Classification Report:")
    print(train_report)

In [None]:
#Evaluate the model and obtain performance metrics
model.eval()
predictions = []
true_labels = []
for sequences, protein_ids, labels in test_loader:
    bert_embeddings = get_bert_embeddings(sequences)
    node_embeddings = load_graph_embeddings(graph_embedding_directory, protein_ids)
    combined_embeddings = combine_embeddings(bert_embeddings, node_embeddings)
    outputs = model(combined_embeddings)
    _, predicted = torch.max(outputs, 1)
    predictions.extend(predicted.tolist())
    true_labels.extend(labels.tolist())

accuracy = accuracy_score(true_labels, predictions)
test_report = classification_report(true_labels, predictions)

print("Test Metrics:")
print(f"Accuracy: {accuracy}")
print("Classification Report:")
print(test_report)