diff --git a/model2vec/train/classifier.py b/model2vec/train/classifier.py index 9ec30ee0..d90986a8 100644 --- a/model2vec/train/classifier.py +++ b/model2vec/train/classifier.py @@ -96,6 +96,8 @@ def fit( y: list[str], learning_rate: float = 1e-3, batch_size: int | None = None, + min_epochs: int | None = None, + max_epochs: int | None = -1, early_stopping_patience: int | None = 5, test_size: float = 0.1, device: str = "auto", @@ -114,6 +116,9 @@ def fit( :param learning_rate: The learning rate. :param batch_size: The batch size. If this is None, a good batch size is chosen automatically. + :param min_epochs: The minimum number of epochs to train for. + :param max_epochs: The maximum number of epochs to train for. + If this is -1, the model trains until early stopping is triggered. :param early_stopping_patience: The patience for early stopping. If this is None, early stopping is disabled. :param test_size: The test size for the train-test split. @@ -158,7 +163,8 @@ def fit( with TemporaryDirectory() as tempdir: trainer = pl.Trainer( - max_epochs=500, + min_epochs=min_epochs, + max_epochs=max_epochs, callbacks=callbacks, val_check_interval=val_check_interval, check_val_every_n_epoch=check_val_every_epoch,