Skip to content

Commit

Permalink
feat: impl gpt2 in web
Browse files Browse the repository at this point in the history
  • Loading branch information
hlhr202 committed May 18, 2023
1 parent ff60d9d commit 24f9d6b
Show file tree
Hide file tree
Showing 16 changed files with 949 additions and 98 deletions.
1 change: 1 addition & 0 deletions packages/gpt2-onnx/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
public
2 changes: 1 addition & 1 deletion packages/gpt2-onnx/example/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { GPT2Onnx } from "../src";
import { GPT2Onnx } from "../src/node";
import path from "path";

const modelPath = path.join(process.cwd(), "../../gpt2.onnx");
Expand Down
16 changes: 16 additions & 0 deletions packages/gpt2-onnx/index.html
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
<!DOCTYPE html>
<html lang="en">

<head>
<meta charset="UTF-8">
<meta http-equiv="X-UA-Compatible" content="IE=edge">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Document</title>
</head>

<body>
<div id="root"></div>
<script type="module" src="./src/web.tsx"></script>
</body>

</html>
16 changes: 14 additions & 2 deletions packages/gpt2-onnx/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,27 @@
"description": "",
"main": "index.js",
"scripts": {
"serve": "vite serve",
"test": "cross-env TF_CPP_MIN_LOG_LEVEL=2 tsx example/index.ts"
},
"keywords": [],
"author": "",
"license": "ISC",
"dependencies": {
"@llama-node/hf-tokenizer": "0.0.37",
"@tensorflow/tfjs-node": "^4.2.0",
"@tensorflow/tfjs": "^4.6.0",
"@tensorflow/tfjs-node": "^4.6.0",
"onnxruntime-node": "^1.14.0",
"onnxruntime-web": "^1.14.0"
"onnxruntime-web": "^1.14.0",
"react": "^18.2.0",
"react-dom": "^18.2.0"
},
"devDependencies": {
"@types/react": "^18.2.6",
"@types/react-dom": "^18.2.4",
"@vitejs/plugin-react": "^4.0.0",
"rollup-plugin-copy": "^3.4.0",
"vite-plugin-top-level-await": "^1.3.0",
"vite-plugin-wasm": "^3.2.2"
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import * as onnx from "onnxruntime-node";
import * as tfjs from "@tensorflow/tfjs-node";
import { AutoTokenizer } from "./tokenizer";
import type { IsomorphicTokenizer } from "../shared/tokenizer";

interface IGPT2OnnxOptions {
modelPath: string;
Expand All @@ -16,12 +17,14 @@ interface IGPT2OnnxInferenceOptions {
}

export class GPT2Onnx {
tokenizer: AutoTokenizer | undefined;
tokenizer: IsomorphicTokenizer | undefined;
session: onnx.InferenceSession | undefined;

static async create(options: IGPT2OnnxOptions) {
const tokenizer = new AutoTokenizer();
await tokenizer.initFromUrl(options.tokenizerUrl);
const gpt2Onnx = new GPT2Onnx();
gpt2Onnx.tokenizer = await AutoTokenizer.fromUrl(options.tokenizerUrl);
gpt2Onnx.tokenizer = tokenizer;
gpt2Onnx.session = await onnx.InferenceSession.create(
options.modelPath
);
Expand Down
9 changes: 9 additions & 0 deletions packages/gpt2-onnx/src/node/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import { Tokenizer } from "@llama-node/hf-tokenizer";
import { Tensor } from "onnxruntime-node";
import { IsomorphicTokenizer } from "../shared/tokenizer";

export class AutoTokenizer extends IsomorphicTokenizer {
constructor() {
super(Tokenizer, Tensor);
}
}
50 changes: 50 additions & 0 deletions packages/gpt2-onnx/src/shared/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
import type { Tokenizer } from "@llama-node/hf-tokenizer";
import type { TensorConstructor } from "onnxruntime-node";

export class IsomorphicTokenizer {
tokenizer?: Tokenizer;

constructor(
private tokenizerConstructor: new (json: string) => Tokenizer,
private tensorConstructor: TensorConstructor
) {}

async initFromUrl(url: string) {
const json = await (await fetch(url)).json();

this.tokenizer = new this.tokenizerConstructor(JSON.stringify(json));
}

free() {
this.tokenizer?.free();
}

decode(ids: Uint32Array, skipSpecialTokens = false) {
return this.tokenizer?.decode(ids, skipSpecialTokens) ?? "";
}

tokenize(text: string, addSpecialTokens = true) {
const encoded = this.tokenizer?.encode(text, addSpecialTokens);

const inputIdArray = BigInt64Array.from(
Array.from(encoded?.ids ?? []).map((x) => BigInt(x))
);

encoded?.free();

const input_ids = new this.tensorConstructor("int64", inputIdArray, [
1,
inputIdArray.length,
]);

const attentionMaskArray = inputIdArray.map(() => BigInt(1));

const attention_mask = new this.tensorConstructor(
"int64",
attentionMaskArray,
[1, attentionMaskArray.length]
);

return { input_ids, attention_mask };
}
}
49 changes: 0 additions & 49 deletions packages/gpt2-onnx/src/tokenizer.ts

This file was deleted.

1 change: 1 addition & 0 deletions packages/gpt2-onnx/src/vite-env.d.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
/// <reference types="vite/client" />
76 changes: 76 additions & 0 deletions packages/gpt2-onnx/src/web.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import React from "react";
import { createRoot } from "react-dom/client";
import "./web/worker";
import type { ICommand } from "./web/worker";
import MyWorker from "./web/worker?worker";

const tokenizerUrl = "https://huggingface.co/gpt2/raw/main/tokenizer.json";

const worker = new MyWorker();

console.log("worker", worker);

const App = () => {
const [inputText, setInputText] = React.useState("");
const [inferenceText, setInferenceText] = React.useState("");

const handleLoad = React.useCallback((model: ArrayBufferLike) => {
const message: ICommand = {
type: "load",
data: {
model,
tokenizerUrl,
},
};
worker.postMessage(message);
}, []);

const handleInference = React.useCallback(() => {
const message: ICommand = {
type: "inference",
data: {
prompt: inputText,
topK: 1,
numPredict: 128,
},
};
worker.postMessage(message);
}, [inputText]);

const handleFileChange = React.useCallback(
async (e: React.ChangeEvent<HTMLInputElement>) => {
const model = await e.target.files?.[0]?.arrayBuffer();
if (!model) {
throw new Error("No model");
} else {
return handleLoad(model);
}
},
[]
);

React.useEffect(() => {
worker.onmessage = (e) => {
setInferenceText((text) => text + e.data);
};
}, []);

return (
<div>
<div>
<input onChange={handleFileChange} type="file"></input>
</div>
<div>
<input
value={inputText}
onChange={(e) => setInputText(e.target.value)}
/>
<button onClick={handleInference}>Inference</button>
</div>
<div>{inferenceText}</div>
</div>
);
};

const root = createRoot(document.getElementById("root")!);
root.render(<App />);
105 changes: 105 additions & 0 deletions packages/gpt2-onnx/src/web/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import * as onnx from "onnxruntime-web";
import * as tfjs from "@tensorflow/tfjs";
import { AutoTokenizer } from "./tokenizer";
import type { IsomorphicTokenizer } from "../shared/tokenizer";

interface IGPT2OnnxOptions {
modelPath: string | ArrayBufferLike;
tokenizerUrl: string;
}

interface IGPT2OnnxInferenceOptions {
numPredict?: number;
prompt: string;
topK?: number;
endToken?: number;
onProgress: (data: string) => void;
}

export class GPT2Onnx {
tokenizer: IsomorphicTokenizer | undefined;
session: onnx.InferenceSession | undefined;

static async create(options: IGPT2OnnxOptions) {
const tokenizer = new AutoTokenizer();
await tokenizer.initFromUrl(options.tokenizerUrl);
const gpt2Onnx = new GPT2Onnx();
gpt2Onnx.tokenizer = tokenizer;
gpt2Onnx.session = await onnx.InferenceSession.create(
options.modelPath as ArrayBufferLike
);

return gpt2Onnx;
}

free() {
this.tokenizer?.free();
}

getLogits(onnxTensor: onnx.TypedTensor<"float32">) {
let output = tfjs
.tensor<tfjs.Rank.R3>(onnxTensor.data, onnxTensor.dims as any)
.slice(0, 1);

return output
.slice(
[0, output.shape[1] - 1, 0],
[output.shape[0], 1, output.shape[2]]
)
.squeeze();
}

async inference(inferArgs: IGPT2OnnxInferenceOptions) {
if (!this.tokenizer) {
throw new Error("Tokenizer not initialized");
}

if (!this.session) {
throw new Error("Session not initialized");
}

const numPredict = inferArgs.numPredict ?? 128;
const topK = inferArgs.topK ?? 1;

let text = inferArgs.prompt;
let remain = numPredict;

while (remain > 0) {
remain -= 1;

const inputs = this.tokenizer.tokenize(text, true);

const result = await this.session.run(inputs);

const logits = this.getLogits(
result["last_hidden_state"] as onnx.TypedTensor<"float32">
);

let probs = tfjs.softmax(logits, -1);

// TODO: implement topP
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze();

const token = probs.dataSync();

// TODO: implement end of sentence
if (
token[0] >= 50256 ||
token[0] === 0 ||
token[0] === 1 ||
(inferArgs.endToken && token[0] === inferArgs.endToken)
) {
break;
}

const tokenText = this.tokenizer.decode(
Uint32Array.from(token),
true
);

inferArgs.onProgress(tokenText);

text += tokenText;
}
}
}
15 changes: 15 additions & 0 deletions packages/gpt2-onnx/src/web/tokenizer.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import init, { Tokenizer } from "@llama-node/hf-tokenizer/web";
import { Tensor } from "onnxruntime-web";
import { IsomorphicTokenizer } from "../shared/tokenizer";

export class AutoTokenizer extends IsomorphicTokenizer {
constructor() {
super(Tokenizer, Tensor);
}

async initFromUrl(url: string) {
await init();

await super.initFromUrl(url);
}
}

0 comments on commit 24f9d6b

Please sign in to comment.