In [1]:
import os
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')
import pickle
import numpy as np
from sklearn.model_selection import StratifiedKFold
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, EarlyStoppingCallback

from mod.mod_text import *

# prepare the input data
with open("./data/MethylationState/bivalent_vs_lys4/train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
with open("./data/MethylationState/bivalent_vs_lys4/test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

# Prepare datasets
train_texts_all, train_labels_all = train_data['desc'], train_data['labels']
test_texts, test_labels = test_data['desc'], test_data['labels']

# Load model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

  Referenced from: <9A4710B9-0DA3-36BB-9129-645F282E64B2> /Users/david/anaconda3/envs/myenv/lib/python3.10/site-packages/torchvision/image.so
  warn(


In [2]:
n_splits = 5  # Number of folds
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=8)

# Initialize a list to store AUC scores for each fold and on the test data
val_auc_scores = []
test_auc_scores = []
output_dirs = []  # Track output directories for each fold

# Create test dataset
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

In [3]:
# Loop over each fold
for fold, (train_index, val_index) in enumerate(kf.split(train_texts_all, train_labels_all)):
    print(f"Fold {fold + 1}/{n_splits}")

    # Split data into training and validation for this fold
    train_texts, val_texts = [train_texts_all[i] for i in train_index], [train_texts_all[i] for i in val_index]
    train_labels, val_labels = [train_labels_all[i] for i in train_index], [train_labels_all[i] for i in val_index]

    # Create PyTorch datasets
    train_dataset = TextDataset(train_texts, train_labels, tokenizer)
    val_dataset = TextDataset(val_texts, val_labels, tokenizer)

    # Define output directory for this fold
    output_dir = f"./results/BivalentLys4/fold_{fold + 1}"
    os.makedirs(output_dir, exist_ok=True)
    output_dirs.append(output_dir)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch", # Save the model after each epoch
        load_best_model_at_end=True, # Load the best model at the end of each fold
        save_total_limit=1, # Keep only the best model checkpoint
        learning_rate=1e-5, 
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=20,
        max_grad_norm=1.0,
        # warmup_ratio=0.1,
        weight_decay=0.01,
        metric_for_best_model="AUC",
        greater_is_better=True
    )

    # Initialize the model and Trainer for this fold
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        eval_metric="AUC",
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Train the model on this fold
    trainer.train()

    # Evaluate on the validation set and save the best model's AUC
    val_results = trainer.evaluate()
    val_auc = val_results["eval_AUC"]
    print(f"Fold {fold + 1} Validation AUC: {val_auc}")
    val_auc_scores.append(val_auc)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1/5


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6829316020011902, 'eval_AUC': 0.8865546218487395, 'eval_runtime': 1.5068, 'eval_samples_per_second': 20.573, 'eval_steps_per_second': 2.655, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6652384400367737, 'eval_AUC': 0.9117647058823528, 'eval_runtime': 0.1669, 'eval_samples_per_second': 185.709, 'eval_steps_per_second': 23.962, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6272950172424316, 'eval_AUC': 0.9243697478991597, 'eval_runtime': 0.1618, 'eval_samples_per_second': 191.584, 'eval_steps_per_second': 24.721, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.572930634021759, 'eval_AUC': 0.9201680672268908, 'eval_runtime': 0.1721, 'eval_samples_per_second': 180.137, 'eval_steps_per_second': 23.244, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5278266668319702, 'eval_AUC': 0.9159663865546218, 'eval_runtime': 0.132, 'eval_samples_per_second': 234.822, 'eval_steps_per_second': 30.3, 'epoch': 5.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4841545820236206, 'eval_AUC': 0.9243697478991597, 'eval_runtime': 0.1455, 'eval_samples_per_second': 213.097, 'eval_steps_per_second': 27.496, 'epoch': 6.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4475550353527069, 'eval_AUC': 0.9327731092436975, 'eval_runtime': 0.1349, 'eval_samples_per_second': 229.72, 'eval_steps_per_second': 29.641, 'epoch': 7.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4237571358680725, 'eval_AUC': 0.9327731092436975, 'eval_runtime': 0.1533, 'eval_samples_per_second': 202.28, 'eval_steps_per_second': 26.101, 'epoch': 8.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.405909925699234, 'eval_AUC': 0.9327731092436975, 'eval_runtime': 0.1652, 'eval_samples_per_second': 187.602, 'eval_steps_per_second': 24.207, 'epoch': 9.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3892882466316223, 'eval_AUC': 0.9369747899159664, 'eval_runtime': 0.1676, 'eval_samples_per_second': 184.962, 'eval_steps_per_second': 23.866, 'epoch': 10.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3790244460105896, 'eval_AUC': 0.9369747899159664, 'eval_runtime': 0.163, 'eval_samples_per_second': 190.182, 'eval_steps_per_second': 24.54, 'epoch': 11.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3744599521160126, 'eval_AUC': 0.9285714285714286, 'eval_runtime': 0.1614, 'eval_samples_per_second': 192.022, 'eval_steps_per_second': 24.777, 'epoch': 12.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.368773877620697, 'eval_AUC': 0.9327731092436975, 'eval_runtime': 0.1453, 'eval_samples_per_second': 213.284, 'eval_steps_per_second': 27.521, 'epoch': 13.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3667847216129303, 'eval_AUC': 0.9243697478991596, 'eval_runtime': 0.1633, 'eval_samples_per_second': 189.89, 'eval_steps_per_second': 24.502, 'epoch': 14.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.36478957533836365, 'eval_AUC': 0.9285714285714286, 'eval_runtime': 0.1577, 'eval_samples_per_second': 196.544, 'eval_steps_per_second': 25.36, 'epoch': 15.0}
{'train_runtime': 53.5308, 'train_samples_per_second': 45.581, 'train_steps_per_second': 5.978, 'train_loss': 0.4733455657958984, 'epoch': 15.0}


  0%|          | 0/4 [00:00<?, ?it/s]

