From 7c1bba8abd942cfe9d77c276541f52f6b94b1b19 Mon Sep 17 00:00:00 2001 From: Charlie Ruan <53290280+CharlieFRuan@users.noreply.github.com> Date: Thu, 14 Mar 2024 03:37:45 -0400 Subject: [PATCH] Expose sample top p from prob --- web/src/runtime.ts | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/web/src/runtime.ts b/web/src/runtime.ts index d9eaafb068a9..ea022d1b3e9d 100644 --- a/web/src/runtime.ts +++ b/web/src/runtime.ts @@ -165,6 +165,7 @@ class RuntimeContext implements Disposable { makeShapeTuple: PackedFunc; ndarrayCreateView: PackedFunc; sampleTopPFromLogits: PackedFunc; + sampleTopPFromProb: PackedFunc; applyRepetitionPenalty: PackedFunc; applyPresenceAndFrequencyPenalty: PackedFunc; applySoftmaxWithTemperature: PackedFunc; @@ -188,6 +189,7 @@ class RuntimeContext implements Disposable { this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple"); this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView"); this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits"); + this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob"); this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty"); this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty"); this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature"); @@ -1826,6 +1828,17 @@ export class Instance implements Disposable { return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random()); } + /** + * Sample index via top-p sampling. + * + * @param prob The distribution, i.e. logits after `applySoftmaxWithTemperature()` is performed. + * @param top_p The top_p + * @returns The sampled index. + */ + sampleTopPFromProb(prob: NDArray, top_p: number): number { + return this.ctx.sampleTopPFromProb(prob, top_p, Math.random()); + } + /** * Apply repetition penalty to the logits. * @param logits The input logits before penalty.