# Load model & tokenizer

In [1]:
import os
import torch
from transformers import T5ForConditionalGeneration, T5Tokenizer

# Select visible gpus
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

MODEL_NAME = 'Rostlab/prot_t5_xl_uniref50'
DEVICE = 'cuda:0'
OUTPUT_DIR = f'{MODEL_NAME.split("/")[1]}_finetune'
TRAIN_EPOCHS = 10
SEED = 42
RESUME_FROM_CHECKPOINT=False
CHECKPOINT_STEP = '19082'

pretrained_model_name = os.path.join(OUTPUT_DIR, 'checkpoint-' + CHECKPOINT_STEP) \
    if RESUME_FROM_CHECKPOINT \
    else MODEL_NAME

model = T5ForConditionalGeneration.from_pretrained(pretrained_model_name).to(DEVICE)

model.parallelize()

tokenizer = T5Tokenizer.from_pretrained(MODEL_NAME, do_lower_case=False)
tokenizer.add_tokens(['0', '1']) # Add tokens for labels

model.resize_token_embeddings(len(tokenizer))

print('output_dir:', OUTPUT_DIR)
print('CUDA_VISIBLE_DEVICES:', os.environ['CUDA_VISIBLE_DEVICES'])
print('model:', model.device)

output_dir: prot_t5_xl_uniref50_finetune
CUDA_VISIBLE_DEVICES: 0,1,2,3
model: cuda:0


# Define dataset

In [2]:
import datasets
from datasets import load_dataset
datasets.logging.set_verbosity_error()

raw_train_dataset = load_dataset(
    'csv', data_files='train.csv'
)['train'].shuffle(seed=SEED).train_test_split(test_size=0.1)

raw_test_dataset = load_dataset(
    'csv', data_files={'test': 'test.csv'},
    split='test',
)

print(raw_train_dataset)
print(raw_test_dataset)

  0%|          | 0/1 [00:00<?, ?it/s]

DatasetDict({
    train: Dataset({
        features: ['id', 'epitope_seq', 'antigen_seq', 'antigen_code', 'start_position', 'end_position', 'number_of_tested', 'number_of_responses', 'assay_method_technique', 'assay_group', 'disease_type', 'disease_state', 'reference_date', 'reference_journal', 'reference_title', 'reference_IRI', 'qualitative_label', 'label'],
        num_rows: 171729
    })
    test: Dataset({
        features: ['id', 'epitope_seq', 'antigen_seq', 'antigen_code', 'start_position', 'end_position', 'number_of_tested', 'number_of_responses', 'assay_method_technique', 'assay_group', 'disease_type', 'disease_state', 'reference_date', 'reference_journal', 'reference_title', 'reference_IRI', 'qualitative_label', 'label'],
        num_rows: 19082
    })
})
Dataset({
    features: ['id', 'epitope_seq', 'antigen_seq', 'antigen_code', 'start_position', 'end_position', 'number_of_tested', 'number_of_responses', 'assay_method_technique', 'assay_group', 'disease_type', 'disease_sta

# Dataset Preprocessing

In [3]:
import datasets
datasets.logging.disable_progress_bar()

def dataset_preproc(dataset: datasets.Dataset, num_proc=4):
    if 'label' in dataset.column_names:
        dataset = dataset.map(
            lambda x: tokenizer(str(x['label'])),
            num_proc=num_proc,
        )
        dataset = dataset.rename_column('input_ids', 'labels')
        # dataset = dataset.rename_column('attention_mask', 'decoder_attention_mask')
        # dataset = dataset.add_column('decoder_input_ids', column=dataset['labels'])
    dataset = dataset.map(
        lambda x: tokenizer(' '.join(x['epitope_seq'])),
        num_proc=num_proc,
    )
    dataset = dataset.remove_columns(list(set(dataset.column_names) - set(['input_ids', 'labels'])))
    return dataset

train_dataset = dataset_preproc(raw_train_dataset['train'], num_proc=16)
print(train_dataset)

valid_dataset = dataset_preproc(raw_train_dataset['test'], num_proc=16)
print(valid_dataset)

test_dataset = dataset_preproc(raw_test_dataset, num_proc=16)
print(test_dataset)

Dataset({
    features: ['labels', 'input_ids'],
    num_rows: 171729
})
Dataset({
    features: ['labels', 'input_ids'],
    num_rows: 19082
})
Dataset({
    features: ['input_ids'],
    num_rows: 120944
})


In [4]:
[tokenizer.decode(train_dataset[0][k]) for k in ['labels', 'input_ids']]

['0 </s>', 'V S G K E E M E R S S E E E G</s>']

# Define trainer & Train

In [5]:
import os
import torch
import gc

from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
from sklearn.metrics import f1_score

## Wipe memory
gc.collect()
torch.cuda.empty_cache()

training_args = Seq2SeqTrainingArguments(
    output_dir=OUTPUT_DIR,
    num_train_epochs=1,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    evaluation_strategy='epoch',
    report_to=None,
    logging_strategy='epoch',
    log_level='warning',
    save_strategy='epoch',
    save_total_limit=2,
    metric_for_best_model='f1-score',
    load_best_model_at_end=True,
    # push_to_hub=True,
    hub_private_repo=True,
    auto_find_batch_size=True,
)

data_collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    label_pad_token_id=tokenizer.pad_token_id
)

print('output_dir:', OUTPUT_DIR)
print('CUDA_VISIBLE_DEVICES:', os.environ['CUDA_VISIBLE_DEVICES'])
print('model:', model.device)
print('trainer:', training_args.device)


def f1_score_metrics(eval_pred):
    y_pred = eval_pred.predictions[0].argmax(-1)[:, 0]
    y_true = eval_pred.label_ids[:, 0]
    score = f1_score(y_true, y_pred, average='macro')
    return {'f1-score': score}

trainer = Seq2SeqTrainer(
    model,
    training_args,
    data_collator=data_collator,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    compute_metrics=f1_score_metrics,
)

train_output = trainer.train(resume_from_checkpoint=RESUME_FROM_CHECKPOINT)

output_dir: prot_t5_xl_uniref50_finetune
CUDA_VISIBLE_DEVICES: 0,1,2,3
model: cuda:0
trainer: cuda:0




Epoch,Training Loss,Validation Loss,F1-score
1,4.5041,0.304958,0.428571


TrainOutput(global_step=11, training_loss=4.504134438254616, metrics={'train_runtime': 169.3498, 'train_samples_per_second': 1.016, 'train_steps_per_second': 0.065, 'total_flos': 53442499215360.0, 'train_loss': 4.504134438254616, 'epoch': 1.0})

In [None]:
trainer.push_to_hub()

# Model test

In [None]:
test_output = trainer.predict(test_dataset)