Skip to content

Commit

Permalink
update: refractor onnx (#87)
Browse files Browse the repository at this point in the history
* update: refractor onnx for shared code between nodejs/browser

* update: use alias for isomorphic import
  • Loading branch information
hlhr202 committed May 29, 2023
1 parent 649457a commit ea663be
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 464 deletions.
2 changes: 2 additions & 0 deletions packages/gpt2-onnx/example/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { GPT2Onnx } from "../src/node";
import path from "path";
import { AutoTokenizer } from "../src/node/tokenizer";

const modelPath = path.join(process.cwd(), "../../gpt2.onnx");

Expand All @@ -16,6 +17,7 @@ const run = async () => {
const gpt2 = await GPT2Onnx.create({
modelPath,
tokenizerUrl,
tokenizer: new AutoTokenizer(),
});

process.stdout.write(prompt);
Expand Down
1 change: 1 addition & 0 deletions packages/gpt2-onnx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"@llama-node/hf-tokenizer": "0.0.37",
"@tensorflow/tfjs": "^4.6.0",
"@tensorflow/tfjs-node": "^4.6.0",
"axios": "*",
"onnxruntime-node": "^1.14.0",
"onnxruntime-web": "^1.14.0",
"react": "^18.2.0",
Expand Down
106 changes: 2 additions & 104 deletions packages/gpt2-onnx/src/node/index.ts
Original file line number Diff line number Diff line change
@@ -1,105 +1,3 @@
import * as onnx from "onnxruntime-node";
import * as tfjs from "@tensorflow/tfjs-node";
import { AutoTokenizer } from "./tokenizer";
import type { IsomorphicTokenizer } from "../shared/tokenizer";
import { IsomorphicContext as GPT2Onnx } from "../shared/context";

interface IGPT2OnnxOptions {
modelPath: string;
tokenizerUrl: string;
}

interface IGPT2OnnxInferenceOptions {
numPredict?: number;
prompt: string;
topK?: number;
endToken?: number;
onProgress: (data: string) => void;
}

export class GPT2Onnx {
tokenizer: IsomorphicTokenizer | undefined;
session: onnx.InferenceSession | undefined;

static async create(options: IGPT2OnnxOptions) {
const tokenizer = new AutoTokenizer();
await tokenizer.initFromUrl(options.tokenizerUrl);
const gpt2Onnx = new GPT2Onnx();
gpt2Onnx.tokenizer = tokenizer;
gpt2Onnx.session = await onnx.InferenceSession.create(
options.modelPath
);

return gpt2Onnx;
}

free() {
this.tokenizer?.free();
}

getLogits(onnxTensor: onnx.TypedTensor<"float32">) {
let output = tfjs
.tensor<tfjs.Rank.R3>(onnxTensor.data, onnxTensor.dims as any)
.slice(0, 1);

return output
.slice(
[0, output.shape[1] - 1, 0],
[output.shape[0], 1, output.shape[2]]
)
.squeeze();
}

async inference(inferArgs: IGPT2OnnxInferenceOptions) {
if (!this.tokenizer) {
throw new Error("Tokenizer not initialized");
}

if (!this.session) {
throw new Error("Session not initialized");
}

const numPredict = inferArgs.numPredict ?? 128;
const topK = inferArgs.topK ?? 1;

let text = inferArgs.prompt;
let remain = numPredict;

while (remain > 0) {
remain -= 1;

const inputs = this.tokenizer.tokenize(text, true);

const result = await this.session.run(inputs);

const logits = this.getLogits(
result["last_hidden_state"] as onnx.TypedTensor<"float32">
);

let probs = tfjs.softmax(logits, -1);

// TODO: implement topP
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze();

const token = probs.dataSync();

// TODO: implement end of sentence
if (
token[0] >= 50256 ||
token[0] === 0 ||
token[0] === 1 ||
(inferArgs.endToken && token[0] === inferArgs.endToken)
) {
break;
}

const tokenText = this.tokenizer.decode(
Uint32Array.from(token),
true
);

inferArgs.onProgress(tokenText);

text += tokenText;
}
}
}
export { GPT2Onnx };
8 changes: 1 addition & 7 deletions packages/gpt2-onnx/src/node/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,3 @@
import { Tokenizer } from "@llama-node/hf-tokenizer";
import { Tensor } from "onnxruntime-node";
import { IsomorphicTokenizer } from "../shared/tokenizer";

export class AutoTokenizer extends IsomorphicTokenizer {
constructor() {
super(Tokenizer, Tensor);
}
}
export class AutoTokenizer extends IsomorphicTokenizer {}
106 changes: 106 additions & 0 deletions packages/gpt2-onnx/src/shared/context.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import * as tfjs from "@tensorflow/tfjs-node";
import * as onnx from "onnxruntime-node";
import type { Rank } from "@tensorflow/tfjs";
import type { IsomorphicTokenizer } from "./tokenizer";
import type { InferenceSession, TypedTensor } from "onnxruntime-node";

interface IGPT2OnnxOptions {
tokenizer: IsomorphicTokenizer;
modelPath: string | ArrayBufferLike;
tokenizerUrl: string;
}

interface IGPT2OnnxInferenceOptions {
numPredict?: number;
prompt: string;
topK?: number;
endToken?: number;
onProgress: (data: string) => void;
}

export class IsomorphicContext {
tokenizer?: IsomorphicTokenizer;
session?: InferenceSession;

static async create(options: IGPT2OnnxOptions) {
const tokenizer = options.tokenizer;
await tokenizer.initFromUrl(options.tokenizerUrl);
const gpt2Onnx = new IsomorphicContext();

gpt2Onnx.tokenizer = tokenizer;
gpt2Onnx.session = await onnx.InferenceSession.create(
options.modelPath as ArrayBufferLike
);

return gpt2Onnx;
}

free() {
this.tokenizer?.free();
}

getLogits(onnxTensor: TypedTensor<"float32">) {
let output = tfjs
.tensor<Rank.R3>(onnxTensor.data, onnxTensor.dims as any)
.slice(0, 1);

return output
.slice(
[0, output.shape[1] - 1, 0],
[output.shape[0], 1, output.shape[2]]
)
.squeeze();
}

async inference(inferArgs: IGPT2OnnxInferenceOptions) {
if (!this.tokenizer) {
throw new Error("Tokenizer not initialized");
}

if (!this.session) {
throw new Error("Session not initialized");
}

const numPredict = inferArgs.numPredict ?? 128;
const topK = inferArgs.topK ?? 1;

let remain = numPredict;
const tokens = this.tokenizer.tokenize(inferArgs.prompt, true);

while (remain > 0) {
remain -= 1;

const inputs = this.tokenizer.toOnnx(tokens);
const result = await this.session.run(inputs);

const logits = this.getLogits(
result["last_hidden_state"] as TypedTensor<"float32">
);

let probs = tfjs.softmax(logits, -1);

// TODO: implement topP
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze();

const token = probs.dataSync();

// TODO: implement end of sentence
if (
token[0] >= 50256 ||
token[0] === 0 ||
token[0] === 1 ||
(inferArgs.endToken && token[0] === inferArgs.endToken)
) {
break;
}

const tokenText = this.tokenizer.decode(
Uint32Array.from(token),
true
);

inferArgs.onProgress(tokenText);
tokens.push(BigInt(token[0]));
}
}
}
56 changes: 34 additions & 22 deletions packages/gpt2-onnx/src/shared/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import type { Tokenizer } from "@llama-node/hf-tokenizer";
import { Tokenizer } from "@llama-node/hf-tokenizer/nodejs/tokenizer-node";
import { Tensor } from "onnxruntime-node";
import axios from "axios";
import type { TensorConstructor } from "onnxruntime-node";

export class IsomorphicTokenizer {
tokenizer?: Tokenizer;

constructor(
private tokenizerConstructor: new (json: string) => Tokenizer,
private tensorConstructor: TensorConstructor
) {}
Tensor?: TensorConstructor;

async initFromUrl(url: string) {
const json = await (await fetch(url)).json();
const json = await axios.get(url).then((res) => res.data);

this.tokenizer = new this.tokenizerConstructor(JSON.stringify(json));
this.Tensor = Tensor;
this.tokenizer = new Tokenizer(JSON.stringify(json));
}

free() {
Expand All @@ -23,28 +22,41 @@ export class IsomorphicTokenizer {
return this.tokenizer?.decode(ids, skipSpecialTokens) ?? "";
}

tokenize(text: string, addSpecialTokens = true) {
const encoded = this.tokenizer?.encode(text, addSpecialTokens);
toOnnx(input: bigint[]) {
if (!this.Tensor) {
throw new Error("Onnx is not initialized");
}

const inputIdArray = BigInt64Array.from(
Array.from(encoded?.ids ?? []).map((x) => BigInt(x))
);
const inputArray = BigInt64Array.from(input);

encoded?.free();
const input_ids = new this.Tensor("int64", inputArray, [
1,
inputArray.length,
]);

const input_ids = new this.tensorConstructor("int64", inputIdArray, [
const attentionMaskArray = inputArray.map(() => BigInt(1));

const attention_mask = new this.Tensor("int64", attentionMaskArray, [
1,
inputIdArray.length,
attentionMaskArray.length,
]);

const attentionMaskArray = inputIdArray.map(() => BigInt(1));
return { input_ids, attention_mask };
}

tokenize(text: string, addSpecialTokens = true) {
if (!this.Tensor) {
throw new Error("Onnx is not initialized");
}

const encoded = this.tokenizer?.encode(text, addSpecialTokens);

const attention_mask = new this.tensorConstructor(
"int64",
attentionMaskArray,
[1, attentionMaskArray.length]
const inputIdArray = Array.from(encoded?.ids ?? []).map((x) =>
BigInt(x)
);

return { input_ids, attention_mask };
encoded?.free();

return inputIdArray;
}
}

0 comments on commit ea663be

Please sign in to comment.