In [57]:
#Import necessary packages
import pandas as pd
import torch
from transformers import BertTokenizer, BertForTokenClassification, AdamW, BertModel
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.model_selection import train_test_split
import torch.nn as nn
import torch.nn.functional as F


In [26]:
#Import tokenizer from ProtBERT model developed by the Rost lab
# Tokenizer for ProtBERT
tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert")

In [27]:
#Define BERT-EE (BERT plus entity extraction for the NLS motifs)
class BERT_EE(nn.Module):
    def __init__(self, input_dim, motif_vocab_size, hidden_dim, num_classes):
        super(BERT_EE, self).__init__()
        self.bert = BertModel.from_pretrained("Rostlab/prot_bert")
        self.embedding_layer = nn.Embedding(motif_vocab_size, hidden_dim)
        self.linear = nn.Linear(input_dim + hidden_dim, num_classes)

    def forward(self, input_ids, attention_mask, motif_features):
        # ProtBERT embeddings for the protein sequence
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        bert_embeddings = bert_output.last_hidden_state[:, 0, :]  # CLS token embedding

        # Embedding layer for motif features
        motif_embeddings = self.embedding_layer(motif_features)

        # Concatenate ProtBERT embeddings with motif embeddings
        combined_embeddings = torch.cat((bert_embeddings, motif_embeddings), dim=1)

        # Final classification layer
        logits = self.linear(combined_embeddings)
        return logits


In [28]:
# Read the CSV file
df = pd.read_csv("finalized_complete_NLS_sequence_table.csv")

In [8]:
df


Unnamed: 0,UniProt ID,Sequence_full,Name,Begin,End,Sequence_nls,Length,Evidence,ECO code
0,Q14738,MPYKLKKEKEPPKVAKCTAKPSSSGKDGGGENTEEAQPQPQPQPQP...,Serine/threonine-protein phosphatase 2A 56 kDa...,548,565,KRTVETEAVQMLKDIKKE,18,Sequence Analysis,ECO:0000255
1,Q13362,MLTCNKAGSRMVVDAANSNGPFQPVVLLHIRDVPPADQEKLFIQKL...,Serine/threonine-protein phosphatase 2A 56 kDa...,416,422,KLKEKLK,7,Sequence Analysis,ECO:0000255
2,Q9NRA8,MDRRSMGETESGDAFLDLKKPPASKCPHRYTKEELLDIKELPHSKQ...,Eukaryotic translation initiation factor 4E tr...,195,211,RREFGDSKRVFGERRRN,17,,
3,P42684,MGQQVGRVGEAPGLQQPQPRGIRGSSAARPSGRRRDPAGRTTETGF...,Abelson tyrosine-protein kinase 2,658,660,KKR,3,Sequence Analysis,ECO:0000255
4,Q4JIM5,MGQQVGRVGEAPGLQQPQPRGIRGSSAARPSGRRRDPAGRTADAGF...,Abelson tyrosine-protein kinase 2,659,661,KKR,3,Sequence Analysis,ECO:0000255
...,...,...,...,...,...,...,...,...,...
1358,Q96CK0,MAERALEPEAEAEAEAGAGGEAAAEEGAAGRKARGRPRLTESDRAR...,Zinc finger protein 653,107,118,PKKPKRKKRRRR,12,Sequence Analysis,ECO:0000255
1359,Q96CK0,MAERALEPEAEAEAEAGAGGEAAAEEGAAGRKARGRPRLTESDRAR...,Zinc finger protein 653,445,451,EPEKRRR,7,Sequence Analysis,ECO:0000255
1360,Q24JY4,MVEKKTSVRSQDPGQRRVLDRAARQRRINRQLEALENDNFQDDPHA...,Zinc finger HIT domain-containing protein 1,38,47,DNFQDDPHAG,10,By similarity,ECO:0000250
1361,O43257,MVEKKTSVRSQDPGQRRVLDRAARQRRINRQLEALENDNFQDDPHA...,Zinc finger HIT domain-containing protein 1,38,47,DNFQDDPHAG,10,,


