From f83be1a68c9fccfcc6ac84b0e6cf75628504c2c7 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 09:39:28 +0100 Subject: [PATCH 1/2] Added min and max epochs to fit --- model2vec/train/classifier.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 9ec30ee0..67264273 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -96,6 +96,8 @@ def fit( y: list[str], learning_rate: float = 1e-3, batch_size: int | None = None, + min_epochs: int = 1, + max_epochs: int = 500, early_stopping_patience: int | None = 5, test_size: float = 0.1, device: str = "auto", @@ -114,6 +116,8 @@ def fit( :param learning_rate: The learning rate. :param batch_size: The batch size. If this is None, a good batch size is chosen automatically. + :param min_epochs: The minimum number of epochs to train for. + :param max_epochs: The maximum number of epochs to train for. :param early_stopping_patience: The patience for early stopping. If this is None, early stopping is disabled. :param test_size: The test size for the train-test split. @@ -158,7 +162,8 @@ def fit( with TemporaryDirectory() as tempdir: trainer = pl.Trainer( - max_epochs=500, + min_epochs=min_epochs, + max_epochs=max_epochs, callbacks=callbacks, val_check_interval=val_check_interval, check_val_every_n_epoch=check_val_every_epoch, From fba502ae308a2f40b25618a1ca440059c79d77e3 Mon Sep 17 00:00:00 2001 From: Pringled Date: Fri, 14 Feb 2025 09:58:15 +0100 Subject: [PATCH 2/2] Updated default args --- model2vec/train/classifier.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 67264273..d90986a8 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -96,8 +96,8 @@ def fit( y: list[str], learning_rate: float = 1e-3, batch_size: int | None = None, - min_epochs: int = 1, - max_epochs: int = 500, + min_epochs: int | None = None, + max_epochs: int | None = -1, early_stopping_patience: int | None = 5, test_size: float = 0.1, device: str = "auto", @@ -118,6 +118,7 @@ def fit( If this is None, a good batch size is chosen automatically. :param min_epochs: The minimum number of epochs to train for. :param max_epochs: The maximum number of epochs to train for. + If this is -1, the model trains until early stopping is triggered. :param early_stopping_patience: The patience for early stopping. If this is None, early stopping is disabled. :param test_size: The test size for the train-test split.