In [1]:
import os
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')
import pickle
import numpy as np
from sklearn.model_selection import StratifiedKFold
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, EarlyStoppingCallback

from mod.mod_text import *


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# prepare the input data
with open("./data/MethylationState/bivalent_vs_no_methyl/train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
with open("./data/MethylationState/bivalent_vs_no_methyl/test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

# Prepare datasets
train_texts_all, train_labels_all = train_data['desc'], train_data['labels']
test_texts, test_labels = test_data['desc'], test_data['labels']

# Load model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [3]:
n_splits = 5  # Number of folds
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=10)

# Initialize a list to store AUC scores for each fold and on the test data
val_auc_scores = []
output_dirs = []  # Track output directories for each fold

# Create test dataset
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

In [4]:
# Loop over each fold
for fold, (train_index, val_index) in enumerate(kf.split(train_texts_all, train_labels_all)):
    print(f"Fold {fold + 1}/{n_splits}")

    # Split data into training and validation for this fold
    train_texts, val_texts = [train_texts_all[i] for i in train_index], [train_texts_all[i] for i in val_index]
    train_labels, val_labels = [train_labels_all[i] for i in train_index], [train_labels_all[i] for i in val_index]

    # Create PyTorch datasets
    train_dataset = TextDataset(train_texts, train_labels, tokenizer)
    val_dataset = TextDataset(val_texts, val_labels, tokenizer)

    # Define output directory for this fold
    output_dir = f"/Volumes/Raven/Research/FinetuneEmbed/results/BivalentNoMethyl/fold_{fold + 1}"
    os.makedirs(output_dir, exist_ok=True)
    output_dirs.append(output_dir)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch", # Save the model after each epoch
        load_best_model_at_end=True, # Load the best model at the end of each fold
        save_total_limit=1, # Keep only the best model checkpoint
        learning_rate=1e-4, 
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=20,
        max_grad_norm=1.0,
        weight_decay=0.01,
        metric_for_best_model="AUC",
        greater_is_better=True
    )

    # Initialize the model and Trainer for this fold
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        eval_metric="AUC",
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Train the model on this fold
    trainer.train()

    # Evaluate on the validation set and save the best model's AUC
    val_results = trainer.evaluate()
    val_auc = val_results["eval_AUC"]
    print(f"Fold {fold + 1} Validation AUC: {val_auc}")
    val_auc_scores.append(val_auc)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1/5


                                                
  5%|▌         | 12/240 [00:07<00:39,  5.71it/s]

{'eval_loss': 0.4805563986301422, 'eval_AUC': 0.9411764705882353, 'eval_runtime': 0.3875, 'eval_samples_per_second': 61.943, 'eval_steps_per_second': 7.743, 'epoch': 1.0}


                                                
 10%|█         | 24/240 [00:13<00:35,  6.04it/s]

{'eval_loss': 0.30268394947052, 'eval_AUC': 0.9411764705882353, 'eval_runtime': 0.0866, 'eval_samples_per_second': 277.123, 'eval_steps_per_second': 34.64, 'epoch': 2.0}


                                                
 15%|█▌        | 36/240 [00:18<00:32,  6.37it/s]

{'eval_loss': 0.22963207960128784, 'eval_AUC': 0.9495798319327731, 'eval_runtime': 0.0893, 'eval_samples_per_second': 268.814, 'eval_steps_per_second': 33.602, 'epoch': 3.0}


                                                
 20%|██        | 48/240 [00:22<00:30,  6.28it/s]

{'eval_loss': 0.37562084197998047, 'eval_AUC': 0.9411764705882353, 'eval_runtime': 0.0862, 'eval_samples_per_second': 278.403, 'eval_steps_per_second': 34.8, 'epoch': 4.0}


                                                
 25%|██▌       | 60/240 [00:27<00:27,  6.46it/s]

{'eval_loss': 0.3860814869403839, 'eval_AUC': 0.9243697478991597, 'eval_runtime': 0.0851, 'eval_samples_per_second': 282.168, 'eval_steps_per_second': 35.271, 'epoch': 5.0}


                                                
 30%|███       | 72/240 [00:31<00:25,  6.63it/s]