Fold 1 Validation AUC: 0.9369747899159664
Fold 2/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6842238903045654, 'eval_AUC': 0.8109243697478992, 'eval_runtime': 0.1356, 'eval_samples_per_second': 228.644, 'eval_steps_per_second': 29.502, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6695904731750488, 'eval_AUC': 0.8865546218487396, 'eval_runtime': 0.132, 'eval_samples_per_second': 234.904, 'eval_steps_per_second': 30.31, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6440141201019287, 'eval_AUC': 0.8487394957983193, 'eval_runtime': 0.1378, 'eval_samples_per_second': 224.949, 'eval_steps_per_second': 29.026, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6049315929412842, 'eval_AUC': 0.8487394957983193, 'eval_runtime': 0.1378, 'eval_samples_per_second': 224.954, 'eval_steps_per_second': 29.026, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5637047290802002, 'eval_AUC': 0.8697478991596638, 'eval_runtime': 0.142, 'eval_samples_per_second': 218.343, 'eval_steps_per_second': 28.173, 'epoch': 5.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5264471769332886, 'eval_AUC': 0.8823529411764706, 'eval_runtime': 0.1379, 'eval_samples_per_second': 224.734, 'eval_steps_per_second': 28.998, 'epoch': 6.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4985487759113312, 'eval_AUC': 0.8823529411764707, 'eval_runtime': 0.1404, 'eval_samples_per_second': 220.735, 'eval_steps_per_second': 28.482, 'epoch': 7.0}
{'train_runtime': 22.1644, 'train_samples_per_second': 110.087, 'train_steps_per_second': 14.438, 'train_loss': 0.6004339626857212, 'epoch': 7.0}


  0%|          | 0/4 [00:00<?, ?it/s]

Fold 2 Validation AUC: 0.8865546218487396
Fold 3/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6780127286911011, 'eval_AUC': 0.8739495798319327, 'eval_runtime': 0.1854, 'eval_samples_per_second': 167.241, 'eval_steps_per_second': 21.579, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6489608883857727, 'eval_AUC': 0.861344537815126, 'eval_runtime': 0.18, 'eval_samples_per_second': 172.225, 'eval_steps_per_second': 22.223, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6098731756210327, 'eval_AUC': 0.8277310924369748, 'eval_runtime': 0.1731, 'eval_samples_per_second': 179.095, 'eval_steps_per_second': 23.109, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5732021927833557, 'eval_AUC': 0.819327731092437, 'eval_runtime': 0.166, 'eval_samples_per_second': 186.756, 'eval_steps_per_second': 24.098, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5361990928649902, 'eval_AUC': 0.8319327731092437, 'eval_runtime': 0.1643, 'eval_samples_per_second': 188.678, 'eval_steps_per_second': 24.346, 'epoch': 5.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.505531907081604, 'eval_AUC': 0.8529411764705882, 'eval_runtime': 0.1357, 'eval_samples_per_second': 228.422, 'eval_steps_per_second': 29.474, 'epoch': 6.0}
{'train_runtime': 18.344, 'train_samples_per_second': 133.013, 'train_steps_per_second': 17.444, 'train_loss': 0.5927296876907349, 'epoch': 6.0}


  0%|          | 0/4 [00:00<?, ?it/s]

