-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
949 additions
and
98 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
public |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
<!DOCTYPE html> | ||
<html lang="en"> | ||
|
||
<head> | ||
<meta charset="UTF-8"> | ||
<meta http-equiv="X-UA-Compatible" content="IE=edge"> | ||
<meta name="viewport" content="width=device-width, initial-scale=1.0"> | ||
<title>Document</title> | ||
</head> | ||
|
||
<body> | ||
<div id="root"></div> | ||
<script type="module" src="./src/web.tsx"></script> | ||
</body> | ||
|
||
</html> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import { Tokenizer } from "@llama-node/hf-tokenizer"; | ||
import { Tensor } from "onnxruntime-node"; | ||
import { IsomorphicTokenizer } from "../shared/tokenizer"; | ||
|
||
export class AutoTokenizer extends IsomorphicTokenizer { | ||
constructor() { | ||
super(Tokenizer, Tensor); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
import type { Tokenizer } from "@llama-node/hf-tokenizer"; | ||
import type { TensorConstructor } from "onnxruntime-node"; | ||
|
||
export class IsomorphicTokenizer { | ||
tokenizer?: Tokenizer; | ||
|
||
constructor( | ||
private tokenizerConstructor: new (json: string) => Tokenizer, | ||
private tensorConstructor: TensorConstructor | ||
) {} | ||
|
||
async initFromUrl(url: string) { | ||
const json = await (await fetch(url)).json(); | ||
|
||
this.tokenizer = new this.tokenizerConstructor(JSON.stringify(json)); | ||
} | ||
|
||
free() { | ||
this.tokenizer?.free(); | ||
} | ||
|
||
decode(ids: Uint32Array, skipSpecialTokens = false) { | ||
return this.tokenizer?.decode(ids, skipSpecialTokens) ?? ""; | ||
} | ||
|
||
tokenize(text: string, addSpecialTokens = true) { | ||
const encoded = this.tokenizer?.encode(text, addSpecialTokens); | ||
|
||
const inputIdArray = BigInt64Array.from( | ||
Array.from(encoded?.ids ?? []).map((x) => BigInt(x)) | ||
); | ||
|
||
encoded?.free(); | ||
|
||
const input_ids = new this.tensorConstructor("int64", inputIdArray, [ | ||
1, | ||
inputIdArray.length, | ||
]); | ||
|
||
const attentionMaskArray = inputIdArray.map(() => BigInt(1)); | ||
|
||
const attention_mask = new this.tensorConstructor( | ||
"int64", | ||
attentionMaskArray, | ||
[1, attentionMaskArray.length] | ||
); | ||
|
||
return { input_ids, attention_mask }; | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
/// <reference types="vite/client" /> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import React from "react"; | ||
import { createRoot } from "react-dom/client"; | ||
import "./web/worker"; | ||
import type { ICommand } from "./web/worker"; | ||
import MyWorker from "./web/worker?worker"; | ||
|
||
const tokenizerUrl = "https://huggingface.co/gpt2/raw/main/tokenizer.json"; | ||
|
||
const worker = new MyWorker(); | ||
|
||
console.log("worker", worker); | ||
|
||
const App = () => { | ||
const [inputText, setInputText] = React.useState(""); | ||
const [inferenceText, setInferenceText] = React.useState(""); | ||
|
||
const handleLoad = React.useCallback((model: ArrayBufferLike) => { | ||
const message: ICommand = { | ||
type: "load", | ||
data: { | ||
model, | ||
tokenizerUrl, | ||
}, | ||
}; | ||
worker.postMessage(message); | ||
}, []); | ||
|
||
const handleInference = React.useCallback(() => { | ||
const message: ICommand = { | ||
type: "inference", | ||
data: { | ||
prompt: inputText, | ||
topK: 1, | ||
numPredict: 128, | ||
}, | ||
}; | ||
worker.postMessage(message); | ||
}, [inputText]); | ||
|
||
const handleFileChange = React.useCallback( | ||
async (e: React.ChangeEvent<HTMLInputElement>) => { | ||
const model = await e.target.files?.[0]?.arrayBuffer(); | ||
if (!model) { | ||
throw new Error("No model"); | ||
} else { | ||
return handleLoad(model); | ||
} | ||
}, | ||
[] | ||
); | ||
|
||
React.useEffect(() => { | ||
worker.onmessage = (e) => { | ||
setInferenceText((text) => text + e.data); | ||
}; | ||
}, []); | ||
|
||
return ( | ||
<div> | ||
<div> | ||
<input onChange={handleFileChange} type="file"></input> | ||
</div> | ||
<div> | ||
<input | ||
value={inputText} | ||
onChange={(e) => setInputText(e.target.value)} | ||
/> | ||
<button onClick={handleInference}>Inference</button> | ||
</div> | ||
<div>{inferenceText}</div> | ||
</div> | ||
); | ||
}; | ||
|
||
const root = createRoot(document.getElementById("root")!); | ||
root.render(<App />); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import * as onnx from "onnxruntime-web"; | ||
import * as tfjs from "@tensorflow/tfjs"; | ||
import { AutoTokenizer } from "./tokenizer"; | ||
import type { IsomorphicTokenizer } from "../shared/tokenizer"; | ||
|
||
interface IGPT2OnnxOptions { | ||
modelPath: string | ArrayBufferLike; | ||
tokenizerUrl: string; | ||
} | ||
|
||
interface IGPT2OnnxInferenceOptions { | ||
numPredict?: number; | ||
prompt: string; | ||
topK?: number; | ||
endToken?: number; | ||
onProgress: (data: string) => void; | ||
} | ||
|
||
export class GPT2Onnx { | ||
tokenizer: IsomorphicTokenizer | undefined; | ||
session: onnx.InferenceSession | undefined; | ||
|
||
static async create(options: IGPT2OnnxOptions) { | ||
const tokenizer = new AutoTokenizer(); | ||
await tokenizer.initFromUrl(options.tokenizerUrl); | ||
const gpt2Onnx = new GPT2Onnx(); | ||
gpt2Onnx.tokenizer = tokenizer; | ||
gpt2Onnx.session = await onnx.InferenceSession.create( | ||
options.modelPath as ArrayBufferLike | ||
); | ||
|
||
return gpt2Onnx; | ||
} | ||
|
||
free() { | ||
this.tokenizer?.free(); | ||
} | ||
|
||
getLogits(onnxTensor: onnx.TypedTensor<"float32">) { | ||
let output = tfjs | ||
.tensor<tfjs.Rank.R3>(onnxTensor.data, onnxTensor.dims as any) | ||
.slice(0, 1); | ||
|
||
return output | ||
.slice( | ||
[0, output.shape[1] - 1, 0], | ||
[output.shape[0], 1, output.shape[2]] | ||
) | ||
.squeeze(); | ||
} | ||
|
||
async inference(inferArgs: IGPT2OnnxInferenceOptions) { | ||
if (!this.tokenizer) { | ||
throw new Error("Tokenizer not initialized"); | ||
} | ||
|
||
if (!this.session) { | ||
throw new Error("Session not initialized"); | ||
} | ||
|
||
const numPredict = inferArgs.numPredict ?? 128; | ||
const topK = inferArgs.topK ?? 1; | ||
|
||
let text = inferArgs.prompt; | ||
let remain = numPredict; | ||
|
||
while (remain > 0) { | ||
remain -= 1; | ||
|
||
const inputs = this.tokenizer.tokenize(text, true); | ||
|
||
const result = await this.session.run(inputs); | ||
|
||
const logits = this.getLogits( | ||
result["last_hidden_state"] as onnx.TypedTensor<"float32"> | ||
); | ||
|
||
let probs = tfjs.softmax(logits, -1); | ||
|
||
// TODO: implement topP | ||
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze(); | ||
|
||
const token = probs.dataSync(); | ||
|
||
// TODO: implement end of sentence | ||
if ( | ||
token[0] >= 50256 || | ||
token[0] === 0 || | ||
token[0] === 1 || | ||
(inferArgs.endToken && token[0] === inferArgs.endToken) | ||
) { | ||
break; | ||
} | ||
|
||
const tokenText = this.tokenizer.decode( | ||
Uint32Array.from(token), | ||
true | ||
); | ||
|
||
inferArgs.onProgress(tokenText); | ||
|
||
text += tokenText; | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import init, { Tokenizer } from "@llama-node/hf-tokenizer/web"; | ||
import { Tensor } from "onnxruntime-web"; | ||
import { IsomorphicTokenizer } from "../shared/tokenizer"; | ||
|
||
export class AutoTokenizer extends IsomorphicTokenizer { | ||
constructor() { | ||
super(Tokenizer, Tensor); | ||
} | ||
|
||
async initFromUrl(url: string) { | ||
await init(); | ||
|
||
await super.initFromUrl(url); | ||
} | ||
} |
Oops, something went wrong.