-
Notifications
You must be signed in to change notification settings - Fork 62
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* update: refractor onnx for shared code between nodejs/browser * update: use alias for isomorphic import
- Loading branch information
Showing
12 changed files
with
243 additions
and
464 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,105 +1,3 @@ | ||
import * as onnx from "onnxruntime-node"; | ||
import * as tfjs from "@tensorflow/tfjs-node"; | ||
import { AutoTokenizer } from "./tokenizer"; | ||
import type { IsomorphicTokenizer } from "../shared/tokenizer"; | ||
import { IsomorphicContext as GPT2Onnx } from "../shared/context"; | ||
|
||
interface IGPT2OnnxOptions { | ||
modelPath: string; | ||
tokenizerUrl: string; | ||
} | ||
|
||
interface IGPT2OnnxInferenceOptions { | ||
numPredict?: number; | ||
prompt: string; | ||
topK?: number; | ||
endToken?: number; | ||
onProgress: (data: string) => void; | ||
} | ||
|
||
export class GPT2Onnx { | ||
tokenizer: IsomorphicTokenizer | undefined; | ||
session: onnx.InferenceSession | undefined; | ||
|
||
static async create(options: IGPT2OnnxOptions) { | ||
const tokenizer = new AutoTokenizer(); | ||
await tokenizer.initFromUrl(options.tokenizerUrl); | ||
const gpt2Onnx = new GPT2Onnx(); | ||
gpt2Onnx.tokenizer = tokenizer; | ||
gpt2Onnx.session = await onnx.InferenceSession.create( | ||
options.modelPath | ||
); | ||
|
||
return gpt2Onnx; | ||
} | ||
|
||
free() { | ||
this.tokenizer?.free(); | ||
} | ||
|
||
getLogits(onnxTensor: onnx.TypedTensor<"float32">) { | ||
let output = tfjs | ||
.tensor<tfjs.Rank.R3>(onnxTensor.data, onnxTensor.dims as any) | ||
.slice(0, 1); | ||
|
||
return output | ||
.slice( | ||
[0, output.shape[1] - 1, 0], | ||
[output.shape[0], 1, output.shape[2]] | ||
) | ||
.squeeze(); | ||
} | ||
|
||
async inference(inferArgs: IGPT2OnnxInferenceOptions) { | ||
if (!this.tokenizer) { | ||
throw new Error("Tokenizer not initialized"); | ||
} | ||
|
||
if (!this.session) { | ||
throw new Error("Session not initialized"); | ||
} | ||
|
||
const numPredict = inferArgs.numPredict ?? 128; | ||
const topK = inferArgs.topK ?? 1; | ||
|
||
let text = inferArgs.prompt; | ||
let remain = numPredict; | ||
|
||
while (remain > 0) { | ||
remain -= 1; | ||
|
||
const inputs = this.tokenizer.tokenize(text, true); | ||
|
||
const result = await this.session.run(inputs); | ||
|
||
const logits = this.getLogits( | ||
result["last_hidden_state"] as onnx.TypedTensor<"float32"> | ||
); | ||
|
||
let probs = tfjs.softmax(logits, -1); | ||
|
||
// TODO: implement topP | ||
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze(); | ||
|
||
const token = probs.dataSync(); | ||
|
||
// TODO: implement end of sentence | ||
if ( | ||
token[0] >= 50256 || | ||
token[0] === 0 || | ||
token[0] === 1 || | ||
(inferArgs.endToken && token[0] === inferArgs.endToken) | ||
) { | ||
break; | ||
} | ||
|
||
const tokenText = this.tokenizer.decode( | ||
Uint32Array.from(token), | ||
true | ||
); | ||
|
||
inferArgs.onProgress(tokenText); | ||
|
||
text += tokenText; | ||
} | ||
} | ||
} | ||
export { GPT2Onnx }; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,3 @@ | ||
import { Tokenizer } from "@llama-node/hf-tokenizer"; | ||
import { Tensor } from "onnxruntime-node"; | ||
import { IsomorphicTokenizer } from "../shared/tokenizer"; | ||
|
||
export class AutoTokenizer extends IsomorphicTokenizer { | ||
constructor() { | ||
super(Tokenizer, Tensor); | ||
} | ||
} | ||
export class AutoTokenizer extends IsomorphicTokenizer {} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import * as tfjs from "@tensorflow/tfjs-node"; | ||
import * as onnx from "onnxruntime-node"; | ||
import type { Rank } from "@tensorflow/tfjs"; | ||
import type { IsomorphicTokenizer } from "./tokenizer"; | ||
import type { InferenceSession, TypedTensor } from "onnxruntime-node"; | ||
|
||
interface IGPT2OnnxOptions { | ||
tokenizer: IsomorphicTokenizer; | ||
modelPath: string | ArrayBufferLike; | ||
tokenizerUrl: string; | ||
} | ||
|
||
interface IGPT2OnnxInferenceOptions { | ||
numPredict?: number; | ||
prompt: string; | ||
topK?: number; | ||
endToken?: number; | ||
onProgress: (data: string) => void; | ||
} | ||
|
||
export class IsomorphicContext { | ||
tokenizer?: IsomorphicTokenizer; | ||
session?: InferenceSession; | ||
|
||
static async create(options: IGPT2OnnxOptions) { | ||
const tokenizer = options.tokenizer; | ||
await tokenizer.initFromUrl(options.tokenizerUrl); | ||
const gpt2Onnx = new IsomorphicContext(); | ||
|
||
gpt2Onnx.tokenizer = tokenizer; | ||
gpt2Onnx.session = await onnx.InferenceSession.create( | ||
options.modelPath as ArrayBufferLike | ||
); | ||
|
||
return gpt2Onnx; | ||
} | ||
|
||
free() { | ||
this.tokenizer?.free(); | ||
} | ||
|
||
getLogits(onnxTensor: TypedTensor<"float32">) { | ||
let output = tfjs | ||
.tensor<Rank.R3>(onnxTensor.data, onnxTensor.dims as any) | ||
.slice(0, 1); | ||
|
||
return output | ||
.slice( | ||
[0, output.shape[1] - 1, 0], | ||
[output.shape[0], 1, output.shape[2]] | ||
) | ||
.squeeze(); | ||
} | ||
|
||
async inference(inferArgs: IGPT2OnnxInferenceOptions) { | ||
if (!this.tokenizer) { | ||
throw new Error("Tokenizer not initialized"); | ||
} | ||
|
||
if (!this.session) { | ||
throw new Error("Session not initialized"); | ||
} | ||
|
||
const numPredict = inferArgs.numPredict ?? 128; | ||
const topK = inferArgs.topK ?? 1; | ||
|
||
let remain = numPredict; | ||
const tokens = this.tokenizer.tokenize(inferArgs.prompt, true); | ||
|
||
while (remain > 0) { | ||
remain -= 1; | ||
|
||
const inputs = this.tokenizer.toOnnx(tokens); | ||
const result = await this.session.run(inputs); | ||
|
||
const logits = this.getLogits( | ||
result["last_hidden_state"] as TypedTensor<"float32"> | ||
); | ||
|
||
let probs = tfjs.softmax(logits, -1); | ||
|
||
// TODO: implement topP | ||
probs = probs.topk(topK, true).indices.slice(0, 1).squeeze(); | ||
|
||
const token = probs.dataSync(); | ||
|
||
// TODO: implement end of sentence | ||
if ( | ||
token[0] >= 50256 || | ||
token[0] === 0 || | ||
token[0] === 1 || | ||
(inferArgs.endToken && token[0] === inferArgs.endToken) | ||
) { | ||
break; | ||
} | ||
|
||
const tokenText = this.tokenizer.decode( | ||
Uint32Array.from(token), | ||
true | ||
); | ||
|
||
inferArgs.onProgress(tokenText); | ||
tokens.push(BigInt(token[0])); | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.