Fold 3 Validation AUC: 0.8739495798319327
Fold 4/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6840149164199829, 'eval_AUC': 0.8506787330316743, 'eval_runtime': 1.2777, 'eval_samples_per_second': 23.48, 'eval_steps_per_second': 3.131, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6650834679603577, 'eval_AUC': 0.9321266968325792, 'eval_runtime': 0.1485, 'eval_samples_per_second': 202.035, 'eval_steps_per_second': 26.938, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6306244730949402, 'eval_AUC': 0.9638009049773756, 'eval_runtime': 0.1451, 'eval_samples_per_second': 206.782, 'eval_steps_per_second': 27.571, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.579940676689148, 'eval_AUC': 0.9638009049773756, 'eval_runtime': 0.1353, 'eval_samples_per_second': 221.653, 'eval_steps_per_second': 29.554, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5244490504264832, 'eval_AUC': 0.9773755656108598, 'eval_runtime': 0.1538, 'eval_samples_per_second': 195.042, 'eval_steps_per_second': 26.006, 'epoch': 5.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.467013418674469, 'eval_AUC': 0.9819004524886878, 'eval_runtime': 0.1352, 'eval_samples_per_second': 221.897, 'eval_steps_per_second': 29.586, 'epoch': 6.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4141673445701599, 'eval_AUC': 0.9909502262443439, 'eval_runtime': 0.132, 'eval_samples_per_second': 227.245, 'eval_steps_per_second': 30.299, 'epoch': 7.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3694673180580139, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1392, 'eval_samples_per_second': 215.449, 'eval_steps_per_second': 28.727, 'epoch': 8.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3422618806362152, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1329, 'eval_samples_per_second': 225.792, 'eval_steps_per_second': 30.106, 'epoch': 9.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3132157623767853, 'eval_AUC': 1.0, 'eval_runtime': 0.1338, 'eval_samples_per_second': 224.182, 'eval_steps_per_second': 29.891, 'epoch': 10.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.285976380109787, 'eval_AUC': 1.0, 'eval_runtime': 0.1365, 'eval_samples_per_second': 219.83, 'eval_steps_per_second': 29.311, 'epoch': 11.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.27125754952430725, 'eval_AUC': 1.0, 'eval_runtime': 0.1382, 'eval_samples_per_second': 217.014, 'eval_steps_per_second': 28.935, 'epoch': 12.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.25487223267555237, 'eval_AUC': 1.0, 'eval_runtime': 0.1361, 'eval_samples_per_second': 220.368, 'eval_steps_per_second': 29.382, 'epoch': 13.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.2418752759695053, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1389, 'eval_samples_per_second': 215.925, 'eval_steps_per_second': 28.79, 'epoch': 14.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.2373967468738556, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1383, 'eval_samples_per_second': 216.846, 'eval_steps_per_second': 28.913, 'epoch': 15.0}
{'train_runtime': 54.5893, 'train_samples_per_second': 45.064, 'train_steps_per_second': 5.862, 'train_loss': 0.5077909151713054, 'epoch': 15.0}


  0%|          | 0/4 [00:00<?, ?it/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 4 Validation AUC: 1.0
Fold 5/5


  0%|          | 0/320 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6817789673805237, 'eval_AUC': 0.9140271493212669, 'eval_runtime': 0.1381, 'eval_samples_per_second': 217.304, 'eval_steps_per_second': 28.974, 'epoch': 1.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6664181351661682, 'eval_AUC': 0.923076923076923, 'eval_runtime': 0.1454, 'eval_samples_per_second': 206.284, 'eval_steps_per_second': 27.504, 'epoch': 2.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.6318687200546265, 'eval_AUC': 0.8959276018099548, 'eval_runtime': 0.1415, 'eval_samples_per_second': 211.971, 'eval_steps_per_second': 28.263, 'epoch': 3.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5875731706619263, 'eval_AUC': 0.8868778280542987, 'eval_runtime': 0.1404, 'eval_samples_per_second': 213.64, 'eval_steps_per_second': 28.485, 'epoch': 4.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.5404641628265381, 'eval_AUC': 0.9049773755656108, 'eval_runtime': 0.1474, 'eval_samples_per_second': 203.499, 'eval_steps_per_second': 27.133, 'epoch': 5.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4967394471168518, 'eval_AUC': 0.918552036199095, 'eval_runtime': 0.1402, 'eval_samples_per_second': 213.995, 'eval_steps_per_second': 28.533, 'epoch': 6.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.46095308661460876, 'eval_AUC': 0.9230769230769231, 'eval_runtime': 0.1345, 'eval_samples_per_second': 223.002, 'eval_steps_per_second': 29.734, 'epoch': 7.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.437398225069046, 'eval_AUC': 0.9230769230769231, 'eval_runtime': 0.1372, 'eval_samples_per_second': 218.71, 'eval_steps_per_second': 29.161, 'epoch': 8.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.4192701280117035, 'eval_AUC': 0.9276018099547512, 'eval_runtime': 0.142, 'eval_samples_per_second': 211.245, 'eval_steps_per_second': 28.166, 'epoch': 9.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.40738430619239807, 'eval_AUC': 0.9230769230769231, 'eval_runtime': 0.1448, 'eval_samples_per_second': 207.225, 'eval_steps_per_second': 27.63, 'epoch': 10.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3955639898777008, 'eval_AUC': 0.9276018099547512, 'eval_runtime': 0.1384, 'eval_samples_per_second': 216.752, 'eval_steps_per_second': 28.9, 'epoch': 11.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.390997052192688, 'eval_AUC': 0.9276018099547512, 'eval_runtime': 0.1741, 'eval_samples_per_second': 172.285, 'eval_steps_per_second': 22.971, 'epoch': 12.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.3869172930717468, 'eval_AUC': 0.9276018099547512, 'eval_runtime': 0.1761, 'eval_samples_per_second': 170.319, 'eval_steps_per_second': 22.709, 'epoch': 13.0}


  0%|          | 0/4 [00:00<?, ?it/s]

