Skip to content

Commit

Permalink
feat: Support Together AI in OpenAIEmbeddings wrapper (davidmigloz#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored and KennethKnudsen97 committed Apr 22, 2024
1 parent 9c38fef commit 6d8edf2
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 8 deletions.
1 change: 1 addition & 0 deletions docs/_sidebar.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
- [GCP Vertex AI](/modules/retrieval/text_embedding/integrations/gcp_vertex_ai.md)
- [Ollama](/modules/retrieval/text_embedding/integrations/ollama.md)
- [Mistral AI](/modules/retrieval/text_embedding/integrations/mistralai.md)
- [Together AI](/modules/retrieval/text_embedding/integrations/together_ai.md)
- [Prem App](/modules/retrieval/text_embedding/integrations/prem.md)
- [Vector stores](/modules/retrieval/vector_stores/vector_stores.md)
- Integrations
Expand Down
25 changes: 18 additions & 7 deletions docs/modules/retrieval/text_embedding/integrations/openai.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
# OpenAI
# OpenAIEmbeddings

Let's load the OpenAI Embedding class.
You can use the `OpenAIEmbeddings` wrapper to consume OpenAI embedding models.

```dart
final openaiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openaiApiKey);
const text = 'This is a test document.';
final res = await embeddings.embedQuery(text);
final res = await embeddings.embedDocuments([text]);
final openAiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openAiApiKey);
// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.003105443, 0.011136302, -0.0040295827, -0.011749065, ...]]
// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.005047946, 0.0050882488, -0.0051957234, -0.019143905, ...]
embeddings.close();
```
30 changes: 30 additions & 0 deletions docs/modules/retrieval/text_embedding/integrations/together_ai.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Together AI Embeddings

[Together AI](https://www.together.ai/) offers several leading [embedding models](https://docs.together.ai/docs/embedding-models#embedding-models) through its OpenAI compatible API.

You can consume Together AI API using the `OpenAIEmbeddings` wrapper in the same way you would use the OpenAI API.

The only difference is that you need to change the base URL to `https://api.together.xyz/v1`:

```dart
final togetherAiApiKey = Platform.environment['TOGETHER_AI_API_KEY'];
final embeddings = OpenAIEmbeddings(
apiKey: togetherAiApiKey,
baseUrl: 'https://api.together.xyz/v1',
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
);
// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.038838703, 0.0580902, 0.022614542, 0.0078403875, ...]]
// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.019722218, 0.04656633, -0.0074559706, 0.005712764, ...]
embeddings.close();
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
// ignore_for_file: avoid_print
import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_openai/langchain_openai.dart';

void main(final List<String> arguments) async {
final openAiApiKey = Platform.environment['OPENAI_API_KEY'];
final embeddings = OpenAIEmbeddings(apiKey: openAiApiKey);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.003105443, 0.011136302, -0.0040295827, -0.011749065, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.005047946, 0.0050882488, -0.0051957234, -0.019143905, ...]

embeddings.close();
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
// ignore_for_file: avoid_print
import 'dart:io';

import 'package:langchain/langchain.dart';
import 'package:langchain_openai/langchain_openai.dart';

void main(final List<String> arguments) async {
final togetherAiApiKey = Platform.environment['TOGETHER_AI_API_KEY'];
final embeddings = OpenAIEmbeddings(
apiKey: togetherAiApiKey,
baseUrl: 'https://api.together.xyz/v1',
model: 'togethercomputer/m2-bert-80M-32k-retrieval',
);

// Embedding a document
const doc = Document(pageContent: 'This is a test document.');
final res1 = await embeddings.embedDocuments([doc]);
print(res1);
// [[-0.038838703, 0.0580902, 0.022614542, 0.0078403875, ...]]

// Embedding a retrieval query
const text = 'This is a test query.';
final res2 = await embeddings.embedQuery(text);
print(res2);
// [-0.019722218, 0.04656633, -0.0074559706, 0.005712764, ...]

embeddings.close();
}
9 changes: 8 additions & 1 deletion packages/langchain_openai/lib/src/embeddings/openai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ import 'package:openai_dart/openai_dart.dart';
/// - [Embeddings guide](https://platform.openai.com/docs/guides/embeddings/limitations-risks)
/// - [Embeddings API docs](https://platform.openai.com/docs/api-reference/embeddings)
///
/// You can also use this wrapper to consume OpenAI-compatible APIs like [Together AI](https://www.together.ai).
///
/// ### Authentication
///
/// The OpenAI API uses API keys for authentication. Visit your
Expand Down Expand Up @@ -122,7 +124,7 @@ class OpenAIEmbeddings implements Embeddings {
OpenAIEmbeddings({
final String? apiKey,
final String? organization,
final String? baseUrl,
final String baseUrl = 'https://api.openai.com/v1',
final Map<String, String>? headers,
final Map<String, dynamic>? queryParams,
final http.Client? client,
Expand Down Expand Up @@ -199,4 +201,9 @@ class OpenAIEmbeddings implements Embeddings {
);
return data.data.first.embeddingVector;
}

/// Closes the client and cleans up any resources associated with it.
void close() {
_client.endSession();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
@TestOn('vm')
library; // Uses dart:io

import 'dart:io';

import 'package:langchain_openai/langchain_openai.dart';
import 'package:test/test.dart';

void main() {
group('Together AI Embeddings tests', () {
late OpenAIEmbeddings embeddings;

setUp(() async {
embeddings = OpenAIEmbeddings(
apiKey: Platform.environment['TOGETHER_AI_API_KEY'],
baseUrl: 'https://api.together.xyz/v1',
);
});

tearDown(() {
embeddings.close();
});

test('Test AI Embeddings models', () async {
final models = [
'togethercomputer/m2-bert-80M-2k-retrieval',
'togethercomputer/m2-bert-80M-8k-retrieval',
'togethercomputer/m2-bert-80M-32k-retrieval',
'WhereIsAI/UAE-Large-V1',
'BAAI/bge-large-en-v1.5',
'BAAI/bge-base-en-v1.5',
'sentence-transformers/msmarco-bert-base-dot-v5',
'bert-base-uncased',
];
for (final model in models) {
embeddings.model = model;
final res = await embeddings.embedQuery('Hello world');
expect(res.length, greaterThan(0));
await Future<void>.delayed(const Duration(seconds: 1)); // Rate limit
}
});
});
}

0 comments on commit 6d8edf2

Please sign in to comment.