In [36]:
# Tokenize and encode sequences
# Maximum length selected based on original paper describing the ProtBERT model parameters
#In order to be stacked, input IDs and attention masks must first be in list format
input_ids_list = []
attention_masks_list = []

# Tokenize and encode sequences
for sequence in df['Sequence_nls']:
    encoded_seq = tokenizer(sequence, padding='max_length', truncation=True, max_length=2048, return_tensors='pt')
    input_ids_list.append(encoded_seq['input_ids'])
    attention_masks_list.append(encoded_seq['attention_mask'])

# Convert input IDs and attention masks to tensors
input_ids = torch.cat(input_ids_list, dim=0)
attention_masks = torch.cat(attention_masks_list, dim=0)



In [37]:
# Initialize motif features tensor, zeros as default
max_sequence_length = max(len(seq) for seq in df['Sequence_nls'])
motif_features = torch.zeros(len(df), max_sequence_length)  # Initialize motif features tensor, inflate everything to 2,048 zeros for each tensor

In [38]:
# Initialize start and end positions tensors, look for maximum length sequence here, populate tensors with zeros
max_sequence_length = max(len(seq) for seq in df['Sequence_nls'])
start_positions = torch.zeros(len(df), max_sequence_length, dtype=torch.long)
end_positions = torch.zeros(len(df), max_sequence_length, dtype=torch.long)

In [39]:
# Populate motif features tensor with ones representing the location of NLS sequences
for idx, row in df.iterrows():
    start = row['Begin'] #row for each protein
    end = row['End']
    motif_features[idx, start:end+1] = 1  # Set the range of positions corresponding to the motif (NLS in this case) to 1

In [40]:
# Labeling (since all these sequences contain NLS sequences, they'll be labelled as such for now
labels = torch.ones(len(df), dtype=torch.long)  # Assume all sequences have NLS motifs, so label them as 1


In [50]:

# Split dataset into train and test sets
X_train, X_test, y_train, y_test = train_test_split(input_ids, labels, test_size=0.2, random_state=831)


In [55]:
# DataLoaders for training and test datasets
batch_size = 8
train_dataset = TensorDataset(X_train, y_train)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = TensorDataset(X_test, y_test)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [58]:
# Model initialization, optimizer learning rate defined, and cross-entropy loss function
# Number of classes set to two for straightforward binary classification
model = BERT_EE(input_dim=input_ids.size(1), motif_vocab_size=max_sequence_length*2, hidden_dim=128, num_classes=2)
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

Downloading:   0%|          | 0.00/1.57G [00:00<?, ?B/s]

Some weights of the model checkpoint at Rostlab/prot_bert were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Training loop
epochs = 5
for epoch in range(epochs):
    model.train()
    total_loss = 0
    for batch in train_dataloader:
        input_ids, labels = batch
        optimizer.zero_grad()
        logits = model(input_ids, attention_masks, motif_features)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    avg_train_loss = total_loss / len(train_dataloader)

    print(f"Epoch {epoch + 1}/{epochs}:")
    print(f"  Train Loss: {avg_train_loss:.4f}")

In [None]:
# Test loop
model.eval()
test_preds = []
test_labels = []
with torch.no_grad():
    for batch in test_dataloader:
        input_ids, labels = batch
        logits = model(input_ids)
        _, predicted = torch.max(logits, dim=1)
        test_preds.extend(predicted.tolist())
        test_labels.extend(labels.tolist())

# Calculate test metrics
test_accuracy = accuracy_score(test_labels, test_preds)
test_precision = precision_score(test_labels, test_preds)
test_recall = recall_score(test_labels, test_preds)

print("Test Metrics:")
print(f"  Test Accuracy: {test_accuracy:.4f}")
print(f"  Test Precision: {test_precision:.4f}")
print(f"  Test Recall: {test_recall:.4f}")