Skip to content

Commit

Permalink
Expose sample top p from prob
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieFRuan committed Mar 14, 2024
1 parent fe2b314 commit 7c1bba8
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions web/src/runtime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ class RuntimeContext implements Disposable {
makeShapeTuple: PackedFunc;
ndarrayCreateView: PackedFunc;
sampleTopPFromLogits: PackedFunc;
sampleTopPFromProb: PackedFunc;
applyRepetitionPenalty: PackedFunc;
applyPresenceAndFrequencyPenalty: PackedFunc;
applySoftmaxWithTemperature: PackedFunc;
Expand All @@ -188,6 +189,7 @@ class RuntimeContext implements Disposable {
this.makeShapeTuple = getGlobalFunc("runtime.ShapeTuple");
this.ndarrayCreateView = getGlobalFunc("runtime.TVMArrayCreateView");
this.sampleTopPFromLogits = getGlobalFunc("vm.builtin.sample_top_p_from_logits");
this.sampleTopPFromProb = getGlobalFunc("vm.builtin.sample_top_p_from_prob");
this.applyRepetitionPenalty = getGlobalFunc("vm.builtin.apply_repetition_penalty");
this.applyPresenceAndFrequencyPenalty = getGlobalFunc("vm.builtin.apply_presence_and_frequency_penalty");
this.applySoftmaxWithTemperature = getGlobalFunc("vm.builtin.apply_softmax_with_temperature");
Expand Down Expand Up @@ -1826,6 +1828,17 @@ export class Instance implements Disposable {
return this.ctx.sampleTopPFromLogits(logits, temperature, top_p, Math.random());
}

/**
* Sample index via top-p sampling.
*
* @param prob The distribution, i.e. logits after `applySoftmaxWithTemperature()` is performed.
* @param top_p The top_p
* @returns The sampled index.
*/
sampleTopPFromProb(prob: NDArray, top_p: number): number {
return this.ctx.sampleTopPFromProb(prob, top_p, Math.random());
}

/**
* Apply repetition penalty to the logits.
* @param logits The input logits before penalty.
Expand Down

0 comments on commit 7c1bba8

Please sign in to comment.