Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions src/paperqa/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,21 +825,23 @@ def _deprecated_field(self) -> Self:
return self

@model_validator(mode="after")
def _validate_temperature_for_o1_preview(self) -> Self:
"""Ensures temperature is 1 if the LLM is 'o1-preview' or 'o1-mini'.
def _update_temperature(self) -> Self:
"""Ensures temperature is 1 if the LLM requires it.

o1 reasoning models only support temperature = 1. See
https://platform.openai.com/docs/guides/reasoning/quickstart
o1 reasoning models only support temperature = 1.
SEE: https://platform.openai.com/docs/guides/reasoning/quickstart
"""
if self.llm.startswith("o1-") and self.temperature != 1:
warnings.warn(
"When dealing with OpenAI o1 models, the temperature must be set to 1."
f" The specified temperature {self.temperature} has been overridden"
" to 1.",
category=UserWarning,
stacklevel=2,
)
self.temperature = 1
for model_prefix in ("o1", "gpt-5"):
if self.llm.startswith(model_prefix) and self.temperature != 1:
warnings.warn(
f"When dealing with OpenAI {model_prefix} models,"
" the temperature must be set to 1,"
f" so the specified temperature {self.temperature}"
" has been overridden to 1.",
category=UserWarning,
stacklevel=2,
)
self.temperature = 1
return self

@computed_field # type: ignore[prop-decorator]
Expand Down
28 changes: 17 additions & 11 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import pathlib
import warnings
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -102,18 +101,25 @@ def test_router_kwargs_present_in_models() -> None:
assert settings.get_summary_llm().config["router_kwargs"] is not None


def test_o1_requires_temp_equals_1() -> None:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
s = Settings(llm="o1-thismodeldoesnotexist", temperature=0)
assert "temperature must be set to 1" in str(w[-1].message)
@pytest.mark.parametrize(
"model_name",
[
"o1",
"o1-thismodeldoesnotexist",
"gpt-5",
"gpt-5-2025-08-07",
"gpt-5-mini-2025-08-07",
],
)
def test_models_requiring_temp_1(model_name: str) -> None:
with pytest.warns(UserWarning, match="temperature") as record: # noqa: PT031
s = Settings(llm=model_name, temperature=0)
(w,) = record.list
assert "temperature must be set to 1" in str(w.message)
assert s.temperature == 1

# Test that temperature=1 produces no warning
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
_ = Settings(llm="o1-thismodeldoesnotexist", temperature=1)
assert not w
Settings(llm=model_name, temperature=1)
assert record.list == [w], "Expected no new warnings with correct temperature"


@pytest.mark.parametrize(
Expand Down