In [None]:
import os
# Specify the working directory
os.chdir('/Users/david/Desktop/FinetuneEmbed')
import pickle
import numpy as np
from sklearn.model_selection import StratifiedKFold
from transformers import AutoTokenizer, TrainingArguments, AutoModelForSequenceClassification, EarlyStoppingCallback

from mod.mod_text import *

# prepare the input data
with open("./data/MethylationState/bivalent_vs_lys4/train_data.pkl", "rb") as f:
    train_data = pickle.load(f)
with open("./data/MethylationState/bivalent_vs_lys4/test_data.pkl", "rb") as f:
    test_data = pickle.load(f)

# Prepare datasets
train_texts_all, train_labels_all = train_data['desc'], train_data['labels']
test_texts, test_labels = test_data['desc'], test_data['labels']

# Load model and tokenizer
model_name = "sentence-transformers/all-MiniLM-L6-v2"
tokenizer = AutoTokenizer.from_pretrained(model_name)

: 

In [2]:
n_splits = 5  # Number of folds
kf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=8)

# Initialize a list to store AUC scores for each fold and on the test data
val_auc_scores = []
test_auc_scores = []
output_dirs = []  # Track output directories for each fold

# Create test dataset
test_dataset = TextDataset(test_texts, test_labels, tokenizer)

In [3]:
# Loop over each fold
for fold, (train_index, val_index) in enumerate(kf.split(train_texts_all, train_labels_all)):
    print(f"Fold {fold + 1}/{n_splits}")

    # Split data into training and validation for this fold
    train_texts, val_texts = [train_texts_all[i] for i in train_index], [train_texts_all[i] for i in val_index]
    train_labels, val_labels = [train_labels_all[i] for i in train_index], [train_labels_all[i] for i in val_index]

    # Create PyTorch datasets
    train_dataset = TextDataset(train_texts, train_labels, tokenizer)
    val_dataset = TextDataset(val_texts, val_labels, tokenizer)

    # Define output directory for this fold
    output_dir = f"/Volumes/Raven/Research/FinetuneEmbed/results/BivalentLys4/fold_{fold + 1}"
    os.makedirs(output_dir, exist_ok=True)
    output_dirs.append(output_dir)

    # Training arguments
    training_args = TrainingArguments(
        output_dir=output_dir,
        evaluation_strategy="epoch",
        save_strategy="epoch", # Save the model after each epoch
        load_best_model_at_end=True, # Load the best model at the end of each fold
        save_total_limit=1, # Keep only the best model checkpoint
        learning_rate=1e-5, 
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=20,
        max_grad_norm=1.0,
        # warmup_ratio=0.1,
        weight_decay=0.01,
        metric_for_best_model="AUC",
        greater_is_better=True
    )

    # Initialize the model and Trainer for this fold
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

    trainer = CustomTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        compute_metrics=compute_metrics,
        eval_metric="AUC",
        callbacks=[EarlyStoppingCallback(early_stopping_patience=5)]
    )

    # Train the model on this fold
    trainer.train()

    # Evaluate on the validation set and save the best model's AUC
    val_results = trainer.evaluate()
    val_auc = val_results["eval_AUC"]
    print(f"Fold {fold + 1} Validation AUC: {val_auc}")
    val_auc_scores.append(val_auc)


Fold 1/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                                
  5%|▌         | 16/320 [00:11<03:48,  1.33it/s]

{'eval_loss': 0.6809014081954956, 'eval_AUC': 0.7478991596638656, 'eval_runtime': 1.413, 'eval_samples_per_second': 21.94, 'eval_steps_per_second': 2.831, 'epoch': 1.0}


                                                
 10%|█         | 32/320 [00:16<00:40,  7.11it/s]

{'eval_loss': 0.6544880270957947, 'eval_AUC': 0.8067226890756302, 'eval_runtime': 0.1226, 'eval_samples_per_second': 252.923, 'eval_steps_per_second': 32.635, 'epoch': 2.0}


                                                
 15%|█▌        | 48/320 [00:23<00:40,  6.75it/s]

{'eval_loss': 0.6201480031013489, 'eval_AUC': 0.8235294117647058, 'eval_runtime': 0.1741, 'eval_samples_per_second': 178.042, 'eval_steps_per_second': 22.973, 'epoch': 3.0}


                                                
 20%|██        | 64/320 [00:27<00:34,  7.34it/s]

