Skip to content

Commit

Permalink
Fix random test fails due to segfault by chromadb & Float32Array not …
Browse files Browse the repository at this point in the history
…shared in jest context isolation

The way chromadb imports @xenova/transformers package in file
chromadb/src/embeddings/TransformersEmbeddingFunction.ts:33 makes it result in random segment fault errors terminating the tests prematurely.
This fix contains a code that bypasses chromadb package and directly uses the @xenova/transformers package

Due to how jest isolates the context of each running test (xenova/transformers.js#57, https://github.com/kayahr/jest-environment-node-single-context,
jestjs/jest#2549) - it makes it impossible for onnxruntime-node package to validate the array passed as an input to it is actually an `instanceof Float32Array`
type. The `instanceof` results in false because the globals are different between context. This commit shares the Float32Array global between each context.
  • Loading branch information
Juraj Carnogursky committed Sep 3, 2023
1 parent d3c1247 commit 827a076
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 12 deletions.
57 changes: 57 additions & 0 deletions .vscode/launch.json
@@ -0,0 +1,57 @@
{
"version": "0.2.0",
"configurations": [
{
"type": "node",
"request": "launch",
"name": "Jest single run all tests",
"program": "${workspaceRoot}/node_modules/jest/bin/jest.js",
"env": { "CI": "true" },
"args": [
"-c",
"./langchain/jest.config.cjs",
"--verbose",
"-i",
"--no-cache"
],
"console": "integratedTerminal",
"internalConsoleOptions": "neverOpen"
},
{
"type": "node",
"request": "launch",
"name": "Jest watch all tests",
"program": "${workspaceRoot}/node_modules/jest/bin/jest.js",
"args": [
"-c",
"./langchain/jest.config.cjs",
"--verbose",
"-i",
"--no-cache",
"--watchAll"
],
"console": "integratedTerminal",
"internalConsoleOptions": "neverOpen"
},
{
"type": "node",
"request": "launch",
"name": "Jest watch current file",
"program": "${workspaceFolder}/node_modules/jest/bin/jest",
"env": {
"NODE_OPTIONS": "--experimental-vm-modules"
},
"args": [
"${fileBasename}",
"-c",
"./langchain/jest.config.cjs",
"--verbose",
"-i",
"--no-cache",
"--watchAll"
],
"console": "integratedTerminal",
"internalConsoleOptions": "neverOpen"
}
]
}
2 changes: 1 addition & 1 deletion langchain/jest.config.cjs
@@ -1,7 +1,7 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: "ts-jest/presets/default-esm",
testEnvironment: "node",
testEnvironment: "./jest.env.cjs",
modulePathIgnorePatterns: ["dist/", "docs/"],
moduleNameMapper: {
"^(\\.{1,2}/.*)\\.js$": "$1",
Expand Down
11 changes: 11 additions & 0 deletions langchain/jest.env.cjs
@@ -0,0 +1,11 @@
const { TestEnvironment } = require("jest-environment-node");

class AdjustedTestEnvironmentToSupportFloat32Array extends TestEnvironment {
constructor(config, context) {
// Make `instanceof Float32Array return true in tests
super(config, context);
this.global.Float32Array = Float32Array;
}
}

module.exports = AdjustedTestEnvironmentToSupportFloat32Array;
25 changes: 15 additions & 10 deletions langchain/src/embeddings/hf_transformers.ts
@@ -1,4 +1,4 @@
import { TransformersEmbeddingFunction } from "chromadb";
import { Pipeline, pipeline } from "@xenova/transformers";
import { chunkArray } from "../util/chunk.js";
import { Embeddings, EmbeddingsParams } from "./base.js";

Expand Down Expand Up @@ -30,23 +30,18 @@ export class HuggingFaceTransformersEmbeddings
{
modelName = "Xenova/all-MiniLM-L6-v2";

embedder = new TransformersEmbeddingFunction({
model: this.modelName,
});

batchSize = 512;

stripNewLines = true;

timeout?: number;

private pipelinePromise: Promise<Pipeline>;

constructor(fields?: Partial<HuggingFaceTransformersEmbeddingsParams>) {
super(fields ?? {});

this.modelName = fields?.modelName ?? this.modelName;
this.embedder = new TransformersEmbeddingFunction({
model: this.modelName,
});
this.stripNewLines = fields?.stripNewLines ?? this.stripNewLines;
this.timeout = fields?.timeout;
}
Expand All @@ -61,7 +56,7 @@ export class HuggingFaceTransformersEmbeddings

for (let i = 0; i < subPrompts.length; i += 1) {
const input = subPrompts[i];
const data = await this.embedder.generate(input);
const data = await this.runEmbedding(input);

for (let j = 0; j < input.length; j += 1) {
embeddings.push(data[j]);
Expand All @@ -72,9 +67,19 @@ export class HuggingFaceTransformersEmbeddings
}

async embedQuery(text: string): Promise<number[]> {
const data = await this.embedder.generate([
const data = await this.runEmbedding([
this.stripNewLines ? text.replace(/\n/g, " ") : text,
]);
return data[0];
}

private async runEmbedding(texts: string[]) {
const pipe = await (this.pipelinePromise ??= pipeline(
"feature-extraction",
this.modelName
));

const output = await pipe(texts, { pooling: "mean", normalize: true });
return output.tolist();
}
}
2 changes: 1 addition & 1 deletion langchain/src/embeddings/tests/hf_transformers.int.test.ts
Expand Up @@ -16,7 +16,7 @@ test("HuggingFaceTransformersEmbeddings", async () => {
];

const queryEmbedding = await embeddings.embedQuery(documents[0]);
expect(queryEmbedding).toHaveLength(768);
expect(queryEmbedding).toHaveLength(384);
expect(typeof queryEmbedding[0]).toBe("number");

const store = new MemoryVectorStore(embeddings);
Expand Down

0 comments on commit 827a076

Please sign in to comment.