In [22]:
import pandas as pd
from Bio import SeqIO
from transformers import BertTokenizer
from torch.utils.data import DataLoader, Dataset, random_split
import torch

class ProteinDataset(Dataset):
    def __init__(self, tokenized_inputs, labels):
        self.tokenized_inputs = tokenized_inputs
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: val.squeeze() for key, val in self.tokenized_inputs[idx].items()}
        item['labels'] = self.labels[idx]
        return item


binding_sites_df = pd.read_excel("binding_sites_info.xlsx")


binding_sites_info = {}
for _, row in binding_sites_df.iterrows():
    binding_sites_info[row['Identifier']] = list(range(row['Start Index'], row['End Index'] + 1))

tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)

def preprocess_and_align_labels(fasta_path, binding_sites_info, tokenizer):
    tokenized_inputs = []
    labels = []
    for record in SeqIO.parse(fasta_path, "fasta"):
        seq_id = record.id
        sequence = str(record.seq)
        encoded_input = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")
        seq_labels = ['O'] * len(encoded_input['input_ids'][0])
        if seq_id in binding_sites_info:
            for index in binding_sites_info[seq_id]:
                if index < len(seq_labels):
                    seq_labels[index] = 'B'
        label_ids = [1 if label == 'B' else 0 for label in seq_labels]
        tokenized_inputs.append(encoded_input)
        labels.append(torch.tensor(label_ids, dtype=torch.long))
    return tokenized_inputs, labels

# Load data and labels
fasta_path = "DNA_Gyrase_UniProt_Cleaned.fasta"
tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)
tokenized_inputs, labels = preprocess_and_align_labels(fasta_path, binding_sites_info, tokenizer)

# Create Dataset
dataset = ProteinDataset(tokenized_inputs, labels)

# Splitting the dataset
train_size = int(0.7 * len(dataset))
valid_size = int(0.15 * len(dataset))
test_size = len(dataset) - train_size - valid_size

train_dataset, valid_dataset, test_dataset = random_split(dataset, [train_size, valid_size, test_size])

# Creating DataLoader for each dataset
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=4, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=4, shuffle=False)


In [23]:
from transformers import BertForTokenClassification, AdamW

model = BertForTokenClassification.from_pretrained('Rostlab/prot_bert', num_labels=2)  # 2 labels: 'O' and 'B'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)
model.to(device)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at Rostlab/prot_bert and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


cuda


BertForTokenClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30, 1024, padding_idx=0)
      (position_embeddings): Embedding(40000, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-29): 30 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), e

In [None]:
from transformers import AdamW
import torch

# Assuming model, train_loader, and device are already defined and properly initialized

optimizer = AdamW(model.parameters(), lr=5e-5)
model.train()
num_epochs = 3  # Define the number of epochs

def calculate_accuracy(logits, labels):
    predictions = torch.argmax(logits, dim=-1)
    correct_predictions = torch.eq(predictions, labels).float()
    return correct_predictions.mean()

for epoch in range(num_epochs):
    total_loss, total_accuracy = 0.0, 0.0
    num_batches = 0
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        logits = outputs.logits
        accuracy = calculate_accuracy(logits, batch['labels'])

        total_loss += loss.item()
        total_accuracy += accuracy.item()
        num_batches += 1
    
    epoch_loss = total_loss / num_batches
    epoch_accuracy = total_accuracy / num_batches
    print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {epoch_accuracy:.4f}')

model.save_pretrained('protbert_trained')
tokenizer.save_pretrained('protbert_tokenizer')




Epoch 1/3, Loss: 0.0049, Accuracy: 1.0000
Epoch 2/3, Loss: 0.0003, Accuracy: 1.0000


In [None]:
def evaluate(model, valid_loader):
    model.eval()
    total_loss, total_accuracy = 0, 0
    for batch in valid_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        total_loss += loss.item()
        logits = outputs.logits
        accuracy = calculate_accuracy(logits, batch['labels'])
        predictions = torch.argmax(logits, dim=-1)
        labels = batch['labels']
        total_accuracy += (predictions == labels).float().mean()
        
    return total_loss / len(valid_loader), total_accuracy / len(valid_loader)

val_loss, val_accuracy = evaluate(model, valid_loader)
print(f"Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}")