{'eval_loss': 0.5837067365646362, 'eval_AUC': 0.8151260504201681, 'eval_runtime': 0.1225, 'eval_samples_per_second': 253.022, 'eval_steps_per_second': 32.648, 'epoch': 4.0}


                                                
 25%|██▌       | 80/320 [00:32<00:31,  7.57it/s]

{'eval_loss': 0.5504868030548096, 'eval_AUC': 0.8235294117647058, 'eval_runtime': 0.1266, 'eval_samples_per_second': 244.874, 'eval_steps_per_second': 31.597, 'epoch': 5.0}


                                                
 30%|███       | 96/320 [00:37<00:31,  7.08it/s]

{'eval_loss': 0.5189170837402344, 'eval_AUC': 0.8487394957983192, 'eval_runtime': 0.1257, 'eval_samples_per_second': 246.617, 'eval_steps_per_second': 31.821, 'epoch': 6.0}


                                                 
 35%|███▌      | 112/320 [00:42<00:28,  7.34it/s]

{'eval_loss': 0.4918759763240814, 'eval_AUC': 0.8529411764705882, 'eval_runtime': 0.1206, 'eval_samples_per_second': 257.019, 'eval_steps_per_second': 33.164, 'epoch': 7.0}


                                                 
 40%|████      | 128/320 [00:47<00:25,  7.52it/s]

{'eval_loss': 0.47353553771972656, 'eval_AUC': 0.861344537815126, 'eval_runtime': 0.1272, 'eval_samples_per_second': 243.692, 'eval_steps_per_second': 31.444, 'epoch': 8.0}


                                                 
 45%|████▌     | 144/320 [00:52<00:25,  6.78it/s]

{'eval_loss': 0.46294039487838745, 'eval_AUC': 0.8571428571428571, 'eval_runtime': 0.1251, 'eval_samples_per_second': 247.784, 'eval_steps_per_second': 31.972, 'epoch': 9.0}


                                                 
 50%|█████     | 160/320 [00:57<00:21,  7.36it/s]

{'eval_loss': 0.45587044954299927, 'eval_AUC': 0.8571428571428571, 'eval_runtime': 0.1228, 'eval_samples_per_second': 252.425, 'eval_steps_per_second': 32.571, 'epoch': 10.0}


                                                 
 55%|█████▌    | 176/320 [01:03<00:19,  7.52it/s]

{'eval_loss': 0.4504850506782532, 'eval_AUC': 0.8571428571428571, 'eval_runtime': 0.1252, 'eval_samples_per_second': 247.507, 'eval_steps_per_second': 31.936, 'epoch': 11.0}


                                                 
 60%|██████    | 192/320 [01:09<00:17,  7.36it/s]

{'eval_loss': 0.44812050461769104, 'eval_AUC': 0.8529411764705882, 'eval_runtime': 0.1204, 'eval_samples_per_second': 257.524, 'eval_steps_per_second': 33.229, 'epoch': 12.0}


                                                 
 65%|██████▌   | 208/320 [01:15<00:15,  7.44it/s]

{'eval_loss': 0.44656312465667725, 'eval_AUC': 0.8529411764705882, 'eval_runtime': 0.1228, 'eval_samples_per_second': 252.519, 'eval_steps_per_second': 32.583, 'epoch': 13.0}


 65%|██████▌   | 208/320 [01:19<00:42,  2.61it/s]


{'train_runtime': 79.655, 'train_samples_per_second': 30.632, 'train_steps_per_second': 4.017, 'train_loss': 0.4904493185190054, 'epoch': 13.0}