{'eval_loss': 0.39129891991615295, 'eval_AUC': 0.9495798319327731, 'eval_runtime': 0.0849, 'eval_samples_per_second': 282.649, 'eval_steps_per_second': 35.331, 'epoch': 6.0}


                                                
 35%|███▌      | 84/240 [00:36<00:24,  6.46it/s]

{'eval_loss': 0.4797574579715729, 'eval_AUC': 0.9243697478991597, 'eval_runtime': 0.0908, 'eval_samples_per_second': 264.271, 'eval_steps_per_second': 33.034, 'epoch': 7.0}


                                                
 40%|████      | 96/240 [00:40<00:22,  6.50it/s]

{'eval_loss': 0.6465464234352112, 'eval_AUC': 0.8907563025210085, 'eval_runtime': 0.0877, 'eval_samples_per_second': 273.559, 'eval_steps_per_second': 34.195, 'epoch': 8.0}


 40%|████      | 96/240 [00:44<01:06,  2.15it/s]


{'train_runtime': 44.5723, 'train_samples_per_second': 43.076, 'train_steps_per_second': 5.385, 'train_loss': 0.2154291272163391, 'epoch': 8.0}


100%|██████████| 3/3 [00:00<00:00, 33.38it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1 Validation AUC: 0.9495798319327731
Fold 2/5


  5%|▌         | 12/240 [00:01<00:32,  7.00it/s]
  5%|▌         | 12/240 [00:01<00:32,  7.00it/s]

{'eval_loss': 0.5260520577430725, 'eval_AUC': 0.8571428571428571, 'eval_runtime': 0.091, 'eval_samples_per_second': 263.612, 'eval_steps_per_second': 32.951, 'epoch': 1.0}


 10%|█         | 24/240 [00:06<00:33,  6.52it/s]
 10%|█         | 24/240 [00:06<00:33,  6.52it/s]

{'eval_loss': 0.44400015473365784, 'eval_AUC': 0.865546218487395, 'eval_runtime': 0.0875, 'eval_samples_per_second': 274.183, 'eval_steps_per_second': 34.273, 'epoch': 2.0}


 15%|█▌        | 36/240 [00:11<00:31,  6.41it/s]
 15%|█▌        | 36/240 [00:11<00:31,  6.41it/s]

{'eval_loss': 0.5394663214683533, 'eval_AUC': 0.8823529411764707, 'eval_runtime': 0.0857, 'eval_samples_per_second': 279.962, 'eval_steps_per_second': 34.995, 'epoch': 3.0}


 20%|██        | 48/240 [00:15<00:29,  6.56it/s]
 20%|██        | 48/240 [00:15<00:29,  6.56it/s]

{'eval_loss': 0.8600875735282898, 'eval_AUC': 0.907563025210084, 'eval_runtime': 0.0848, 'eval_samples_per_second': 283.159, 'eval_steps_per_second': 35.395, 'epoch': 4.0}


 25%|██▌       | 60/240 [00:20<00:27,  6.44it/s]
 25%|██▌       | 60/240 [00:20<00:27,  6.44it/s]

{'eval_loss': 0.8992486596107483, 'eval_AUC': 0.865546218487395, 'eval_runtime': 0.0848, 'eval_samples_per_second': 283.145, 'eval_steps_per_second': 35.393, 'epoch': 5.0}


 30%|███       | 72/240 [00:24<00:25,  6.60it/s]
 30%|███       | 72/240 [00:24<00:25,  6.60it/s]

{'eval_loss': 1.273851990699768, 'eval_AUC': 0.8403361344537815, 'eval_runtime': 0.086, 'eval_samples_per_second': 279.086, 'eval_steps_per_second': 34.886, 'epoch': 6.0}


 35%|███▌      | 84/240 [00:30<00:27,  5.78it/s]
 35%|███▌      | 84/240 [00:30<00:27,  5.78it/s]

{'eval_loss': 1.2908048629760742, 'eval_AUC': 0.8403361344537815, 'eval_runtime': 0.0971, 'eval_samples_per_second': 247.247, 'eval_steps_per_second': 30.906, 'epoch': 7.0}


 40%|████      | 96/240 [00:37<00:24,  5.87it/s]
 40%|████      | 96/240 [00:37<00:24,  5.87it/s]

{'eval_loss': 0.892333984375, 'eval_AUC': 0.8739495798319328, 'eval_runtime': 0.0857, 'eval_samples_per_second': 280.083, 'eval_steps_per_second': 35.01, 'epoch': 8.0}


 45%|████▌     | 108/240 [00:45<00:22,  5.84it/s]
 45%|████▌     | 108/240 [00:45<00:22,  5.84it/s]

{'eval_loss': 0.9490193724632263, 'eval_AUC': 0.8739495798319328, 'eval_runtime': 0.0893, 'eval_samples_per_second': 268.841, 'eval_steps_per_second': 33.605, 'epoch': 9.0}


 45%|████▌     | 108/240 [00:50<01:01,  2.16it/s]


{'train_runtime': 50.0516, 'train_samples_per_second': 38.36, 'train_steps_per_second': 4.795, 'train_loss': 0.17606468553896304, 'epoch': 9.0}


100%|██████████| 3/3 [00:00<00:00, 34.45it/s]


Fold 2 Validation AUC: 0.907563025210084
Fold 3/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▌         | 12/240 [00:01<00:31,  7.21it/s]
  5%|▌         | 12/240 [00:01<00:31,  7.21it/s]

{'eval_loss': 0.46840524673461914, 'eval_AUC': 0.9159663865546218, 'eval_runtime': 0.0866, 'eval_samples_per_second': 277.261, 'eval_steps_per_second': 34.658, 'epoch': 1.0}


 10%|█         | 24/240 [00:07<00:35,  6.09it/s]
 10%|█         | 24/240 [00:07<00:35,  6.09it/s]

{'eval_loss': 0.3150760233402252, 'eval_AUC': 0.9411764705882353, 'eval_runtime': 0.0872, 'eval_samples_per_second': 275.267, 'eval_steps_per_second': 34.408, 'epoch': 2.0}


 15%|█▌        | 36/240 [00:11<00:32,  6.21it/s]
 15%|█▌        | 36/240 [00:11<00:32,  6.21it/s]

{'eval_loss': 0.4414774179458618, 'eval_AUC': 0.8823529411764706, 'eval_runtime': 0.094, 'eval_samples_per_second': 255.29, 'eval_steps_per_second': 31.911, 'epoch': 3.0}


 20%|██        | 48/240 [00:16<00:29,  6.46it/s]
 20%|██        | 48/240 [00:16<00:29,  6.46it/s]

{'eval_loss': 0.5504828095436096, 'eval_AUC': 0.8907563025210085, 'eval_runtime': 0.0867, 'eval_samples_per_second': 276.976, 'eval_steps_per_second': 34.622, 'epoch': 4.0}


 25%|██▌       | 60/240 [00:21<00:28,  6.42it/s]
 25%|██▌       | 60/240 [00:21<00:28,  6.42it/s]

{'eval_loss': 0.8843199610710144, 'eval_AUC': 0.8991596638655461, 'eval_runtime': 0.0873, 'eval_samples_per_second': 275.0, 'eval_steps_per_second': 34.375, 'epoch': 5.0}


 30%|███       | 72/240 [00:25<00:25,  6.58it/s]
 30%|███       | 72/240 [00:25<00:25,  6.58it/s]

{'eval_loss': 1.106290340423584, 'eval_AUC': 0.8487394957983193, 'eval_runtime': 0.0868, 'eval_samples_per_second': 276.393, 'eval_steps_per_second': 34.549, 'epoch': 6.0}


 35%|███▌      | 84/240 [00:30<00:25,  6.18it/s]
 35%|███▌      | 84/240 [00:30<00:25,  6.18it/s]

{'eval_loss': 1.1682220697402954, 'eval_AUC': 0.8487394957983193, 'eval_runtime': 0.0887, 'eval_samples_per_second': 270.678, 'eval_steps_per_second': 33.835, 'epoch': 7.0}


 35%|███▌      | 84/240 [00:35<01:05,  2.37it/s]


{'train_runtime': 35.3972, 'train_samples_per_second': 54.242, 'train_steps_per_second': 6.78, 'train_loss': 0.2047340983436221, 'epoch': 7.0}


100%|██████████| 3/3 [00:00<00:00, 58.48it/s]


Fold 3 Validation AUC: 0.9411764705882353
Fold 4/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▌         | 12/240 [00:01<00:31,  7.16it/s]
  5%|▌         | 12/240 [00:01<00:31,  7.16it/s]

{'eval_loss': 0.45114004611968994, 'eval_AUC': 0.9831932773109244, 'eval_runtime': 0.0895, 'eval_samples_per_second': 268.12, 'eval_steps_per_second': 33.515, 'epoch': 1.0}


 10%|█         | 24/240 [00:06<00:35,  6.03it/s]
 10%|█         | 24/240 [00:07<00:35,  6.03it/s]

{'eval_loss': 0.2645821273326874, 'eval_AUC': 0.9831932773109244, 'eval_runtime': 0.0913, 'eval_samples_per_second': 262.815, 'eval_steps_per_second': 32.852, 'epoch': 2.0}


 15%|█▌        | 36/240 [00:11<00:30,  6.63it/s]
 15%|█▌        | 36/240 [00:11<00:30,  6.63it/s]

{'eval_loss': 0.24395973980426788, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.0863, 'eval_samples_per_second': 278.065, 'eval_steps_per_second': 34.758, 'epoch': 3.0}


 20%|██        | 48/240 [00:15<00:30,  6.21it/s]
 20%|██        | 48/240 [00:15<00:30,  6.21it/s]

{'eval_loss': 0.3705669343471527, 'eval_AUC': 0.9411764705882353, 'eval_runtime': 0.098, 'eval_samples_per_second': 244.988, 'eval_steps_per_second': 30.623, 'epoch': 4.0}


 25%|██▌       | 60/240 [00:19<00:27,  6.46it/s]
 25%|██▌       | 60/240 [00:19<00:27,  6.46it/s]

{'eval_loss': 0.336822509765625, 'eval_AUC': 0.9495798319327731, 'eval_runtime': 0.1193, 'eval_samples_per_second': 201.147, 'eval_steps_per_second': 25.143, 'epoch': 5.0}


 30%|███       | 72/240 [00:25<00:27,  6.01it/s]
 30%|███       | 72/240 [00:25<00:27,  6.01it/s]

{'eval_loss': 0.38627147674560547, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.0955, 'eval_samples_per_second': 251.351, 'eval_steps_per_second': 31.419, 'epoch': 6.0}


 30%|███       | 72/240 [00:29<01:07,  2.48it/s]


{'train_runtime': 29.0219, 'train_samples_per_second': 66.157, 'train_steps_per_second': 8.27, 'train_loss': 0.246666775809394, 'epoch': 6.0}


100%|██████████| 3/3 [00:00<00:00, 55.56it/s]


Fold 4 Validation AUC: 0.9831932773109244
Fold 5/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▌         | 12/240 [00:01<00:32,  7.05it/s]
  5%|▌         | 12/240 [00:01<00:32,  7.05it/s]

{'eval_loss': 0.6011874675750732, 'eval_AUC': 0.6638655462184874, 'eval_runtime': 0.1475, 'eval_samples_per_second': 162.67, 'eval_steps_per_second': 20.334, 'epoch': 1.0}


 10%|█         | 24/240 [00:08<00:35,  6.08it/s]
 10%|█         | 24/240 [00:08<00:35,  6.08it/s]

{'eval_loss': 0.6041375398635864, 'eval_AUC': 0.45378151260504207, 'eval_runtime': 0.0857, 'eval_samples_per_second': 279.952, 'eval_steps_per_second': 34.994, 'epoch': 2.0}


 15%|█▌        | 36/240 [00:13<00:32,  6.32it/s]
 15%|█▌        | 36/240 [00:13<00:32,  6.32it/s]

{'eval_loss': 0.6043354868888855, 'eval_AUC': 0.7815126050420168, 'eval_runtime': 0.0889, 'eval_samples_per_second': 269.988, 'eval_steps_per_second': 33.748, 'epoch': 3.0}


 20%|██        | 48/240 [00:17<00:29,  6.55it/s]
 20%|██        | 48/240 [00:17<00:29,  6.55it/s]

{'eval_loss': 0.6111755967140198, 'eval_AUC': 0.25210084033613445, 'eval_runtime': 0.0848, 'eval_samples_per_second': 282.866, 'eval_steps_per_second': 35.358, 'epoch': 4.0}


 25%|██▌       | 60/240 [00:22<00:27,  6.58it/s]
 25%|██▌       | 60/240 [00:22<00:27,  6.58it/s]

{'eval_loss': 0.6038179397583008, 'eval_AUC': 0.7731092436974789, 'eval_runtime': 0.0845, 'eval_samples_per_second': 284.128, 'eval_steps_per_second': 35.516, 'epoch': 5.0}


 30%|███       | 72/240 [00:26<00:26,  6.27it/s]
 30%|███       | 72/240 [00:26<00:26,  6.27it/s]

{'eval_loss': 0.6054918169975281, 'eval_AUC': 0.722689075630252, 'eval_runtime': 0.0882, 'eval_samples_per_second': 272.078, 'eval_steps_per_second': 34.01, 'epoch': 6.0}


 35%|███▌      | 84/240 [00:30<00:25,  6.23it/s]
 35%|███▌      | 84/240 [00:30<00:25,  6.23it/s]

{'eval_loss': 0.6034266948699951, 'eval_AUC': 0.7815126050420168, 'eval_runtime': 0.1288, 'eval_samples_per_second': 186.264, 'eval_steps_per_second': 23.283, 'epoch': 7.0}


 40%|████      | 96/240 [00:34<00:21,  6.63it/s]
 40%|████      | 96/240 [00:35<00:21,  6.63it/s]

{'eval_loss': 0.60329270362854, 'eval_AUC': 0.7983193277310925, 'eval_runtime': 0.0855, 'eval_samples_per_second': 280.843, 'eval_steps_per_second': 35.105, 'epoch': 8.0}


 45%|████▌     | 108/240 [00:39<00:20,  6.55it/s]
 45%|████▌     | 108/240 [00:39<00:20,  6.55it/s]

{'eval_loss': 0.5966820120811462, 'eval_AUC': 0.7899159663865547, 'eval_runtime': 0.0867, 'eval_samples_per_second': 276.929, 'eval_steps_per_second': 34.616, 'epoch': 9.0}


 50%|█████     | 120/240 [00:44<00:18,  6.41it/s]
 50%|█████     | 120/240 [00:44<00:18,  6.41it/s]

{'eval_loss': 0.5356356501579285, 'eval_AUC': 0.8571428571428572, 'eval_runtime': 0.088, 'eval_samples_per_second': 272.737, 'eval_steps_per_second': 34.092, 'epoch': 10.0}


 55%|█████▌    | 132/240 [00:50<00:17,  6.06it/s]
 55%|█████▌    | 132/240 [00:50<00:17,  6.06it/s]

{'eval_loss': 0.5069840550422668, 'eval_AUC': 0.8151260504201681, 'eval_runtime': 0.0869, 'eval_samples_per_second': 276.142, 'eval_steps_per_second': 34.518, 'epoch': 11.0}


 60%|██████    | 144/240 [00:56<00:15,  6.11it/s]
 60%|██████    | 144/240 [00:56<00:15,  6.11it/s]

{'eval_loss': 0.48066017031669617, 'eval_AUC': 0.8403361344537815, 'eval_runtime': 0.0852, 'eval_samples_per_second': 281.842, 'eval_steps_per_second': 35.23, 'epoch': 12.0}


 65%|██████▌   | 156/240 [01:00<00:12,  6.48it/s]
 65%|██████▌   | 156/240 [01:00<00:12,  6.48it/s]

{'eval_loss': 0.4689815044403076, 'eval_AUC': 0.8319327731092437, 'eval_runtime': 0.0855, 'eval_samples_per_second': 280.84, 'eval_steps_per_second': 35.105, 'epoch': 13.0}


 70%|███████   | 168/240 [01:05<00:11,  6.48it/s]
 70%|███████   | 168/240 [01:05<00:11,  6.48it/s]

{'eval_loss': 0.4869631230831146, 'eval_AUC': 0.8403361344537815, 'eval_runtime': 0.0939, 'eval_samples_per_second': 255.716, 'eval_steps_per_second': 31.965, 'epoch': 14.0}


 75%|███████▌  | 180/240 [01:09<00:09,  6.38it/s]
 75%|███████▌  | 180/240 [01:09<00:09,  6.38it/s]

{'eval_loss': 0.6437397599220276, 'eval_AUC': 0.8403361344537815, 'eval_runtime': 0.0876, 'eval_samples_per_second': 274.001, 'eval_steps_per_second': 34.25, 'epoch': 15.0}


 75%|███████▌  | 180/240 [01:13<00:24,  2.45it/s]


{'train_runtime': 73.3862, 'train_samples_per_second': 26.163, 'train_steps_per_second': 3.27, 'train_loss': 0.5308734469943577, 'epoch': 15.0}


100%|██████████| 3/3 [00:00<00:00, 55.14it/s]

Fold 5 Validation AUC: 0.8571428571428572





In [5]:
# Calculate mean and standard deviation for validation AUC scores
mean_val_auc = np.mean(val_auc_scores)
std_val_auc = np.std(val_auc_scores)

# Print the results
print(f"Validation AUC: Mean = {mean_val_auc:.4f}, Standard Deviation = {std_val_auc:.4f}")

Validation AUC: Mean = 0.9277, Standard Deviation = 0.0427


In [6]:
best_fold_idx = np.argmax(val_auc_scores)
best_model_dir = output_dirs[best_fold_idx]  # Directory of the best model
print(f"Best model found in fold {best_fold_idx + 1} with Validation AUC: {val_auc_scores[best_fold_idx]}")


Best model found in fold 4 with Validation AUC: 0.9831932773109244


In [8]:
# use the best model and do the final training
best_model_dir = '/Volumes/Raven/Research/FinetuneEmbed/results/BivalentNoMethyl/fold_4/checkpoint-12'
best_model = AutoModelForSequenceClassification.from_pretrained(best_model_dir)

full_train_dataset = TextDataset(train_texts_all, train_labels_all, tokenizer)

# Define training arguments for the final training phase
final_training_args = TrainingArguments(
    output_dir="./LongShortTF/final_model",       # Directory to save the final model
    evaluation_strategy="no",         # No evaluation during training
    save_strategy="no",            # Save the model at each epoch
    save_total_limit=1,               # Keep only the last checkpoint to save storage
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=10000,              # Minimize logging output
    report_to="none"                  # Disable logging to external tools
)

# Initialize the Trainer with the full dataset and final training arguments
trainer = Trainer(
    model=best_model,
    args=final_training_args,
    train_dataset=full_train_dataset
)

# Evaluate on the test set
test_results = trainer.predict(test_dataset)
# Calculate AUC on the test data
test_probs = torch.nn.functional.softmax(torch.tensor(test_results.predictions), dim=1)[:, 1].numpy()
test_auc = roc_auc_score(test_results.label_ids, test_probs)
print(f"Test AUC with the best model: {test_auc}")

100%|██████████| 2/2 [00:01<00:00,  1.88it/s]

Test AUC with the best model: 0.7



