Skip to content

Commit ec5b709

Browse files
committed
add typing to token count callback
1 parent 739e568 commit ec5b709

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

promptolution/callbacks.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import os
44
import time
5+
from typing import Literal
56

67
import numpy as np
78
import pandas as pd
@@ -242,7 +243,11 @@ def on_train_end(self, optimizer):
242243
class TokenCountCallback(Callback):
243244
"""Callback for stopping optimization based on the total token count."""
244245

245-
def __init__(self, max_tokens_for_termination, token_type_for_termination):
246+
def __init__(
247+
self,
248+
max_tokens_for_termination: int,
249+
token_type_for_termination: Literal["input_tokens", "output_tokens", "total_tokens"],
250+
):
246251
"""Initialize the TokenCountCallback.
247252
248253
Args:

0 commit comments

Comments
 (0)