From 827a0764b73816b92e4cf97206d39cd5d50864ee Mon Sep 17 00:00:00 2001 From: Juraj Carnogursky Date: Sun, 3 Sep 2023 22:22:00 +0200 Subject: [PATCH] Fix random test fails due to segfault by chromadb & Float32Array not shared in jest context isolation The way chromadb imports @xenova/transformers package in file chromadb/src/embeddings/TransformersEmbeddingFunction.ts:33 makes it result in random segment fault errors terminating the tests prematurely. This fix contains a code that bypasses chromadb package and directly uses the @xenova/transformers package Due to how jest isolates the context of each running test (https://github.com/xenova/transformers.js/issues/57, https://github.com/kayahr/jest-environment-node-single-context, https://github.com/jestjs/jest/issues/2549) - it makes it impossible for onnxruntime-node package to validate the array passed as an input to it is actually an `instanceof Float32Array` type. The `instanceof` results in false because the globals are different between context. This commit shares the Float32Array global between each context. --- .vscode/launch.json | 57 +++++++++++++++++++ langchain/jest.config.cjs | 2 +- langchain/jest.env.cjs | 11 ++++ langchain/src/embeddings/hf_transformers.ts | 25 ++++---- .../tests/hf_transformers.int.test.ts | 2 +- 5 files changed, 85 insertions(+), 12 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 langchain/jest.env.cjs diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000000..857eabc62808 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,57 @@ +{ + "version": "0.2.0", + "configurations": [ + { + "type": "node", + "request": "launch", + "name": "Jest single run all tests", + "program": "${workspaceRoot}/node_modules/jest/bin/jest.js", + "env": { "CI": "true" }, + "args": [ + "-c", + "./langchain/jest.config.cjs", + "--verbose", + "-i", + "--no-cache" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen" + }, + { + "type": "node", + "request": "launch", + "name": "Jest watch all tests", + "program": "${workspaceRoot}/node_modules/jest/bin/jest.js", + "args": [ + "-c", + "./langchain/jest.config.cjs", + "--verbose", + "-i", + "--no-cache", + "--watchAll" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen" + }, + { + "type": "node", + "request": "launch", + "name": "Jest watch current file", + "program": "${workspaceFolder}/node_modules/jest/bin/jest", + "env": { + "NODE_OPTIONS": "--experimental-vm-modules" + }, + "args": [ + "${fileBasename}", + "-c", + "./langchain/jest.config.cjs", + "--verbose", + "-i", + "--no-cache", + "--watchAll" + ], + "console": "integratedTerminal", + "internalConsoleOptions": "neverOpen" + } + ] +} \ No newline at end of file diff --git a/langchain/jest.config.cjs b/langchain/jest.config.cjs index cf07fac61e60..26ffcdb2cdd1 100644 --- a/langchain/jest.config.cjs +++ b/langchain/jest.config.cjs @@ -1,7 +1,7 @@ /** @type {import('ts-jest').JestConfigWithTsJest} */ module.exports = { preset: "ts-jest/presets/default-esm", - testEnvironment: "node", + testEnvironment: "./jest.env.cjs", modulePathIgnorePatterns: ["dist/", "docs/"], moduleNameMapper: { "^(\\.{1,2}/.*)\\.js$": "$1", diff --git a/langchain/jest.env.cjs b/langchain/jest.env.cjs new file mode 100644 index 000000000000..f4024e373d1c --- /dev/null +++ b/langchain/jest.env.cjs @@ -0,0 +1,11 @@ +const { TestEnvironment } = require("jest-environment-node"); + +class AdjustedTestEnvironmentToSupportFloat32Array extends TestEnvironment { + constructor(config, context) { + // Make `instanceof Float32Array return true in tests + super(config, context); + this.global.Float32Array = Float32Array; + } +} + +module.exports = AdjustedTestEnvironmentToSupportFloat32Array; diff --git a/langchain/src/embeddings/hf_transformers.ts b/langchain/src/embeddings/hf_transformers.ts index ab98eb89d1c6..e07989ea7fbb 100644 --- a/langchain/src/embeddings/hf_transformers.ts +++ b/langchain/src/embeddings/hf_transformers.ts @@ -1,4 +1,4 @@ -import { TransformersEmbeddingFunction } from "chromadb"; +import { Pipeline, pipeline } from "@xenova/transformers"; import { chunkArray } from "../util/chunk.js"; import { Embeddings, EmbeddingsParams } from "./base.js"; @@ -30,23 +30,18 @@ export class HuggingFaceTransformersEmbeddings { modelName = "Xenova/all-MiniLM-L6-v2"; - embedder = new TransformersEmbeddingFunction({ - model: this.modelName, - }); - batchSize = 512; stripNewLines = true; timeout?: number; + private pipelinePromise: Promise; + constructor(fields?: Partial) { super(fields ?? {}); this.modelName = fields?.modelName ?? this.modelName; - this.embedder = new TransformersEmbeddingFunction({ - model: this.modelName, - }); this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines; this.timeout = fields?.timeout; } @@ -61,7 +56,7 @@ export class HuggingFaceTransformersEmbeddings for (let i = 0; i < subPrompts.length; i += 1) { const input = subPrompts[i]; - const data = await this.embedder.generate(input); + const data = await this.runEmbedding(input); for (let j = 0; j < input.length; j += 1) { embeddings.push(data[j]); @@ -72,9 +67,19 @@ export class HuggingFaceTransformersEmbeddings } async embedQuery(text: string): Promise { - const data = await this.embedder.generate([ + const data = await this.runEmbedding([ this.stripNewLines ? text.replace(/\n/g, " ") : text, ]); return data[0]; } + + private async runEmbedding(texts: string[]) { + const pipe = await (this.pipelinePromise ??= pipeline( + "feature-extraction", + this.modelName + )); + + const output = await pipe(texts, { pooling: "mean", normalize: true }); + return output.tolist(); + } } diff --git a/langchain/src/embeddings/tests/hf_transformers.int.test.ts b/langchain/src/embeddings/tests/hf_transformers.int.test.ts index 254affcd8232..0a15a8e1d130 100644 --- a/langchain/src/embeddings/tests/hf_transformers.int.test.ts +++ b/langchain/src/embeddings/tests/hf_transformers.int.test.ts @@ -16,7 +16,7 @@ test("HuggingFaceTransformersEmbeddings", async () => { ]; const queryEmbedding = await embeddings.embedQuery(documents[0]); - expect(queryEmbedding).toHaveLength(768); + expect(queryEmbedding).toHaveLength(384); expect(typeof queryEmbedding[0]).toBe("number"); const store = new MemoryVectorStore(embeddings);