In [1]:
import torch
import os
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"


tokenizer = BertTokenizer.from_pretrained("Rostlab/prot_bert_bfd")
dataset = load_dataset('GleghornLab/enzyme_kcat')

print(dataset['train'].column_names)


['seqs', 'labels']


In [2]:
max_length = 512

def tokenize_function(examples):
    return tokenizer(examples['seqs'], padding="max_length", truncation=True, max_length=max_length)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['valid'] if 'valid' in tokenized_datasets else tokenized_datasets['train'].train_test_split(test_size=0.1)['test']

In [3]:
model = BertForSequenceClassification.from_pretrained("Rostlab/prot_bert_bfd", num_labels=1)

def compute_loss(model, inputs, labels):
    outputs = model(**inputs)
    loss = torch.nn.functional.mse_loss(outputs.logits.squeeze(), labels)
    return loss

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_bert_bfd and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    eval_strategy="epoch"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer
)

In [None]:
trainer.train()
results = trainer.evaluate()
print("Evaluation Results:", results)