In [None]:
from transformers.trainer_callback import EarlyStoppingCallback
from transformers import TrainingArguments

# Define the EarlyStoppingCallback with debug prints
class EarlyStoppingCallbackWithDebug(EarlyStoppingCallback):
    def __init__(self, patience=2):
        super().__init__(early_stopping_patience=patience)

    def on_evaluate(self, eval_result, **kwargs):
        print("Evaluation result:", eval_result)  # Print evaluation result for debugging
        # Track validation loss
        current_score = eval_result.get("eval_loss")
        if self.best_score is None or current_score < self.best_score:
            self.best_score = current_score
            self.counter = 0  # Reset counter on improvement
        else:
            self.counter += 1
            if self.counter >= self.patience:
                print(f"Early stopping triggered after {self.patience} evaluations without improvement.")
                self.trainer.control = TrainingControl.STOP_EARLY

# Training arguments with early stopping criteria
training_arguments = TrainingArguments(
    # Other arguments...
    evaluation_strategy="steps",
    eval_steps=logging_steps * 2,  # Evaluate validation set every few training steps
)

# Create EarlyStopping callback instance
early_stopping_callback = EarlyStoppingCallbackWithDebug(patience=patience)

# Train model with early stopping callback
trainer = SFTTrainer(
    # Other arguments...
    args=training_arguments,
    callbacks=[early_stopping_callback]  # Add callback to trainer
)

# Train model
trainer.train()
