Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support a :max_tokens param in OpenAIValidator.validate_max_tokens #388

Merged
merged 2 commits into from
Nov 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions lib/langchain/llm/openai.rb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def complete(prompt:, **params)
return legacy_complete(prompt, parameters) if is_legacy_model?(parameters[:model])

parameters[:messages] = compose_chat_messages(prompt: prompt)
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_tokens = [max_tokens, options[:max_tokens]].min

Do you think this ^^ logic should actually live inside of this method and not inside the base_validator or the openai_validator?

For example we could have something like this:

parameters[:max_tokens] = [
  parameters[:max_tokens],
  validate_max_tokens(parameters[:messages], parameters[:model])
].min

I guess what's missing with this approach is raising an error?

Copy link
Contributor Author

@bricolage bricolage Nov 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreibondarev oh yeah that's leaky, it goes by other key names in other LLMs. My vote is let the validator subclasses handle it (e.g. [options[:max_tokens], super(...)].min). If we did it right in the LLM classes, any individual method calls that use it would lead to some repetition. e.g. OpenAI#chat and #complete would both have similar code. This way the other LLM classes can all pass through their own param names correctly. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreibondarev also what error condition did you have in mind? e.g. if :max_tokens is passed in as a non-number value? Or possibly as a negative value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreibondarev check out the latest commit. I moved it into the OpenAIValidator subclass.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andreibondarev also what error condition did you have in mind? e.g. if :max_tokens is passed in as a non-number value? Or possibly as a negative value?

This is the error that I was referring to: https://github.com/andreibondarev/langchainrb/pull/388/files#diff-79c9e673e647cc360e11c2946b3a4def6a5dfe52480860dc9ccd2101906902a9R30


response = with_api_error_handling do
client.chat(parameters: parameters)
Expand Down Expand Up @@ -131,7 +131,7 @@ def chat(prompt: "", messages: [], context: "", examples: [], **options, &block)
if functions
parameters[:functions] = functions
else
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model])
parameters[:max_tokens] = validate_max_tokens(parameters[:messages], parameters[:model], parameters[:max_tokens])
end

response = with_api_error_handling { client.chat(parameters: parameters) }
Expand Down Expand Up @@ -230,8 +230,8 @@ def with_api_error_handling
response
end

def validate_max_tokens(messages, model)
LENGTH_VALIDATOR.validate_max_tokens!(messages, model)
def validate_max_tokens(messages, model, max_tokens = nil)
LENGTH_VALIDATOR.validate_max_tokens!(messages, model, max_tokens: max_tokens)
end

def extract_response(response)
Expand Down
9 changes: 5 additions & 4 deletions lib/langchain/utils/token_length/base_validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,17 @@ def self.validate_max_tokens!(content, model_name, options = {})
end

leftover_tokens = token_limit(model_name) - text_token_length
# Some models have a separate token limit for completion (e.g. GPT-4 Turbo)

# Some models have a separate token limit for completions (e.g. GPT-4 Turbo)
# We want the lower of the two limits
leftover_tokens = [leftover_tokens, completion_token_limit(model_name)].min
max_tokens = [leftover_tokens, completion_token_limit(model_name)].min

# Raise an error even if whole prompt is equal to the model's token limit (leftover_tokens == 0)
if leftover_tokens < 0
if max_tokens < 0
raise limit_exceeded_exception(token_limit(model_name), text_token_length)
end

leftover_tokens
max_tokens
end

def self.limit_exceeded_exception(limit, length)
Expand Down
6 changes: 6 additions & 0 deletions lib/langchain/utils/token_length/openai_validator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ def self.token_limit(model_name)
def self.completion_token_limit(model_name)
COMPLETION_TOKEN_LIMITS[model_name] || token_limit(model_name)
end

# If :max_tokens is passed in, take the lower of it and the calculated max_tokens
def self.validate_max_tokens!(content, model_name, options = {})
max_tokens = super(content, model_name, options)
[options[:max_tokens], max_tokens].reject(&:nil?).min
end
end
end
end
Expand Down
22 changes: 22 additions & 0 deletions spec/langchain/utils/token_length/openai_validator_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,28 @@
expect(subject).to eq(0)
end
end

context "when :max_tokens is passed in" do
context "when :max_tokens is lower than the leftover tokens" do
subject { described_class.validate_max_tokens!(content, model, max_tokens: 10) }
let(:content) { "lorem ipsum" * 100 }
let(:model) { "gpt-4" }

it "returns the correct max_tokens" do
expect(subject).to eq(10)
end
end

context "when :max_tokens is greater than the leftover tokens" do
subject { described_class.validate_max_tokens!(content, model, max_tokens: 8000) }
let(:content) { "lorem ipsum" * 100 }
let(:model) { "gpt-4" }

it "returns the correct max_tokens" do
expect(subject).to eq(7892)
end
end
end
end

context "with array argument" do
Expand Down