In [None]:
import pandas as pd
from Bio import SeqIO
from transformers import BertTokenizer, BertForTokenClassification
from torch.utils.data import DataLoader, Dataset
import torch


binding_sites_df = pd.read_excel("binding_sites_info.xlsx")


binding_sites_info = {}
for _, row in binding_sites_df.iterrows():
    binding_sites_info[row['Identifier']] = list(range(row['Start Index'], row['End Index'] + 1))


tokenizer = BertTokenizer.from_pretrained('Rostlab/prot_bert', do_lower_case=False)

def preprocess_and_align_labels(fasta_path, binding_sites_info):
    tokenized_inputs = []
    labels = []

    for record in SeqIO.parse(fasta_path, "fasta"):
        seq_id = record.id
        sequence = str(record.seq)
        encoded_input = tokenizer(sequence, truncation=True, padding='max_length', max_length=512, return_tensors="pt")
        
        # Initialize labels for each token as 'O' (Outside of binding site)
        seq_labels = ['O'] * len(encoded_input['input_ids'][0])
        
        
        if seq_id in binding_sites_info:
            for index in binding_sites_info[seq_id]:
                if index < len(seq_labels):
                    seq_labels[index] = 'B'  # 'B' for Binding site
        
        # Convert labels to numeric IDs
        label_ids = [1 if label=='B' else 0 for label in seq_labels]  # 1 for 'B', 0 for 'O'
        
        tokenized_inputs.append(encoded_input)
        labels.append(torch.tensor(label_ids))

    return tokenized_inputs, labels

fasta_path = "DNA_Gyrase_UniProt_Cleaned.fasta"
tokenized_inputs, labels = preprocess_and_align_labels(fasta_path, binding_sites_info)


In [None]:
class ProteinDataset(Dataset):
    def __init__(self, tokenized_inputs, labels):
        self.tokenized_inputs = tokenized_inputs
        self.labels = labels

    def __len__(self):
        return len(self.tokenized_inputs)

    def __getitem__(self, idx):
        item = {key: val[0] for key, val in self.tokenized_inputs[idx].items()}
        item['labels'] = self.labels[idx]
        return item

dataset = ProteinDataset(tokenized_inputs, labels)


In [None]:
from transformers import BertForTokenClassification, AdamW

model = BertForTokenClassification.from_pretrained('Rostlab/prot_bert', num_labels=2)  # 2 labels: 'O' and 'B'


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


optimizer = AdamW(model.parameters(), lr=5e-5)


data_loader = DataLoader(dataset, batch_size=4, shuffle=True)

model.train()
for batch in data_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(**batch)
    loss = outputs.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

model_path = "C:\Users\Asus\Documents\University Of Florida\NLP\protbert model"
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
