Skip to content

Commit

Permalink
refactor: Remove tiktoken in favour of countTokens API on VertexAI (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored and KennethKnudsen97 committed Apr 22, 2024
1 parent 93f6a22 commit 56e7c7a
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 46 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import 'package:http/http.dart' as http;
import 'package:langchain/langchain.dart';
import 'package:langchain_tiktoken/langchain_tiktoken.dart';
import 'package:uuid/uuid.dart';
import 'package:vertex_ai/vertex_ai.dart';

Expand Down Expand Up @@ -190,20 +189,49 @@ class ChatVertexAI extends BaseChatModel<ChatVertexAIOptions> {
return result.toChatResult(id, model);
}

/// Tokenizes the given prompt using tiktoken.
///
/// Currently Google does not provide a tokenizer for Vertex AI models.
/// So we use tiktoken and cl100k_base encoding to get an approximation
/// for counting tokens. Mind that the actual tokens will be totally
/// different from the ones used by the Vertex AI model.
///
/// - [promptValue] The prompt to tokenize.
@override
Future<List<int>> tokenize(
final PromptValue promptValue, {
final ChatVertexAIOptions? options,
}) async {
final encoding = getEncoding('cl100k_base');
return encoding.encode(promptValue.toString());
throw UnsupportedError(
'ChatVertexAI does not support tokenize, only countTokens',
);
}

@override
Future<int> countTokens(
final PromptValue promptValue, {
final ChatVertexAIOptions? options,
}) async {
final messages = promptValue.toChatMessages();
String? context;
final vertexMessages = <VertexAITextChatModelMessage>[];
for (final message in messages) {
if (message is SystemChatMessage) {
context = message.content;
continue;
} else {
vertexMessages.add(message.toVertexAIChatMessage());
}
}
final examples = (options?.examples ?? defaultOptions.examples)
?.map((final e) => e.toVertexAIChatExample())
.toList(growable: false);
final model =
options?.model ?? defaultOptions.model ?? throwNullModelError();

final res = await client.chat.countTokens(
context: context,
examples: examples,
messages: vertexMessages,
publisher: options?.publisher ??
ArgumentError.checkNotNull(
defaultOptions.publisher,
'VertexAIOptions.publisher',
),
model: model,
);
return res.totalTokens;
}
}
37 changes: 26 additions & 11 deletions packages/langchain_google/lib/src/llms/vertex_ai.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import 'package:http/http.dart' as http;
import 'package:langchain/langchain.dart';
import 'package:langchain_tiktoken/langchain_tiktoken.dart';
import 'package:vertex_ai/vertex_ai.dart';

import 'models/mappers.dart';
Expand Down Expand Up @@ -80,6 +79,10 @@ import 'models/models.dart';
/// - `text-bison-32k`
/// * Max input and output tokens combined: 32k
/// * Training data: Up to Aug 2023
/// - `text-unicorn`
/// * Max input token: 8192
/// * Max output tokens: 1024
/// * Training data: Up to Feb 2023
///
/// The previous list of models may not be exhaustive or up-to-date. Check out
/// the [Vertex AI documentation](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)
Expand Down Expand Up @@ -170,20 +173,32 @@ class VertexAI extends BaseLLM<VertexAIOptions> {
return result.toLLMResult(model);
}

/// Tokenizes the given prompt using tiktoken.
///
/// Currently Google does not provide a tokenizer for Vertex AI models.
/// So we use tiktoken and cl100k_base encoding to get an approximation
/// for counting tokens. Mind that the actual tokens will be totally
/// different from the ones used by the Vertex AI model.
///
/// - [promptValue] The prompt to tokenize.
@override
Future<List<int>> tokenize(
final PromptValue promptValue, {
final VertexAIOptions? options,
}) async {
final encoding = getEncoding('cl100k_base');
return encoding.encode(promptValue.toString());
throw UnsupportedError(
'VertexAI does not support tokenize, only countTokens',
);
}

@override
Future<int> countTokens(
final PromptValue promptValue, {
final VertexAIOptions? options,
}) async {
final model =
options?.model ?? defaultOptions.model ?? throwNullModelError();
final res = await client.text.countTokens(
prompt: promptValue.toString(),
publisher: options?.publisher ??
ArgumentError.checkNotNull(
defaultOptions.publisher,
'VertexAIOptions.publisher',
),
model: model,
);
return res.totalTokens;
}
}
1 change: 0 additions & 1 deletion packages/langchain_google/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies:
googleapis_auth: ^1.4.1
http: ^1.1.0
langchain: ^0.3.2
langchain_tiktoken: ^1.0.1
meta: ^1.9.1
uuid: ^4.0.0
vertex_ai: ^0.0.8
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -181,17 +181,6 @@ void main() async {
expect(res2.generations.length, 5);
});

test('Test tokenize', () async {
final chat = ChatVertexAI(
httpClient: authHttpClient,
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
);
const text = 'Hello, how are you?';

final tokens = await chat.tokenize(PromptValue.string(text));
expect(tokens, [9906, 11, 1268, 527, 499, 30]);
});

test('Test countTokens string', () async {
final chat = ChatVertexAI(
httpClient: authHttpClient,
Expand Down Expand Up @@ -226,7 +215,7 @@ void main() async {
];

final numTokens = await chat.countTokens(PromptValue.chat(messages));
expect(numTokens, 41);
expect(numTokens, 37);
});
});
}
11 changes: 0 additions & 11 deletions packages/langchain_google/test/llms/vertex_ai_test.dart
Original file line number Diff line number Diff line change
Expand Up @@ -146,17 +146,6 @@ Future<void> main() async {
expect(res2.generations.length, 5);
});

test('Test tokenize', () async {
final llm = VertexAI(
httpClient: authHttpClient,
project: Platform.environment['VERTEX_AI_PROJECT_ID']!,
);
const text = 'Hello, how are you?';

final tokens = await llm.tokenize(PromptValue.string(text));
expect(tokens, [9906, 11, 1268, 527, 499, 30]);
});

test('Test countTokens', () async {
final llm = VertexAI(
httpClient: authHttpClient,
Expand Down

0 comments on commit 56e7c7a

Please sign in to comment.