{'eval_loss': 0.38332170248031616, 'eval_AUC': 0.9230769230769231, 'eval_runtime': 0.1747, 'eval_samples_per_second': 171.739, 'eval_steps_per_second': 22.899, 'epoch': 14.0}
{'train_runtime': 53.5252, 'train_samples_per_second': 45.96, 'train_steps_per_second': 5.978, 'train_loss': 0.5007151194981166, 'epoch': 14.0}


  0%|          | 0/4 [00:00<?, ?it/s]

Fold 5 Validation AUC: 0.9276018099547512


In [4]:
# Calculate mean and standard deviation for validation AUC scores
mean_val_auc = np.mean(val_auc_scores)
std_val_auc = np.std(val_auc_scores)

# Print the results
print(f"Validation AUC: Mean = {mean_val_auc:.4f}, Standard Deviation = {std_val_auc:.4f}")

Validation AUC: Mean = 0.9250, Standard Deviation = 0.0444


In [5]:
best_fold_idx = np.argmax(val_auc_scores)
best_model_dir = output_dirs[best_fold_idx]  # Directory of the best model
print(f"Best model found in fold {best_fold_idx + 1} with Validation AUC: {val_auc_scores[best_fold_idx]}")


Best model found in fold 4 with Validation AUC: 1.0


In [7]:
# use the best model and do the final training
best_model_dir = './results/BivalentLys4/fold_4/checkpoint-160'
best_model = AutoModelForSequenceClassification.from_pretrained(best_model_dir)

full_train_dataset = TextDataset(train_texts_all, train_labels_all, tokenizer)

# Define training arguments for the final training phase
final_training_args = TrainingArguments(
    output_dir="./LongShortTF/final_model",       # Directory to save the final model
    evaluation_strategy="no",         # No evaluation during training
    save_strategy="no",            # Save the model at each epoch
    save_total_limit=1,               # Keep only the last checkpoint to save storage
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=10000,              # Minimize logging output
    report_to="none"                  # Disable logging to external tools
)

# Initialize the Trainer with the full dataset and final training arguments
trainer = Trainer(
    model=best_model,
    args=final_training_args,
    train_dataset=full_train_dataset
)

# Evaluate on the test set
test_results = trainer.predict(test_dataset)
# Calculate AUC on the test data
test_probs = torch.nn.functional.softmax(torch.tensor(test_results.predictions), dim=1)[:, 1].numpy()
test_auc = roc_auc_score(test_results.label_ids, test_probs)
print(f"Test AUC with the best model: {test_auc}")



  0%|          | 0/3 [00:00<?, ?it/s]

Test AUC with the best model: 0.875
