Skip to content

Commit

Permalink
feat: impl mirostat for llama.cpp
Browse files Browse the repository at this point in the history
  • Loading branch information
hlhr202 committed May 19, 2023
1 parent 19bae9a commit f13d412
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 45 deletions.
4 changes: 2 additions & 2 deletions example/ts/llama-cpp/abortable.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { LlamaInvocation } from "@llama-node/llama-cpp";
import type { Generate } from "@llama-node/llama-cpp";
import { LLM } from "llama-node";
import { LLamaCpp, type LoadConfig } from "llama-node/dist/llm/llama-cpp.js";
import path from "path";
Expand Down Expand Up @@ -28,7 +28,7 @@ const prompt = `A chat between a user and an assistant.
USER: ${template}
ASSISTANT:`;

const params: LlamaInvocation = {
const params: Generate = {
nThreads: 4,
nTokPredict: 2048,
topK: 40,
Expand Down
4 changes: 2 additions & 2 deletions example/ts/llama-cpp/inference.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { LlamaInvocation } from "@llama-node/llama-cpp";
import type { Generate } from "@llama-node/llama-cpp";
import { LLM } from "llama-node";
import { LLamaCpp, type LoadConfig } from "llama-node/dist/llm/llama-cpp.js";
import path from "path";
Expand Down Expand Up @@ -28,7 +28,7 @@ const prompt = `A chat between a user and an assistant.
USER: ${template}
ASSISTANT:`;

const params: LlamaInvocation = {
const params: Generate = {
nThreads: 4,
nTokPredict: 2048,
topK: 40,
Expand Down
4 changes: 2 additions & 2 deletions packages/llama-cpp/example/abortable.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { LLama, LlamaInvocation } from "../index";
import { LLama, Generate } from "../index";
import path from "path";

const run = async () => {
Expand All @@ -15,7 +15,7 @@ const run = async () => {
USER: ${template}
ASSISTANT:`;

const params: LlamaInvocation = {
const params: Generate = {
nThreads: 4,
nTokPredict: 2048,
topK: 40,
Expand Down
4 changes: 2 additions & 2 deletions packages/llama-cpp/example/embedding.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { LLama, LlamaInvocation } from "../index";
import { LLama, Generate } from "../index";
import path from "path";

const run = async () => {
Expand All @@ -21,7 +21,7 @@ const run = async () => {

const prompt = `Who is the president of the United States?`;

const params: LlamaInvocation = {
const params: Generate = {
nThreads: 4,
nTokPredict: 2048,
topK: 40,
Expand Down
4 changes: 2 additions & 2 deletions packages/llama-cpp/example/inference.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import { InferenceResultType } from "../index";
import { LLama, LlamaInvocation } from "../index";
import { LLama, Generate } from "../index";
import path from "path";

const run = async () => {
Expand All @@ -26,7 +26,7 @@ const run = async () => {
USER: ${template}
ASSISTANT:`;

const params: LlamaInvocation = {
const params: Generate = {
nThreads: 4,
nTokPredict: 2048,
topK: 40,
Expand Down
90 changes: 86 additions & 4 deletions packages/llama-cpp/index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,102 @@ export interface InferenceResult {
data?: InferenceToken
message?: string
}
export interface LlamaInvocation {
export interface LogitBias {
token: number
bias: number
}
export interface Generate {
nThreads: number
nTokPredict: number
topK: number
/**
* logit bias for specific tokens
* Default: None
*/
logitBias?: Array<LogitBias>
/**
* top k tokens to sample from
* Range: <= 0 to use vocab size
* Default: 40
*/
topK?: number
/**
* top p tokens to sample from
* Default: 0.95
* 1.0 = disabled
*/
topP?: number
/**
* tail free sampling
* Default: 1.0
* 1.0 = disabled
*/
tfsZ?: number
/**
* temperature
* Default: 0.80
* 1.0 = disabled
*/
temp?: number
/**
* locally typical sampling
* Default: 1.0
* 1.0 = disabled
*/
typicalP?: number
/**
* repeat penalty
* Default: 1.10
* 1.0 = disabled
*/
repeatPenalty?: number
/**
* last n tokens to penalize
* Default: 64
* 0 = disable penalty, -1 = context size
*/
repeatLastN?: number
/**
* frequency penalty
* Default: 0.00
* 1.0 = disabled
*/
frequencyPenalty?: number
/**
* presence penalty
* Default: 0.00
* 1.0 = disabled
*/
presencePenalty?: number
/**
* Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
* Mirostat: A Neural Text Decoding Algorithm that Directly Controls Perplexity
* Default: 0
* 0 = disabled
* 1 = mirostat 1.0
* 2 = mirostat 2.0
*/
mirostat?: number
/**
* The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
* Default: 5.0
*/
mirostatTau?: number
/**
* The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
* Default: 0.1
*/
mirostatEta?: number
/**
* stop sequence
* Default: None
*/
stopSequence?: string
/**
* consider newlines as a repeatable token
* Default: true
*/
penalizeNl?: boolean
/** prompt */
prompt: string
}
export interface ModelLoad {
Expand All @@ -54,7 +136,7 @@ export interface LlamaLoraAdaptor {
}
export class LLama {
static load(params: Partial<LoadModel>, enableLogger: boolean): Promise<LLama>
getWordEmbedding(params: LlamaInvocation): Promise<Array<number>>
getWordEmbedding(params: Generate): Promise<Array<number>>
tokenize(params: string): Promise<Array<number>>
inference(params: LlamaInvocation, callback: (result: InferenceResult) => void): () => void
inference(params: Generate, callback: (result: InferenceResult) => void): () => void
}
55 changes: 45 additions & 10 deletions packages/llama-cpp/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ use llama_sys::{
llama_get_embeddings, llama_get_logits, llama_init_from_file, llama_n_embd, llama_n_vocab,
llama_print_system_info, llama_sample_frequency_and_presence_penalties,
llama_sample_repetition_penalty, llama_sample_tail_free, llama_sample_temperature,
llama_sample_token, llama_sample_token_greedy, llama_sample_top_k, llama_sample_top_p,
llama_sample_typical, llama_token, llama_token_data, llama_token_data_array, llama_token_nl,
llama_token_to_str,
llama_sample_token, llama_sample_token_greedy, llama_sample_token_mirostat,
llama_sample_token_mirostat_v2, llama_sample_top_k, llama_sample_top_p, llama_sample_typical,
llama_token, llama_token_data, llama_token_data_array, llama_token_nl, llama_token_to_str,
};

use crate::types::{LlamaInvocation, ModelLoad};
use crate::types::{Generate, ModelLoad};

// Represents the LLamaContext which wraps FFI calls to the llama.cpp library.
pub struct LLamaContext {
Expand Down Expand Up @@ -73,15 +73,16 @@ impl LLamaContext {
pub fn llama_sample(
&self,
last_n_tokens: &mut [llama_token],
input: &LlamaInvocation,
input: &Generate,
context_params: &llama_context_params,
) -> i32 {
let n_ctx = context_params.n_ctx;
let top_p = input.top_p.unwrap_or(0.95) as f32;
let top_k = if input.top_k <= 0 {
let top_k = input.top_k.unwrap_or(40);
let top_k = if top_k <= 0 {
unsafe { llama_n_vocab(self.ctx) }
} else {
input.top_k
top_k
};
let tfs_z = input.tfs_z.unwrap_or(1.0) as f32;
let temp = input.temp.unwrap_or(0.8) as f32;
Expand All @@ -97,11 +98,20 @@ impl LLamaContext {
let alpha_presence = input.presence_penalty.unwrap_or(0.0) as f32;
let penalize_nl = input.penalize_nl.unwrap_or(true);

let empty_logit_bias = Vec::new();
let logit_bias = input.logit_bias.as_ref().unwrap_or(&empty_logit_bias);

let mirostat = input.mirostat.unwrap_or(0);
let mirostat_tau = input.mirostat_tau.unwrap_or(5.0) as f32;
let mirostat_eta = input.mirostat_eta.unwrap_or(0.1) as f32;

let n_vocab = unsafe { llama_n_vocab(self.ctx) };
let logits_ptr = unsafe { llama_get_logits(self.ctx) };
let logits = unsafe { slice::from_raw_parts_mut(logits_ptr, n_vocab as usize) };

// TODO: apply logit bias
for i in logit_bias.iter() {
logits[i.token as usize] += i.bias as f32;
}

let mut candidates: Vec<llama_token_data> = Vec::with_capacity(n_vocab as usize);

Expand Down Expand Up @@ -165,8 +175,33 @@ impl LLamaContext {

if temp <= 0.0_f32 {
id = unsafe { llama_sample_token_greedy(self.ctx, candidates_p) };
} else if mirostat == 1 {
let mut mirostat_mu = 2.0_f32 * mirostat_tau;
let mirostat_m = 100;
unsafe { llama_sample_temperature(self.ctx, candidates_p, temp) };
id = unsafe {
llama_sample_token_mirostat(
self.ctx,
candidates_p,
mirostat_tau,
mirostat_eta,
mirostat_m,
&mut mirostat_mu,
)
}
} else if mirostat == 2 {
let mut mirostat_mu = 2.0_f32 * mirostat_tau;
unsafe { llama_sample_temperature(self.ctx, candidates_p, temp) };
id = unsafe {
llama_sample_token_mirostat_v2(
self.ctx,
candidates_p,
mirostat_tau,
mirostat_eta,
&mut mirostat_mu,
)
}
} else {
// TODO: here we just do temp for first approach, I dont understand microstat very well, will impl later
id = unsafe {
llama_sample_top_k(self.ctx, candidates_p, top_k, 1);
llama_sample_tail_free(self.ctx, candidates_p, tfs_z, 1);
Expand Down Expand Up @@ -210,7 +245,7 @@ impl LLamaContext {
tokens: &[llama_token],
n_tokens: i32,
n_past: i32,
input: &LlamaInvocation,
input: &Generate,
) -> Result<(), napi::Error> {
let res =
unsafe { llama_eval(self.ctx, tokens.as_ptr(), n_tokens, n_past, input.n_threads) };
Expand Down
6 changes: 3 additions & 3 deletions packages/llama-cpp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use napi::{
JsFunction,
};
use tokio::sync::Mutex;
use types::{InferenceResult, InferenceResultType, LlamaInvocation, ModelLoad};
use types::{InferenceResult, InferenceResultType, Generate, ModelLoad};

#[napi]
pub struct LLama {
Expand All @@ -46,7 +46,7 @@ impl LLama {
}

#[napi]
pub async fn get_word_embedding(&self, params: LlamaInvocation) -> Result<Vec<f64>> {
pub async fn get_word_embedding(&self, params: Generate) -> Result<Vec<f64>> {
let llama = self.llama.lock().await;
llama.embedding(&params).await
}
Expand All @@ -61,7 +61,7 @@ impl LLama {
pub fn inference(
&self,
env: Env,
params: LlamaInvocation,
params: Generate,
#[napi(ts_arg_type = "(result: InferenceResult) => void")] callback: JsFunction,
) -> Result<JsFunction> {
let tsfn: ThreadsafeFunction<InferenceResult, ErrorStrategy::Fatal> = callback
Expand Down
6 changes: 3 additions & 3 deletions packages/llama-cpp/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tokio::sync::Mutex;
use crate::{
context::LLamaContext,
tokenizer::{llama_token_eos, tokenize},
types::{InferenceResult, InferenceResultType, InferenceToken, LlamaInvocation, ModelLoad},
types::{InferenceResult, InferenceResultType, InferenceToken, Generate, ModelLoad},
};

pub struct LLamaInternal {
Expand Down Expand Up @@ -37,7 +37,7 @@ impl LLamaInternal {
Ok(tokenize(context, input, false))
}

pub async fn embedding(&self, input: &LlamaInvocation) -> Result<Vec<f64>, napi::Error> {
pub async fn embedding(&self, input: &Generate) -> Result<Vec<f64>, napi::Error> {
let context = &self.context;
let embd_inp = tokenize(context, input.prompt.as_str(), true);

Expand All @@ -60,7 +60,7 @@ impl LLamaInternal {

pub fn inference(
&self,
input: &LlamaInvocation,
input: &Generate,
running: Arc<Mutex<bool>>,
callback: impl Fn(InferenceResult),
) -> Result<(), napi::Error> {
Expand Down

0 comments on commit f13d412

Please sign in to comment.