Skip to content

Commit

Permalink
Merge branch 'sessions' into pr/30
Browse files Browse the repository at this point in the history
  • Loading branch information
LLukas22 committed Aug 19, 2023
2 parents 8c46429 + 5c4bf4b commit b314cf1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 22 deletions.
15 changes: 7 additions & 8 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "llm-rs"
version = "0.2.14"
version = "0.2.15"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand All @@ -14,15 +14,14 @@ rand_chacha = "0.3.1"
log = "0.4.17"
serde = "1.0.163"
serde_json = "1.0"

llm = { git = "https://github.com/rustformers/llm.git", rev = "5d09eed"}
llm-base = { git = "https://github.com/rustformers/llm.git", rev = "5d09eed"}
llm = { git = "https://github.com/rustformers/llm.git", rev = "129b84a" , features = ["falcon"]}
llm-base = { git = "https://github.com/rustformers/llm.git", rev = "129b84a"}

[dependencies.pyo3]
version = "0.19.1"
# "abi3-py37" tells pyo3 (and maturin) to build using the stable ABI with
# Python 3.7 or later.
features = ["abi3-py37","extension-module", "generate-import-lib"]
version = "0.19.2"
# "abi3-py38" tells pyo3 (and maturin) to build using the stable ABI with
# Python 3.8 or later.
features = ["abi3-py38","extension-module", "generate-import-lib"]

[features]
cublas = ["llm/cublas", "llm-base/cublas"]
Expand Down
1 change: 1 addition & 0 deletions examples/haystack_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@
User: Tell me a Story about a Lama riding the crab named Ferris in about 1000 words.
Assistant:
"""

model.invoke(prompt=prompt,stream=True)
6 changes: 5 additions & 1 deletion llm_rs/config.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ class SessionConfig():
batch_size:int
keys_memory_type:Precision
values_memory_type:Precision
rope_frequency_scale:Optional[float]
rope_frequency_base:Optional[int]


@property
Expand All @@ -62,5 +64,7 @@ class SessionConfig():
values_memory_type:Precision=values_memory_type.FP16,
prefer_mmap:bool=True,
use_gpu:bool=False,
gpu_layers:Optional[int]=None
gpu_layers:Optional[int]=None,
rope_frequency_scale:Optional[float]=None,
rope_frequency_base:Optional[int]=None
) -> None: ...
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "maturin"
[project]
name = "llm-rs"
description = "Unofficial python bindings for llm-rs. 馃悕鉂わ笍馃"
requires-python = ">=3.7"
requires-python = ">=3.8"
classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",]
keywords = [ "LLM", "Transformers",]
dependencies = [ "blake3", "huggingface-hub >= 0.14.1",]
Expand Down
62 changes: 50 additions & 12 deletions src/configs.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::stopwords::StopWordHandler;
use llm::{InferenceParameters, InferenceSessionConfig, ModelKVMemoryType, TokenBias};
use llm::{InferenceParameters, InferenceSessionConfig, ModelKVMemoryType, RoPEOverrides};
use pyo3::{prelude::*, types::PyBytes};
use serde::{Deserialize, Serialize};

Expand Down Expand Up @@ -116,16 +116,20 @@ impl GenerationConfig {

impl GenerationConfig {
pub fn to_llm_params(&self) -> InferenceParameters {
InferenceParameters {
sampler: std::sync::Arc::new(llm::samplers::TopPTopK {
top_k: self.top_k,
top_p: self.top_p,
temperature: self.temperature,
repeat_penalty: self.repetition_penalty,
repetition_penalty_last_n: self.repetition_penalty_last_n,
bias_tokens: TokenBias::default(),
}),
}
// Yup, this is awful. But it works for now.
let sampler_string = format!("repetition:last_n={last_n}:penalty={penalty}/topk:k={top_k}/topp:p={top_p}/temperature:temperature={temperature}",
last_n = self.repetition_penalty_last_n,
penalty = self.repetition_penalty,
top_k = self.top_k,
top_p = self.top_p,
temperature = self.temperature
);

let sampler_config = &[sampler_string];

let sampler =
llm_base::samplers::build_sampler(0, Default::default(), sampler_config).unwrap();
InferenceParameters { sampler }
}
}

Expand Down Expand Up @@ -194,6 +198,10 @@ pub struct SessionConfig {
pub use_gpu: bool,
#[pyo3(get)]
pub gpu_layers: Option<usize>,
#[pyo3(get, set)]
pub rope_frequency_scale: Option<f32>,
#[pyo3(get, set)]
pub rope_frequency_base: Option<usize>,
}

impl Default for SessionConfig {
Expand All @@ -207,6 +215,8 @@ impl Default for SessionConfig {
prefer_mmap: true,
use_gpu: false,
gpu_layers: None,
rope_frequency_scale: None,
rope_frequency_base: None,
}
}
}
Expand All @@ -224,6 +234,8 @@ impl SessionConfig {
prefer_mmap: Option<bool>,
use_gpu: Option<bool>,
gpu_layers: Option<usize>,
rope_frequency_scale: Option<f32>,
rope_frequency_base: Option<usize>,
) -> Self {
SessionConfig {
threads: threads.unwrap_or(8),
Expand All @@ -234,6 +246,8 @@ impl SessionConfig {
prefer_mmap: prefer_mmap.unwrap_or(true),
use_gpu: use_gpu.unwrap_or(false),
gpu_layers,
rope_frequency_scale,
rope_frequency_base,
}
}

Expand All @@ -247,7 +261,18 @@ impl SessionConfig {
#[allow(clippy::type_complexity)]
pub fn __getnewargs__(
&self,
) -> PyResult<(usize, usize, usize, Precision, Precision, bool, bool, usize)> {
) -> PyResult<(
usize,
usize,
usize,
Precision,
Precision,
bool,
bool,
usize,
f32,
usize,
)> {
Ok((
self.threads,
self.batch_size,
Expand All @@ -257,6 +282,8 @@ impl SessionConfig {
self.prefer_mmap,
self.use_gpu,
self.gpu_layers.unwrap_or(0),
self.rope_frequency_scale.unwrap_or(0.0),
self.rope_frequency_base.unwrap_or(0),
))
}
}
Expand All @@ -270,4 +297,15 @@ impl SessionConfig {
n_threads: self.threads,
}
}

pub fn get_rope_overrides(self) -> Option<RoPEOverrides> {
if self.rope_frequency_scale.is_some() || self.rope_frequency_base.is_some() {
Some(RoPEOverrides {
frequency_scale: self.rope_frequency_scale.unwrap_or(1.0),
frequency_base: self.rope_frequency_base.unwrap_or(10_000),
})
} else {
None
}
}
}
2 changes: 2 additions & 0 deletions src/model_base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -331,12 +331,14 @@ macro_rules! wrap_model {
let path = std::path::Path::new(&path);
let lora_paths = lora_paths
.map(|strings| strings.into_iter().map(std::path::PathBuf::from).collect());

let model_params = llm_base::ModelParameters {
context_size: config_to_use.context_length,
prefer_mmap: config_to_use.prefer_mmap,
lora_adapters: lora_paths.clone(),
use_gpu: config_to_use.use_gpu,
gpu_layers: config_to_use.gpu_layers,
rope_overrides: config_to_use.get_rope_overrides(),
};

let vocabulary_source: llm_base::TokenizerSource;
Expand Down

0 comments on commit b314cf1

Please sign in to comment.