In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from transformers import EarlyStoppingCallback, AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from evaluate import load
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import DataCollatorWithPadding
import os
from accelerate import Accelerator
from peft import get_peft_config, PeftModel, PeftConfig, get_peft_model, LoraConfig, TaskType, prepare_model_for_kbit_training
from data_prepare import make_folds

In [3]:
from HF_utils import ClearMLCallback

In [5]:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "true"

In [6]:
accelerator = Accelerator()

In [24]:
def lora_init(model_checkpoint):
    base_model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)
    config = LoraConfig(
        task_type=TaskType.SEQ_CLS,
        r=16,
        lora_alpha=16,
        target_modules=["q", "k", "v"], # "wi_0", "wi_1", "o"
        lora_dropout=0.1,
        bias="all", # lora_only
    )
    lora_model = get_peft_model(base_model, config)
    return accelerator.prepare(lora_model)

In [18]:
model_checkpoint = "ElnaggarLab/ankh-base"

In [8]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)



In [9]:
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=2)

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at ElnaggarLab/ankh-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [25]:
model = lora_init(model_checkpoint)

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at ElnaggarLab/ankh-base and are newly initialized: ['classification_head.dense.bias', 'classification_head.dense.weight', 'classification_head.out_proj.bias', 'classification_head.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [20]:
# for name, module in model.named_modules():
#     print(name, ":", module)

In [26]:
number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
number_parameters // 1e6

7.0

In [8]:
# number_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
# number_parameters // 1e6

736.0

In [27]:
df = pd.read_csv("../data/splits/train_p3_pdb2272.csv")

In [28]:
train_dfs, valid_dfs = make_folds(df)

In [29]:
train = train_dfs[0]
valid = valid_dfs[0]

In [23]:
train_sequences = train["sequence"].tolist()
train_labels = train["label"].tolist()

valid_sequences = valid["sequence"].tolist()
valid_labels = valid["label"].tolist()

In [24]:
train_tokenized = tokenizer(train_sequences)
valid_tokenized = tokenizer(valid_sequences)

In [25]:
batch_size = 64

In [26]:
train_dataset = Dataset.from_dict(train_tokenized)
valid_dataset = Dataset.from_dict(valid_tokenized)

In [27]:
train_dataset = train_dataset.add_column("labels", train_labels)
valid_dataset = valid_dataset.add_column("labels", valid_labels)

In [21]:
# lengths = [len(seq) for seq in valid_dataset['input_ids']]
# lengths

In [28]:
# model = accelerator.prepare(model)
train_dataset = accelerator.prepare(train_dataset)
valid_dataset = accelerator.prepare(valid_dataset)

In [29]:
# Load the metric functions
accuracy_metric = load("accuracy")
f1_metric = load("f1")
matthews_metric = load("matthews_correlation")
precision_metric = load("precision")
recall_metric = load("recall")
roc_auc_metric = load("roc_auc")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    # Compute each metric
    accuracy = accuracy_metric.compute(predictions=predictions, references=labels)
    f1 = f1_metric.compute(predictions=predictions, references=labels)
    matthews = matthews_metric.compute(predictions=predictions, references=labels)
    precision = precision_metric.compute(predictions=predictions, references=labels)
    recall = recall_metric.compute(predictions=predictions, references=labels)
    roc_auc = roc_auc_metric.compute(predictions=predictions, references=labels)

    metrics = {
        "accuracy": accuracy["accuracy"],
        "f1": f1["f1"],
        "matthews_correlation": matthews["matthews_correlation"],
        "precision": precision["precision"],
        "recall": recall["recall"],
        "roc_auc": roc_auc["roc_auc"],
    }

    return metrics

In [22]:
model_name = model_checkpoint.split("/")[-1]

In [23]:
args = TrainingArguments(
    output_dir=f"{model_name}-finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    gradient_accumulation_steps=1,  # changed from 1 to 4
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=10,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    seed = 42,
    push_to_hub=False,
)

In [25]:
clearml_callback = ClearMLCallback(task_name="Training Ankh HF")

ClearML Task: created new task id=1f7f29e9799e4741b260beeffe1d84fc
2024-09-02 18:51:58,286 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/45ba7ff7a93646a8a76d1950065cf1d5/experiments/1f7f29e9799e4741b260beeffe1d84fc/output/log


In [27]:
trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
    callbacks=[clearml_callback],
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)


In [28]:
trainer.train()

Unsupported key of type '<class 'int'>' found when connecting dictionary. It will be converted to str
You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ClearML Monitor: Could not detect iteration reporting, falling back to iterations as seconds-from-start
