In [None]:
# Import the necessary libraries
from Bio import SeqIO
from transformers import BertTokenizerFast, BertConfig, BertForMaskedLM, DataCollatorForLanguageModeling, Trainer, TrainingArguments
import torch
from torch.utils.data import Dataset

In [None]:
# Check if CUDA is available
print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs:", torch.cuda.device_count())

In [None]:
# Load sequences from the FASTA file
fasta_file = "./data/uniprot_sprot.fasta"
sequences = [str(record.seq) for record in SeqIO.parse(fasta_file, "fasta")]

print(f"Loaded {len(sequences)} sequences")
print("Example:", sequences[0])

In [None]:
# Define the amino acid vocabulary and special tokens for BERT
amino_acids = "ACDEFGHIKLMNPQRSTVWYUOBZX" # standard 20 amino acids + 5 non-standard amino acids (U, O, B, Z, X)
special_tokens = ["[UNK]", "[PAD]", "[CLS]", "[SEP]", "[MASK]"]  # special tokens for BERT
vocab_dict = {token: idx for idx, token in enumerate(list(amino_acids) + special_tokens)}

# Initialize the BERT tokenizer with the custom vocabulary
tokeniser = BertTokenizerFast(
    vocab=vocab_dict,
    unk_token="[UNK]",
    pad_token="[PAD]",
    cls_token="[CLS]",
    sep_token="[SEP]",
    mask_token="[MASK]",
)

# Tokenise the sequences
encodings = tokeniser(sequences, return_tensors="pt", truncation=True, padding=True)
print(encodings["input_ids"].shape)

In [None]:
# Create a custom dataset class for the tokenised sequences
class ProteinDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings
    def __len__(self):
        return self.encodings["input_ids"].shape[0]
    def __getitem__(self, idx):
        return {key: val[idx] for key, val in self.encodings.items()}

# Create the dataset from encodings
dataset = ProteinDataset(encodings)

In [None]:
# BERT configuration and model initialisation
config = BertConfig(
    vocab_size=len(tokeniser),
    hidden_size=256,
    num_hidden_layers=4,
    num_attention_heads=4,
    max_position_embeddings=512,
    pad_token_id=tokeniser.pad_token_id,
)

model = BertForMaskedLM(config)

In [None]:
# Define the data collator for masked language modeling
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokeniser, mlm=True, mlm_probability=0.15
)

In [None]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./mlm_bert",
    num_train_epochs=1,
    per_device_train_batch_size=8,
    save_steps=500,
    save_total_limit=2,
    logging_steps=100,
    learning_rate=5e-5,
    weight_decay=0.01,
    remove_unused_columns=False,
)

# Initialise the Trainer and start training
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=data_collator,
)

trainer.train()

In [None]:
# Save the trained model and tokeniser
model.save_pretrained("./mlm_bert")
tokeniser.save_pretrained("./mlm_bert")