100%|██████████| 4/4 [00:00<00:00, 42.42it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 1 Validation AUC: 0.861344537815126
Fold 2/5


  5%|▌         | 16/320 [00:02<00:40,  7.48it/s]
  5%|▌         | 16/320 [00:02<00:40,  7.48it/s]

{'eval_loss': 0.6747286319732666, 'eval_AUC': 0.8865546218487395, 'eval_runtime': 0.122, 'eval_samples_per_second': 254.18, 'eval_steps_per_second': 32.797, 'epoch': 1.0}


 10%|█         | 32/320 [00:07<00:39,  7.34it/s]
 10%|█         | 32/320 [00:08<00:39,  7.34it/s]

{'eval_loss': 0.6468161344528198, 'eval_AUC': 0.9033613445378151, 'eval_runtime': 0.1244, 'eval_samples_per_second': 249.1, 'eval_steps_per_second': 32.142, 'epoch': 2.0}


 15%|█▌        | 48/320 [00:13<00:38,  7.06it/s]
 15%|█▌        | 48/320 [00:13<00:38,  7.06it/s]

{'eval_loss': 0.5982323884963989, 'eval_AUC': 0.8991596638655462, 'eval_runtime': 0.1225, 'eval_samples_per_second': 252.971, 'eval_steps_per_second': 32.641, 'epoch': 3.0}


 20%|██        | 64/320 [00:18<00:34,  7.36it/s]
 20%|██        | 64/320 [00:18<00:34,  7.36it/s]

{'eval_loss': 0.5516197085380554, 'eval_AUC': 0.8991596638655462, 'eval_runtime': 0.125, 'eval_samples_per_second': 247.904, 'eval_steps_per_second': 31.988, 'epoch': 4.0}


 25%|██▌       | 80/320 [00:25<00:33,  7.15it/s]
 25%|██▌       | 80/320 [00:25<00:33,  7.15it/s]

{'eval_loss': 0.5111310482025146, 'eval_AUC': 0.8991596638655462, 'eval_runtime': 0.1218, 'eval_samples_per_second': 254.618, 'eval_steps_per_second': 32.854, 'epoch': 5.0}


 30%|███       | 96/320 [00:32<00:31,  7.06it/s]
 30%|███       | 96/320 [00:32<00:31,  7.06it/s]

{'eval_loss': 0.4767445921897888, 'eval_AUC': 0.907563025210084, 'eval_runtime': 0.119, 'eval_samples_per_second': 260.607, 'eval_steps_per_second': 33.627, 'epoch': 6.0}


 35%|███▌      | 112/320 [00:37<00:27,  7.47it/s]
 35%|███▌      | 112/320 [00:37<00:27,  7.47it/s]

{'eval_loss': 0.4456208646297455, 'eval_AUC': 0.9117647058823529, 'eval_runtime': 0.1294, 'eval_samples_per_second': 239.617, 'eval_steps_per_second': 30.918, 'epoch': 7.0}


 40%|████      | 128/320 [00:43<00:25,  7.51it/s]
 40%|████      | 128/320 [00:43<00:25,  7.51it/s]

{'eval_loss': 0.42406827211380005, 'eval_AUC': 0.9117647058823529, 'eval_runtime': 0.1207, 'eval_samples_per_second': 256.941, 'eval_steps_per_second': 33.154, 'epoch': 8.0}


 45%|████▌     | 144/320 [00:48<00:22,  7.70it/s]
 45%|████▌     | 144/320 [00:48<00:22,  7.70it/s]

{'eval_loss': 0.40709927678108215, 'eval_AUC': 0.8991596638655461, 'eval_runtime': 0.1243, 'eval_samples_per_second': 249.463, 'eval_steps_per_second': 32.189, 'epoch': 9.0}


 50%|█████     | 160/320 [00:54<00:23,  6.93it/s]
 50%|█████     | 160/320 [00:54<00:23,  6.93it/s]

{'eval_loss': 0.3961624205112457, 'eval_AUC': 0.903361344537815, 'eval_runtime': 0.1251, 'eval_samples_per_second': 247.895, 'eval_steps_per_second': 31.986, 'epoch': 10.0}


 55%|█████▌    | 176/320 [01:00<00:19,  7.48it/s]
 55%|█████▌    | 176/320 [01:00<00:19,  7.48it/s]

{'eval_loss': 0.3912981450557709, 'eval_AUC': 0.903361344537815, 'eval_runtime': 0.123, 'eval_samples_per_second': 252.107, 'eval_steps_per_second': 32.53, 'epoch': 11.0}


 60%|██████    | 192/320 [01:07<00:16,  7.53it/s]
 60%|██████    | 192/320 [01:07<00:16,  7.53it/s]

{'eval_loss': 0.3905177414417267, 'eval_AUC': 0.903361344537815, 'eval_runtime': 0.1232, 'eval_samples_per_second': 251.579, 'eval_steps_per_second': 32.462, 'epoch': 12.0}


 60%|██████    | 192/320 [01:16<00:51,  2.50it/s]


{'train_runtime': 76.6567, 'train_samples_per_second': 31.83, 'train_steps_per_second': 4.174, 'train_loss': 0.5245353778203329, 'epoch': 12.0}


100%|██████████| 4/4 [00:00<00:00, 51.83it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 2 Validation AUC: 0.9117647058823529
Fold 3/5


  5%|▌         | 16/320 [00:02<00:39,  7.76it/s]
  5%|▌         | 16/320 [00:02<00:39,  7.76it/s]

{'eval_loss': 0.6850272417068481, 'eval_AUC': 0.8865546218487396, 'eval_runtime': 0.1227, 'eval_samples_per_second': 252.717, 'eval_steps_per_second': 32.609, 'epoch': 1.0}


 10%|█         | 32/320 [00:08<00:38,  7.55it/s]
 10%|█         | 32/320 [00:08<00:38,  7.55it/s]

{'eval_loss': 0.6693935394287109, 'eval_AUC': 0.9453781512605042, 'eval_runtime': 0.1541, 'eval_samples_per_second': 201.154, 'eval_steps_per_second': 25.955, 'epoch': 2.0}


 15%|█▌        | 48/320 [00:13<00:35,  7.61it/s]
 15%|█▌        | 48/320 [00:13<00:35,  7.61it/s]

{'eval_loss': 0.6446230411529541, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.1244, 'eval_samples_per_second': 249.142, 'eval_steps_per_second': 32.147, 'epoch': 3.0}


 20%|██        | 64/320 [00:18<00:33,  7.64it/s]
 20%|██        | 64/320 [00:19<00:33,  7.64it/s]

{'eval_loss': 0.6033316254615784, 'eval_AUC': 0.953781512605042, 'eval_runtime': 0.123, 'eval_samples_per_second': 252.092, 'eval_steps_per_second': 32.528, 'epoch': 4.0}


 25%|██▌       | 80/320 [00:25<00:32,  7.47it/s]
 25%|██▌       | 80/320 [00:25<00:32,  7.47it/s]

{'eval_loss': 0.554007887840271, 'eval_AUC': 0.953781512605042, 'eval_runtime': 0.1196, 'eval_samples_per_second': 259.182, 'eval_steps_per_second': 33.443, 'epoch': 5.0}


 30%|███       | 96/320 [00:31<00:29,  7.57it/s]
 30%|███       | 96/320 [00:31<00:29,  7.57it/s]

{'eval_loss': 0.5058214664459229, 'eval_AUC': 0.9621848739495797, 'eval_runtime': 0.1235, 'eval_samples_per_second': 251.008, 'eval_steps_per_second': 32.388, 'epoch': 6.0}


 35%|███▌      | 112/320 [00:37<00:28,  7.41it/s]
 35%|███▌      | 112/320 [00:38<00:28,  7.41it/s]

{'eval_loss': 0.4657045900821686, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.1217, 'eval_samples_per_second': 254.67, 'eval_steps_per_second': 32.861, 'epoch': 7.0}


 40%|████      | 128/320 [00:44<00:25,  7.53it/s]
 40%|████      | 128/320 [00:44<00:25,  7.53it/s]

{'eval_loss': 0.4277274012565613, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.1165, 'eval_samples_per_second': 266.131, 'eval_steps_per_second': 34.339, 'epoch': 8.0}


 45%|████▌     | 144/320 [00:49<00:23,  7.55it/s]
 45%|████▌     | 144/320 [00:49<00:23,  7.55it/s]

{'eval_loss': 0.39889368414878845, 'eval_AUC': 0.9621848739495797, 'eval_runtime': 0.124, 'eval_samples_per_second': 250.009, 'eval_steps_per_second': 32.259, 'epoch': 9.0}


 50%|█████     | 160/320 [00:54<00:21,  7.60it/s]
 50%|█████     | 160/320 [00:55<00:21,  7.60it/s]

{'eval_loss': 0.37272441387176514, 'eval_AUC': 0.9579831932773109, 'eval_runtime': 0.125, 'eval_samples_per_second': 248.016, 'eval_steps_per_second': 32.002, 'epoch': 10.0}


 55%|█████▌    | 176/320 [01:00<00:19,  7.45it/s]
 55%|█████▌    | 176/320 [01:00<00:19,  7.45it/s]

{'eval_loss': 0.35272905230522156, 'eval_AUC': 0.9705882352941178, 'eval_runtime': 0.1184, 'eval_samples_per_second': 261.866, 'eval_steps_per_second': 33.789, 'epoch': 11.0}


 60%|██████    | 192/320 [01:06<00:16,  7.60it/s]
 60%|██████    | 192/320 [01:06<00:16,  7.60it/s]

{'eval_loss': 0.33199217915534973, 'eval_AUC': 0.9663865546218489, 'eval_runtime': 0.1212, 'eval_samples_per_second': 255.723, 'eval_steps_per_second': 32.997, 'epoch': 12.0}


 65%|██████▌   | 208/320 [01:12<00:15,  7.01it/s]
 65%|██████▌   | 208/320 [01:12<00:15,  7.01it/s]

{'eval_loss': 0.3169700801372528, 'eval_AUC': 0.9705882352941176, 'eval_runtime': 0.1375, 'eval_samples_per_second': 225.523, 'eval_steps_per_second': 29.1, 'epoch': 13.0}


 70%|███████   | 224/320 [01:18<00:13,  7.31it/s]
 70%|███████   | 224/320 [01:18<00:13,  7.31it/s]

{'eval_loss': 0.3044978976249695, 'eval_AUC': 0.9705882352941176, 'eval_runtime': 0.1213, 'eval_samples_per_second': 255.556, 'eval_steps_per_second': 32.975, 'epoch': 14.0}


 75%|███████▌  | 240/320 [01:24<00:10,  7.55it/s]
 75%|███████▌  | 240/320 [01:24<00:10,  7.55it/s]

{'eval_loss': 0.2950991690158844, 'eval_AUC': 0.9705882352941176, 'eval_runtime': 0.1212, 'eval_samples_per_second': 255.859, 'eval_steps_per_second': 33.014, 'epoch': 15.0}


 80%|████████  | 256/320 [01:30<00:08,  7.50it/s]
 80%|████████  | 256/320 [01:30<00:08,  7.50it/s]

{'eval_loss': 0.2850187122821808, 'eval_AUC': 0.9705882352941176, 'eval_runtime': 0.1236, 'eval_samples_per_second': 250.783, 'eval_steps_per_second': 32.359, 'epoch': 16.0}


 80%|████████  | 256/320 [01:34<00:23,  2.71it/s]


{'train_runtime': 94.3193, 'train_samples_per_second': 25.87, 'train_steps_per_second': 3.393, 'train_loss': 0.4803504943847656, 'epoch': 16.0}


100%|██████████| 4/4 [00:00<00:00, 30.25it/s]
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Fold 3 Validation AUC: 0.9705882352941178
Fold 4/5


  5%|▌         | 16/320 [00:04<04:19,  1.17it/s]
  5%|▌         | 16/320 [00:05<04:19,  1.17it/s]

{'eval_loss': 0.6814424395561218, 'eval_AUC': 0.7285067873303167, 'eval_runtime': 1.1958, 'eval_samples_per_second': 25.088, 'eval_steps_per_second': 3.345, 'epoch': 1.0}


 10%|█         | 32/320 [00:12<00:42,  6.78it/s]
 10%|█         | 32/320 [00:12<00:42,  6.78it/s]

{'eval_loss': 0.6624454259872437, 'eval_AUC': 0.7782805429864253, 'eval_runtime': 0.1272, 'eval_samples_per_second': 235.911, 'eval_steps_per_second': 31.455, 'epoch': 2.0}


 15%|█▌        | 48/320 [00:18<00:37,  7.29it/s]
 15%|█▌        | 48/320 [00:18<00:37,  7.29it/s]

{'eval_loss': 0.63033127784729, 'eval_AUC': 0.7918552036199095, 'eval_runtime': 0.137, 'eval_samples_per_second': 218.949, 'eval_steps_per_second': 29.193, 'epoch': 3.0}


 20%|██        | 64/320 [00:22<00:33,  7.66it/s]
 20%|██        | 64/320 [00:23<00:33,  7.66it/s]

{'eval_loss': 0.5929999947547913, 'eval_AUC': 0.7963800904977376, 'eval_runtime': 0.1193, 'eval_samples_per_second': 251.446, 'eval_steps_per_second': 33.526, 'epoch': 4.0}


 25%|██▌       | 80/320 [00:28<00:32,  7.45it/s]
 25%|██▌       | 80/320 [00:28<00:32,  7.45it/s]

{'eval_loss': 0.5626747012138367, 'eval_AUC': 0.8144796380090499, 'eval_runtime': 0.1234, 'eval_samples_per_second': 243.144, 'eval_steps_per_second': 32.419, 'epoch': 5.0}


 30%|███       | 96/320 [00:33<00:29,  7.52it/s]
 30%|███       | 96/320 [00:33<00:29,  7.52it/s]

{'eval_loss': 0.5382727384567261, 'eval_AUC': 0.8235294117647058, 'eval_runtime': 0.1223, 'eval_samples_per_second': 245.301, 'eval_steps_per_second': 32.707, 'epoch': 6.0}


 35%|███▌      | 112/320 [00:38<00:27,  7.57it/s]
 35%|███▌      | 112/320 [00:38<00:27,  7.57it/s]

{'eval_loss': 0.5157743096351624, 'eval_AUC': 0.8461538461538461, 'eval_runtime': 0.1214, 'eval_samples_per_second': 247.166, 'eval_steps_per_second': 32.955, 'epoch': 7.0}


 40%|████      | 128/320 [00:43<00:25,  7.51it/s]
 40%|████      | 128/320 [00:43<00:25,  7.51it/s]

{'eval_loss': 0.5070881247520447, 'eval_AUC': 0.8461538461538461, 'eval_runtime': 0.1233, 'eval_samples_per_second': 243.321, 'eval_steps_per_second': 32.443, 'epoch': 8.0}


 45%|████▌     | 144/320 [00:48<00:23,  7.59it/s]
 45%|████▌     | 144/320 [00:48<00:23,  7.59it/s]

{'eval_loss': 0.5017325282096863, 'eval_AUC': 0.8506787330316742, 'eval_runtime': 0.1248, 'eval_samples_per_second': 240.362, 'eval_steps_per_second': 32.048, 'epoch': 9.0}


 50%|█████     | 160/320 [00:54<00:21,  7.45it/s]
 50%|█████     | 160/320 [00:54<00:21,  7.45it/s]

{'eval_loss': 0.4926769435405731, 'eval_AUC': 0.8506787330316743, 'eval_runtime': 0.1241, 'eval_samples_per_second': 241.708, 'eval_steps_per_second': 32.228, 'epoch': 10.0}


 55%|█████▌    | 176/320 [00:59<00:19,  7.22it/s]
 55%|█████▌    | 176/320 [00:59<00:19,  7.22it/s]

{'eval_loss': 0.4943302273750305, 'eval_AUC': 0.8461538461538461, 'eval_runtime': 0.1235, 'eval_samples_per_second': 242.886, 'eval_steps_per_second': 32.385, 'epoch': 11.0}


 60%|██████    | 192/320 [01:05<00:17,  7.40it/s]
 60%|██████    | 192/320 [01:05<00:17,  7.40it/s]

{'eval_loss': 0.4988306164741516, 'eval_AUC': 0.8371040723981901, 'eval_runtime': 0.1313, 'eval_samples_per_second': 228.413, 'eval_steps_per_second': 30.455, 'epoch': 12.0}


 65%|██████▌   | 208/320 [01:10<00:15,  7.39it/s]
 65%|██████▌   | 208/320 [01:10<00:15,  7.39it/s]

{'eval_loss': 0.49888384342193604, 'eval_AUC': 0.8461538461538461, 'eval_runtime': 0.1224, 'eval_samples_per_second': 245.068, 'eval_steps_per_second': 32.676, 'epoch': 13.0}


 70%|███████   | 224/320 [01:15<00:12,  7.50it/s]
 70%|███████   | 224/320 [01:15<00:12,  7.50it/s]

{'eval_loss': 0.5077879428863525, 'eval_AUC': 0.8416289592760181, 'eval_runtime': 0.1225, 'eval_samples_per_second': 244.8, 'eval_steps_per_second': 32.64, 'epoch': 14.0}


 75%|███████▌  | 240/320 [01:21<00:11,  7.22it/s]
 75%|███████▌  | 240/320 [01:21<00:11,  7.22it/s]

{'eval_loss': 0.5176594853401184, 'eval_AUC': 0.8371040723981901, 'eval_runtime': 0.1233, 'eval_samples_per_second': 243.339, 'eval_steps_per_second': 32.445, 'epoch': 15.0}


 75%|███████▌  | 240/320 [01:25<00:28,  2.82it/s]


{'train_runtime': 85.0126, 'train_samples_per_second': 28.937, 'train_steps_per_second': 3.764, 'train_loss': 0.459443473815918, 'epoch': 15.0}


100%|██████████| 4/4 [00:00<00:00, 43.87it/s]


Fold 4 Validation AUC: 0.8506787330316743
Fold 5/5


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at sentence-transformers/all-MiniLM-L6-v2 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  5%|▌         | 16/320 [00:02<00:40,  7.57it/s]
  5%|▌         | 16/320 [00:02<00:40,  7.57it/s]

{'eval_loss': 0.687924325466156, 'eval_AUC': 0.7692307692307693, 'eval_runtime': 0.1212, 'eval_samples_per_second': 247.605, 'eval_steps_per_second': 33.014, 'epoch': 1.0}


 10%|█         | 32/320 [00:07<00:39,  7.32it/s]
 10%|█         | 32/320 [00:08<00:39,  7.32it/s]

{'eval_loss': 0.6700284481048584, 'eval_AUC': 0.9049773755656108, 'eval_runtime': 0.1244, 'eval_samples_per_second': 241.165, 'eval_steps_per_second': 32.155, 'epoch': 2.0}


 15%|█▌        | 48/320 [00:13<00:37,  7.30it/s]
 15%|█▌        | 48/320 [00:14<00:37,  7.30it/s]

{'eval_loss': 0.6387723684310913, 'eval_AUC': 0.9366515837104072, 'eval_runtime': 0.1211, 'eval_samples_per_second': 247.654, 'eval_steps_per_second': 33.02, 'epoch': 3.0}


 20%|██        | 64/320 [00:19<00:34,  7.36it/s]
 20%|██        | 64/320 [00:19<00:34,  7.36it/s]

{'eval_loss': 0.590529203414917, 'eval_AUC': 0.9547511312217195, 'eval_runtime': 0.1245, 'eval_samples_per_second': 241.02, 'eval_steps_per_second': 32.136, 'epoch': 4.0}


 25%|██▌       | 80/320 [00:25<00:33,  7.16it/s]
 25%|██▌       | 80/320 [00:25<00:33,  7.16it/s]

{'eval_loss': 0.5357155203819275, 'eval_AUC': 0.9638009049773756, 'eval_runtime': 0.1241, 'eval_samples_per_second': 241.681, 'eval_steps_per_second': 32.224, 'epoch': 5.0}


 30%|███       | 96/320 [00:31<00:30,  7.41it/s]
 30%|███       | 96/320 [00:31<00:30,  7.41it/s]

{'eval_loss': 0.4848749041557312, 'eval_AUC': 0.9683257918552036, 'eval_runtime': 0.1211, 'eval_samples_per_second': 247.719, 'eval_steps_per_second': 33.029, 'epoch': 6.0}


 35%|███▌      | 112/320 [00:37<00:28,  7.27it/s]
 35%|███▌      | 112/320 [00:38<00:28,  7.27it/s]

{'eval_loss': 0.4371805787086487, 'eval_AUC': 0.9864253393665158, 'eval_runtime': 0.1221, 'eval_samples_per_second': 245.726, 'eval_steps_per_second': 32.763, 'epoch': 7.0}


 40%|████      | 128/320 [00:43<00:26,  7.24it/s]
 40%|████      | 128/320 [00:43<00:26,  7.24it/s]

{'eval_loss': 0.3914248049259186, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1173, 'eval_samples_per_second': 255.709, 'eval_steps_per_second': 34.094, 'epoch': 8.0}


 45%|████▌     | 144/320 [00:50<00:24,  7.12it/s]
 45%|████▌     | 144/320 [00:50<00:24,  7.12it/s]

{'eval_loss': 0.36455753445625305, 'eval_AUC': 0.995475113122172, 'eval_runtime': 0.1218, 'eval_samples_per_second': 246.299, 'eval_steps_per_second': 32.84, 'epoch': 9.0}


 50%|█████     | 160/320 [00:56<00:21,  7.50it/s]
 50%|█████     | 160/320 [00:56<00:21,  7.50it/s]

{'eval_loss': 0.3462837338447571, 'eval_AUC': 0.9909502262443439, 'eval_runtime': 0.1188, 'eval_samples_per_second': 252.491, 'eval_steps_per_second': 33.666, 'epoch': 10.0}


 55%|█████▌    | 176/320 [01:03<00:19,  7.24it/s]
 55%|█████▌    | 176/320 [01:03<00:19,  7.24it/s]

{'eval_loss': 0.31820032000541687, 'eval_AUC': 0.9909502262443439, 'eval_runtime': 0.1206, 'eval_samples_per_second': 248.744, 'eval_steps_per_second': 33.166, 'epoch': 11.0}


 60%|██████    | 192/320 [01:08<00:17,  7.33it/s]
 60%|██████    | 192/320 [01:09<00:17,  7.33it/s]

{'eval_loss': 0.3006727397441864, 'eval_AUC': 0.9909502262443439, 'eval_runtime': 0.1233, 'eval_samples_per_second': 243.345, 'eval_steps_per_second': 32.446, 'epoch': 12.0}


 65%|██████▌   | 208/320 [01:14<00:15,  7.37it/s]
 65%|██████▌   | 208/320 [01:14<00:15,  7.37it/s]

{'eval_loss': 0.2964152991771698, 'eval_AUC': 0.9909502262443439, 'eval_runtime': 0.1225, 'eval_samples_per_second': 244.878, 'eval_steps_per_second': 32.65, 'epoch': 13.0}


 65%|██████▌   | 208/320 [01:19<00:42,  2.63it/s]


{'train_runtime': 79.1706, 'train_samples_per_second': 31.072, 'train_steps_per_second': 4.042, 'train_loss': 0.5258457110478327, 'epoch': 13.0}


100%|██████████| 4/4 [00:00<00:00, 51.58it/s]

Fold 5 Validation AUC: 0.995475113122172





In [4]:
# Calculate mean and standard deviation for validation AUC scores
mean_val_auc = np.mean(val_auc_scores)
std_val_auc = np.std(val_auc_scores)

# Print the results
print(f"Validation AUC: Mean = {mean_val_auc:.4f}, Standard Deviation = {std_val_auc:.4f}")

Validation AUC: Mean = 0.9180, Standard Deviation = 0.0575


In [5]:
best_fold_idx = np.argmax(val_auc_scores)
best_model_dir = output_dirs[best_fold_idx]  # Directory of the best model
print(f"Best model found in fold {best_fold_idx + 1} with Validation AUC: {val_auc_scores[best_fold_idx]}")


Best model found in fold 5 with Validation AUC: 0.995475113122172


In [7]:
# use the best model and do the final training
best_model_dir = '/Volumes/Raven/Research/FinetuneEmbed/results/BivalentLys4/fold_5/checkpoint-128'
best_model = AutoModelForSequenceClassification.from_pretrained(best_model_dir)

full_train_dataset = TextDataset(train_texts_all, train_labels_all, tokenizer)

# Define training arguments for the final training phase
final_training_args = TrainingArguments(
    output_dir="./LongShortTF/final_model",       # Directory to save the final model
    evaluation_strategy="no",         # No evaluation during training
    save_strategy="no",            # Save the model at each epoch
    save_total_limit=1,               # Keep only the last checkpoint to save storage
    learning_rate=1e-4,
    per_device_train_batch_size=8,
    num_train_epochs=5,
    weight_decay=0.01,
    logging_steps=10000,              # Minimize logging output
    report_to="none"                  # Disable logging to external tools
)

# Initialize the Trainer with the full dataset and final training arguments
trainer = Trainer(
    model=best_model,
    args=final_training_args,
    train_dataset=full_train_dataset
)

# Evaluate on the test set
test_results = trainer.predict(test_dataset)
# Calculate AUC on the test data
test_probs = torch.nn.functional.softmax(torch.tensor(test_results.predictions), dim=1)[:, 1].numpy()
test_auc = roc_auc_score(test_results.label_ids, test_probs)
print(f"Test AUC with the best model: {test_auc}")

100%|██████████| 3/3 [00:00<00:00, 15.12it/s]


Test AUC with the best model: 0.925
