Skip to content

Commit d2fecf7

Browse files
authored
Feature/deterministic (#35)
* make vllm class deterministic * fixes in prompt creation
1 parent ec5b709 commit d2fecf7

File tree

7 files changed

+50
-32
lines changed

7 files changed

+50
-32
lines changed

promptolution/callbacks.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Callback classes for logging, saving, and tracking optimization progress."""
22

33
import os
4-
import time
4+
from datetime import datetime
55
from typing import Literal
66

77
import numpy as np
@@ -64,7 +64,8 @@ def __init__(self, logger):
6464
def on_step_end(self, optimizer):
6565
"""Log information about the current step."""
6666
self.step += 1
67-
self.logger.critical(f"✨Step {self.step} ended✨")
67+
time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
68+
self.logger.critical(f"{time} - ✨Step {self.step} ended✨")
6869
for i, (prompt, score) in enumerate(zip(optimizer.prompts, optimizer.scores)):
6970
self.logger.critical(f"*** Prompt {i}: Score: {score}")
7071
self.logger.critical(f"{prompt}")
@@ -78,10 +79,11 @@ def on_train_end(self, optimizer, logs=None):
7879
optimizer: The optimizer object that called the callback.
7980
logs: Additional information to log.
8081
"""
82+
time = datetime.now().strftime("%d-%m-%y %H:%M:%S:%f")
8183
if logs is None:
82-
self.logger.critical("Training ended")
84+
self.logger.critical(f"{time} - Training ended")
8385
else:
84-
self.logger.critical(f"Training ended - {logs}")
86+
self.logger.critical(f"{time} - Training ended - {logs}")
8587

8688
return True
8789

@@ -109,8 +111,8 @@ def __init__(self, dir):
109111
self.step = 0
110112
self.input_tokens = 0
111113
self.output_tokens = 0
112-
self.start_time = time.time()
113-
self.step_time = time.time()
114+
self.start_time = datetime.now()
115+
self.step_time = datetime.now()
114116

115117
def on_step_end(self, optimizer):
116118
"""Save prompts and scores to csv.
@@ -124,12 +126,12 @@ def on_step_end(self, optimizer):
124126
"step": [self.step] * len(optimizer.prompts),
125127
"input_tokens": [optimizer.meta_llm.input_token_count - self.input_tokens] * len(optimizer.prompts),
126128
"output_tokens": [optimizer.meta_llm.output_token_count - self.output_tokens] * len(optimizer.prompts),
127-
"time_elapsed": [time.time() - self.step_time] * len(optimizer.prompts),
129+
"time_elapsed": [(datetime.now() - self.step_time).total_seconds()] * len(optimizer.prompts),
128130
"score": optimizer.scores,
129131
"prompt": optimizer.prompts,
130132
}
131133
)
132-
self.step_time = time.time()
134+
self.step_time = datetime.now()
133135
self.input_tokens = optimizer.meta_llm.input_token_count
134136
self.output_tokens = optimizer.meta_llm.output_token_count
135137

@@ -151,7 +153,8 @@ def on_train_end(self, optimizer):
151153
steps=self.step,
152154
input_tokens=optimizer.meta_llm.input_token_count,
153155
output_tokens=optimizer.meta_llm.output_token_count,
154-
time_elapsed=time.time() - self.start_time,
156+
time_elapsed=(datetime.now() - self.start_time).total_seconds(),
157+
time=datetime.now(),
155158
score=np.array(optimizer.scores).mean(),
156159
best_prompts=str(optimizer.prompts),
157160
),

promptolution/helpers.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def run_experiment(config: Config):
2727
return df
2828

2929

30-
def run_optimization(config: Config, callbacks: List = None):
30+
def run_optimization(config: Config, callbacks: List = None, use_token: bool = False):
3131
"""Run the optimization phase of the experiment.
3232
3333
Args:
@@ -37,7 +37,10 @@ def run_optimization(config: Config, callbacks: List = None):
3737
List[str]: The optimized list of prompts.
3838
"""
3939
task = get_task(config)
40-
llm = get_llm(config.meta_llm, token=config.api_token, model_storage_path=config.model_storage_path)
40+
if use_token:
41+
llm = get_llm(config.meta_llm, token=config.api_token)
42+
else:
43+
llm = get_llm(config.meta_llm, model_storage_path=config.model_storage_path, seed=config.random_seed)
4144
if config.predictor == "MarkerBasedClassificator":
4245
predictor = MarkerBasedClassificator(llm, classes=task.classes)
4346
elif config.predictor == "FirstOccurenceClassificator":

promptolution/llms/vllm.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,12 @@ def __init__(
4444
temperature: float = 0.1,
4545
top_p: float = 0.9,
4646
model_storage_path: str | None = None,
47-
token: str | None = None,
4847
dtype: str = "auto",
4948
tensor_parallel_size: int = 1,
5049
gpu_memory_utilization: float = 0.95,
5150
max_model_len: int = 2048,
5251
trust_remote_code: bool = False,
52+
seed: int = 42,
5353
**kwargs,
5454
):
5555
"""Initialize the VLLM with a specific model.
@@ -61,12 +61,12 @@ def __init__(
6161
temperature (float, optional): Sampling temperature. Defaults to 0.1.
6262
top_p (float, optional): Top-p sampling parameter. Defaults to 0.9.
6363
model_storage_path (str, optional): Directory to store the model. Defaults to None.
64-
token: (str, optional): Token for accessing the model - not used in implementation yet.
6564
dtype (str, optional): Data type for model weights. Defaults to "float16".
6665
tensor_parallel_size (int, optional): Number of GPUs for tensor parallelism. Defaults to 1.
6766
gpu_memory_utilization (float, optional): Fraction of GPU memory to use. Defaults to 0.95.
6867
max_model_len (int, optional): Maximum sequence length for the model. Defaults to 2048.
6968
trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False.
69+
seed (int, optional): Random seed for the model. Defaults to 42.
7070
**kwargs: Additional keyword arguments to pass to the LLM class initialization.
7171
7272
Note:
@@ -81,7 +81,9 @@ def __init__(
8181
self.trust_remote_code = trust_remote_code
8282

8383
# Configure sampling parameters
84-
self.sampling_params = SamplingParams(temperature=temperature, top_p=top_p, max_tokens=max_generated_tokens)
84+
self.sampling_params = SamplingParams(
85+
temperature=temperature, top_p=top_p, max_tokens=max_generated_tokens, seed=seed
86+
)
8587

8688
# Initialize the vLLM engine with both explicit parameters and any additional kwargs
8789
llm_params = {
@@ -93,6 +95,7 @@ def __init__(
9395
"max_model_len": self.max_model_len,
9496
"download_dir": model_storage_path,
9597
"trust_remote_code": self.trust_remote_code,
98+
"seed": seed,
9699
**kwargs,
97100
}
98101

@@ -136,11 +139,6 @@ def _get_response(self, inputs: list[str]):
136139
for input in inputs
137140
]
138141

139-
# Count input tokens
140-
for prompt in prompts:
141-
input_tokens = self.tokenizer.encode(prompt)
142-
self.input_token_count += len(input_tokens)
143-
144142
# generate responses for self.batch_size prompts at the same time
145143
all_responses = []
146144
for i in range(0, len(prompts), self.batch_size):

promptolution/optimizers/evoprompt_ga.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def optimize(self, n_steps: int) -> List[str]:
8181
if not continue_optimization:
8282
break
8383

84+
self._on_train_end()
8485
return self.prompts
8586

8687
def _crossover(self, prompts, scores) -> str:

promptolution/predictors/classificator.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,12 +75,12 @@ class MarkerBasedClassificator(BasePredictor):
7575
BasePredictor: The base class for predictors in the promptolution library.
7676
"""
7777

78-
def __init__(self, llm, classes, marker="<final_answer>", *args, **kwargs):
78+
def __init__(self, llm, classes=None, marker="<final_answer>", *args, **kwargs):
7979
"""Initialize the Classificator.
8080
8181
Args:
8282
llm: The language model to use for predictions.
83-
classes (List[str]): The list of valid class labels.
83+
classes (List[str]): The list of valid class labels. If None, does not force any class.
8484
marker (str): The marker to use for extracting the class label.
8585
*args, **kwargs: Additional arguments for the BasePredictor.
8686
"""
@@ -101,11 +101,11 @@ def _extract_preds(self, preds: List[str], shape: Tuple[int, int]) -> np.ndarray
101101
"""
102102
response = []
103103
for pred in preds:
104-
predicted_class = pred.split(self.marker)[-1].strip()
105-
if predicted_class not in self.classes:
106-
predicted_class = self.classes[0]
104+
pred = pred.split(self.marker)[-1].strip()
105+
if self.classes is not None and pred not in self.classes:
106+
pred = self.classes[0]
107107

108-
response.append(predicted_class)
108+
response.append(pred)
109109

110110
response = np.array(response).reshape(*shape)
111111
return response

promptolution/utils/prompt_creation.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def create_prompts_from_samples(
4242
n_samples: int = 3,
4343
task_description: str = None,
4444
n_prompts: int = 1,
45+
get_uniform_labels: bool = False,
4546
) -> List[str]:
4647
"""Generate a set of prompts from dataset examples sampled from a given task.
4748
@@ -59,13 +60,23 @@ def create_prompts_from_samples(
5960
n_samples (int): The number of samples to use for generating prompts.
6061
task_description (str): The description of the task to include in the prompt.
6162
n_prompts (int): The number of prompts to generate.
63+
get_uniform_labels (bool): If True, samples are selected such that all classes are represented.
6264
6365
Returns:
6466
List[str]: A list of generated prompts.
6567
"""
68+
if meta_prompt is None and task_description is None:
69+
meta_prompt_template = PROMPT_CREATION_TEMPLATE
70+
elif meta_prompt is None and task_description is not None:
71+
meta_prompt_template = PROMPT_CREATION_TEMPLATE_TD.replace("<task_desc>", task_description)
72+
elif meta_prompt is not None and task_description is None:
73+
meta_prompt_template = meta_prompt
74+
elif meta_prompt is not None and task_description is not None:
75+
meta_prompt_template = meta_prompt.replace("<task_desc>", task_description)
76+
6677
meta_prompts = []
6778
for _ in range(n_prompts):
68-
if isinstance(task, ClassificationTask):
79+
if isinstance(task, ClassificationTask) and get_uniform_labels:
6980
# if classification task sample such that all classes are represented
7081
unique_labels, counts = np.unique(task.ys, return_counts=True)
7182
proportions = counts / len(task.ys)
@@ -87,13 +98,10 @@ def create_prompts_from_samples(
8798
xs = task.xs[indices].tolist()
8899
ys = task.ys[indices].tolist()
89100

90-
if meta_prompt is None:
91-
meta_prompt = PROMPT_CREATION_TEMPLATE
92-
if task_description is None:
93-
meta_prompt = PROMPT_CREATION_TEMPLATE_TD.replace("<task_desc>", task_description)
94101
examples = "\n\n".join([f"Input: {x}\nOutput: {y}" for x, y in zip(xs, ys)])
95-
meta_prompt = meta_prompt.replace("<input_output_pairs>", examples)
102+
meta_prompt = meta_prompt_template.replace("<input_output_pairs>", examples)
96103
meta_prompts.append(meta_prompt)
104+
97105
prompts = llm.get_response(meta_prompts)
98106
prompts = [prompt.split("</prompt>")[0].split("<prompt>")[-1].strip() for prompt in prompts]
99107

scripts/optimizer_test_run.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
parser.add_argument("--optimizer", default="evopromptde")
1717
parser.add_argument("--n-steps", type=int, default=10)
1818
parser.add_argument("--token", default=None)
19+
parser.add_argument("--seed", type=int, default=187)
1920
args = parser.parse_args()
2021

2122
config = Config(
@@ -29,8 +30,12 @@
2930
evaluation_llm=args.model,
3031
api_token=args.token,
3132
model_storage_path=args.model_storage_path,
33+
random_seed=args.seed,
3234
)
3335

34-
prompts = run_optimization(config, callbacks=[LoggerCallback(logger), CSVCallback(f"results/{args.model}/")])
36+
if args.token is None:
37+
prompts = run_optimization(config, callbacks=[LoggerCallback(logger), CSVCallback(f"results/seedingtest/{args.model}/")])
38+
else:
39+
prompts = run_optimization(config, callbacks=[LoggerCallback(logger), CSVCallback(f"results/seedingtest/{args.model}/")], use_token=True)
3540

3641
logger.info(f"Optimized prompts: {prompts}")

0 commit comments

Comments
 (0)