In [1]:
import torch
import os
from transformers import EsmTokenizer, EsmForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset

os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"

dataset = load_dataset("GleghornLab/enzyme_kcat")
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
model = EsmForSequenceClassification.from_pretrained("facebook/esm2_t6_8M_UR50D", num_labels=1)

print(dataset['train'].column_names)




Some weights of EsmForSequenceClassification were not initialized from the model checkpoint at facebook/esm2_t6_8M_UR50D and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


['seqs', 'labels']


In [2]:
def tokenize_function(examples):
    return tokenizer(examples['seqs'], padding="max_length", truncation=True, max_length=128)

tokenized_datasets = dataset.map(tokenize_function, batched=True)
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

train_dataset = tokenized_datasets['train']
eval_dataset = tokenized_datasets['valid']

Map:   0%|          | 0/6837 [00:00<?, ? examples/s]

Map:   0%|          | 0/498 [00:00<?, ? examples/s]

Map:   0%|          | 0/469 [00:00<?, ? examples/s]

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=4,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    eval_strategy="epoch",
    fp16=True
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer
)

trainer.train()
trainer.save_model("./saved_model")

Epoch,Training Loss,Validation Loss
0,21.5855,23.969139
2,15.4503,19.254353


In [4]:
eval_results = trainer.evaluate()

print("Evaluation results:")
for key, value in eval_results.items():
    print(f"{key}: {value}")

Evaluation results:
eval_loss: 19.254352569580078
eval_runtime: 9.8647
eval_samples_per_second: 50.483
eval_steps_per_second: 12.671
epoch: 2.9964912280701754
