Skip to content

Commit

Permalink
feat(retrievers)!: Move all retriever config options to RetrieverOpti…
Browse files Browse the repository at this point in the history
  • Loading branch information
davidmigloz authored and KennethKnudsen97 committed Apr 22, 2024
1 parent f5184fc commit 208f476
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 25 deletions.
4 changes: 3 additions & 1 deletion docs/expression_language/cookbook/retrieval.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,9 @@ In this example, we will add a memory to the chain and return the source documen

```dart
final retriever = vectorStore.asRetriever(
searchType: const VectorStoreSimilaritySearch(k: 1),
defaultOptions: const VectorStoreRetrieverOptions(
searchType: VectorStoreSimilaritySearch(k: 1),
),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
final stringOutputParser = const StringOutputParser();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ Future<void> _conversationalRetrievalChainMemoryAndDocs() async {
final vectorStore = Chroma(embeddings: embeddings);

final retriever = vectorStore.asRetriever(
searchType: const VectorStoreSimilaritySearch(k: 1),
defaultOptions: const VectorStoreRetrieverOptions(
searchType: VectorStoreSimilaritySearch(k: 1),
),
);
final model = ChatOpenAI(apiKey: openaiApiKey);
const stringOutputParser = StringOutputParser();
Expand Down
6 changes: 3 additions & 3 deletions packages/langchain/lib/src/chains/retrieval_qa.dart
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class RetrievalQAChain extends BaseChain {
});

/// Retriever to use.
final BaseRetriever retriever;
final Retriever retriever;

/// Chain to use to combine the documents.
final BaseCombineDocumentsChain combineDocumentsChain;
Expand Down Expand Up @@ -111,7 +111,7 @@ class RetrievalQAChain extends BaseChain {
String get chainType => 'retrieval_qa';

/// Creates a [RetrievalQAChain] from a [BaseLanguageModel] and a
/// [BaseRetriever].
/// [Retriever].
///
/// By default, it uses a prompt template optimized for question answering
/// that includes the retrieved documents and the question.
Expand All @@ -134,7 +134,7 @@ class RetrievalQAChain extends BaseChain {
/// [prompt]. Use 'context' and 'question' as the variable names.
factory RetrievalQAChain.fromLlm({
required final BaseLanguageModel llm,
required final BaseRetriever retriever,
required final Retriever retriever,
final PromptTemplate? prompt,
}) {
return RetrievalQAChain(
Expand Down
18 changes: 12 additions & 6 deletions packages/langchain/lib/src/documents/retrievers/base.dart
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
import '../../core/core.dart';
import '../models/models.dart';
import 'models/models.dart';

/// {@template base_retriever}
/// Base Index class. All indexes should extend this class.
/// {@endtemplate}
abstract class BaseRetriever
extends Runnable<String, BaseLangChainOptions, List<Document>> {
abstract class Retriever<Options extends RetrieverOptions>
extends Runnable<String, Options, List<Document>> {
/// {@macro base_retriever}
const BaseRetriever();
const Retriever();

/// Get the most relevant documents for a given query.
///
/// - [input] - The query to search for.
/// - [options] - Retrieval options.
@override
Future<List<Document>> invoke(
final String input, {
final BaseLangChainOptions? options,
final Options? options,
}) {
return getRelevantDocuments(input);
return getRelevantDocuments(input, options: options);
}

/// Get the most relevant documents for a given query.
///
/// - [query] - The query to search for.
Future<List<Document>> getRelevantDocuments(final String query);
/// - [options] - Retrieval options.
Future<List<Document>> getRelevantDocuments(
final String query, {
final Options? options,
});
}
8 changes: 6 additions & 2 deletions packages/langchain/lib/src/documents/retrievers/fake.dart
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import '../models/models.dart';
import 'base.dart';
import 'models/models.dart';

/// {@template fake_retriever}
/// A retriever that returns a fixed list of documents.
/// This class is meant for testing purposes only.
/// {@endtemplate}
class FakeRetriever extends BaseRetriever {
class FakeRetriever extends Retriever<RetrieverOptions> {
/// {@macro fake_retriever}
const FakeRetriever(this.docs);

/// The documents to return.
final List<Document> docs;

@override
Future<List<Document>> getRelevantDocuments(final String query) {
Future<List<Document>> getRelevantDocuments(
final String query, {
final RetrieverOptions? options,
}) {
return Future.value(docs);
}
}
29 changes: 29 additions & 0 deletions packages/langchain/lib/src/documents/retrievers/models/models.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import 'package:meta/meta.dart';

import '../../../core/base.dart';
import '../../vector_stores/models/models.dart';
import '../retrievers.dart';

/// {@template retriever_options}
/// Base class for [Retriever] options.
/// {@endtemplate}
@immutable
class RetrieverOptions extends BaseLangChainOptions {
/// {@macro retriever_options}
const RetrieverOptions();
}

/// {@template vector_store_retriever_options}
/// Options for [VectorStoreRetriever].
/// {@endtemplate}
class VectorStoreRetrieverOptions extends RetrieverOptions {
/// {@macro vector_store_retriever_options}
const VectorStoreRetrieverOptions({
this.searchType = const VectorStoreSimilaritySearch(),
});

/// The type of search to perform, either:
/// - [VectorStoreSearchType.similarity] (default)
/// - [VectorStoreSearchType.mmr]
final VectorStoreSearchType searchType;
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export 'base.dart';
export 'fake.dart';
export 'models/models.dart';
export 'vector_store.dart';
20 changes: 14 additions & 6 deletions packages/langchain/lib/src/documents/retrievers/vector_store.dart
Original file line number Diff line number Diff line change
@@ -1,25 +1,33 @@
import '../models/models.dart';
import '../vector_stores/vector_stores.dart';
import 'base.dart';
import 'models/models.dart';

/// {@template vector_store_retriever}
/// A retriever that uses a vector store to retrieve documents.
/// {@endtemplate}
class VectorStoreRetriever<V extends VectorStore> extends BaseRetriever {
class VectorStoreRetriever<V extends VectorStore>
extends Retriever<VectorStoreRetrieverOptions> {
/// {@macro vector_store_retriever}
const VectorStoreRetriever({
required this.vectorStore,
this.searchType = const VectorStoreSimilaritySearch(),
this.defaultOptions = const VectorStoreRetrieverOptions(),
});

/// The vector store to retrieve documents from.
final V vectorStore;

/// The type of search to perform.
final VectorStoreSearchType searchType;
/// Default options for this retriever.
final VectorStoreRetrieverOptions defaultOptions;

@override
Future<List<Document>> getRelevantDocuments(final String query) {
return vectorStore.search(query: query, searchType: searchType);
Future<List<Document>> getRelevantDocuments(
final String query, {
final VectorStoreRetrieverOptions? options,
}) {
return vectorStore.search(
query: query,
searchType: options?.searchType ?? defaultOptions.searchType,
);
}
}
11 changes: 5 additions & 6 deletions packages/langchain/lib/src/documents/vector_stores/base.dart
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// ignore_for_file: avoid_unused_constructor_parameters
import '../embeddings/base.dart';
import '../models/models.dart';
import '../retrievers/models/models.dart';
import '../retrievers/vector_store.dart';
import 'models/models.dart';

Expand Down Expand Up @@ -170,16 +171,14 @@ abstract class VectorStore {

/// Returns a [VectorStoreRetriever] that uses this vector store.
///
/// - [searchType] is the type of search to perform, either
/// [VectorStoreSearchType.similarity] (default) or
/// [VectorStoreSearchType.mmr].
/// - [defaultOptions] are the default options for the retriever.
VectorStoreRetriever asRetriever({
final VectorStoreSearchType searchType =
const VectorStoreSimilaritySearch(),
final VectorStoreRetrieverOptions defaultOptions =
const VectorStoreRetrieverOptions(),
}) {
return VectorStoreRetriever(
vectorStore: this,
searchType: searchType,
defaultOptions: defaultOptions,
);
}
}

0 comments on commit 208f476

Please sign in to comment.