Skip to content

Commit

Permalink
add tags to progress function for identifying process
Browse files Browse the repository at this point in the history
  • Loading branch information
amitsingh19975 committed May 2, 2023
1 parent 91a5c11 commit 5f98c57
Show file tree
Hide file tree
Showing 9 changed files with 103 additions and 30 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,5 @@ build/

examples/python/fastllama/api.py
examples/python/fastllama/pyfastllama.so

node_modules/
5 changes: 4 additions & 1 deletion examples/python/example-logging.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fastllama import Model, Logger
from fastllama import Model, Logger, ProgressTag

MODEL_PATH = "./models/ALPACA-LORA-7B/alpaca-lora-q4_0.bin"

Expand All @@ -25,6 +25,9 @@ def log_err(self, func_name: str, message: str) -> None:
def log_warn(self, func_name: str, message: str) -> None:
#Modify this to do whatever you want when you see warning logs
print(f"[Warn]: Func('{func_name}') {message}", flush=True, end='', file=self.file)

def progress(self, tag: ProgressTag, done_size: int, total_size) -> None:
print(f"[Progress]: {tag} {done_size}/{total_size}", flush=True, file=self.file)


model = Model(
Expand Down
8 changes: 4 additions & 4 deletions include/file_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,7 @@ namespace fastllama {
std::size_t done_size{};
for(auto& tl : tensors_map.tensors) {
if (call_progress_callback) {
logger->progress(done_size, data_size);
logger->progress(ProgressTag::Init, done_size, data_size);
}

FAST_LLAMA_ASSERT(tl.tensor, "tensor not created");
Expand All @@ -543,7 +543,7 @@ namespace fastllama {
}

if (call_progress_callback) {
logger->progress(data_size, data_size);
logger->progress(ProgressTag::Init, data_size, data_size);
}
}

Expand All @@ -566,7 +566,7 @@ namespace fastllama {
auto worker = [&](parallel::Block block) {

for(auto i = block.start; i < block.end; ++i) {
if (call_progress_callback) logger->progress(done_size.load(std::memory_order_relaxed), data_size);
if (call_progress_callback) logger->progress(ProgressTag::Init, done_size.load(std::memory_order_relaxed), data_size);
auto& tl = tensors_map.tensors[i];
FAST_LLAMA_ASSERT(tl.tensor, "tensor not created");
tl.data = static_cast<std::uint8_t*>(tl.tensor->data);
Expand All @@ -582,7 +582,7 @@ namespace fastllama {
parallel::for_(thread_pool, parallel::Range{ 0, tensors_map.tensors.size(), static_cast<std::size_t>(block_size) }, std::move(worker));

if (call_progress_callback) {
logger->progress(data_size, data_size);
logger->progress(ProgressTag::Init, data_size, data_size);
}
}

Expand Down
35 changes: 30 additions & 5 deletions include/logger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,33 @@

namespace fastllama {

enum class ProgressTag : std::uint8_t {
Unknown = 0,
Init = 1,
Load = 2,
Save = 3,
Ingest = 4,
AttachLoraAdapter = 5,
DetachLoraAdapter = 6,
};

inline static std::string_view to_string(ProgressTag tag) noexcept {
switch (tag) {
case ProgressTag::Unknown: return "Unknown";
case ProgressTag::Init: return "Init";
case ProgressTag::Load: return "Load";
case ProgressTag::Save: return "Save";
case ProgressTag::Ingest: return "Ingest";
case ProgressTag::AttachLoraAdapter: return "AttachLoraAdapter";
case ProgressTag::DetachLoraAdapter: return "DetachLoraAdapter";
default: return "Unknown";
}
}

struct DefaultLogger {
using LoggerFunction = std::function<void(char const*, int, char const*, int)>;
using LoggerResetFunction = std::function<void()>;
using ProgressCallback = std::function<void(std::size_t, std::size_t)>;
using ProgressCallback = std::function<void(ProgressTag, std::size_t, std::size_t)>;

DefaultLogger() noexcept = default;
DefaultLogger(DefaultLogger const&) = delete;
Expand Down Expand Up @@ -41,7 +64,9 @@ namespace fastllama {
fflush(stdout);
}

static void progress_func(std::size_t done, std::size_t total) {
static void progress_func(ProgressTag tag, ::size_t done, std::size_t total) {
if (tag == ProgressTag::Ingest) return;

auto perc = (static_cast<float>(done) / static_cast<float>(total)) * 100.0f;
auto perc_int = static_cast<int>(perc);

Expand Down Expand Up @@ -70,7 +95,7 @@ namespace fastllama {
DefaultLogger::log_err = [](char const*, int, char const*, int) {};
DefaultLogger::log_warn = [](char const*, int, char const*, int) {};
DefaultLogger::reset = []() {};
DefaultLogger::progress = [](std::size_t, std::size_t) {};
DefaultLogger::progress = [](ProgressTag, std::size_t, std::size_t) {};
}
};

Expand Down Expand Up @@ -130,9 +155,9 @@ namespace fastllama {
m_sink.log_warn(func_name.data(), static_cast<int>(func_name.size()), message.data(), static_cast<int>(message.size()));
}

void progress(std::size_t done, std::size_t total) const {
void progress(ProgressTag tag, std::size_t done, std::size_t total) const {
if (!m_sink.progress) return;
m_sink.progress(std::min(done, total), total);
m_sink.progress(tag, std::min(done, total), total);
}
private:
DefaultLogger m_sink{};
Expand Down
12 changes: 11 additions & 1 deletion interfaces/c/fastllama.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,19 @@
extern "C" {
#endif

enum progress_type_tag : uint8_t {
PROGRESS_TAG_UNKNOWN = 0,
PROGRESS_TAG_INIT = 1,
PROGRESS_TAG_LOAD = 2,
PROGRESS_TAG_SAVE = 3,
PROGRESS_TAG_INGEST = 4,
PROGRESS_TAG_ATTACH_LORA_ADAPTER = 5,
PROGRESS_TAG_DETACH_LORA_ADAPTER = 6,
};

typedef void(*LLAMA_LOGGER_FUNC)(char const* function_name, int function_name_size, char const* message, int message_size);
typedef void(*LLAMA_LOGGER_RESET_FUNC)();
typedef void(*LLAMA_LOGGER_PROGRESS_FUNC)(size_t done_size, size_t total_size);
typedef void(*LLAMA_LOGGER_PROGRESS_FUNC)(progress_type_tag, size_t done_size, size_t total_size);
typedef void(*LLAMA_STREAM_FUNC)(char const* token_stream, int token_stream_size);

struct llama_model_context;
Expand Down
6 changes: 3 additions & 3 deletions interfaces/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ extern "C" {
result.logger.log_err = make_def_err_logger_func();
result.logger.log_warn = make_def_warn_logger_func();
result.logger.reset = make_def_reset_logger_func();
result.logger.progress = +[](std::size_t progress, std::size_t total) {
fastllama::Logger::get_default_logger().progress(progress, total);
result.logger.progress = +[](progress_type_tag tag, std::size_t progress, std::size_t total) {
fastllama::Logger::get_default_logger().progress(static_cast<fastllama::ProgressTag>(tag), progress, total);
};

result.use_mmap = false;
Expand Down Expand Up @@ -90,7 +90,7 @@ extern "C" {
def_logger.log_err = arg.logger.log_err;
def_logger.log_warn = arg.logger.log_warn;
def_logger.reset = arg.logger.reset;
def_logger.progress = arg.logger.progress;
def_logger.progress = reinterpret_cast<void(*)(fastllama::ProgressTag, std::size_t, std::size_t)>(arg.logger.progress);

builder.logger = Logger(std::move(def_logger));

Expand Down
40 changes: 35 additions & 5 deletions interfaces/python/fastllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,34 @@ def progressBar(count_value, total, suffix=''):
sys.stdout.write('[%s] %s%s ...%s\r' %(bar, percentage, '%', suffix))
sys.stdout.flush()

class ProgressTag(Enum):
Unknown = 0,
Init = 1,
Load = 2,
Save = 3,
Ingest = 4,
AttachLoraAdapter = 5,
DetachLoraAdapter = 6,

@staticmethod
def from_int(value: int) -> 'ProgressTag':
if value == 0:
return ProgressTag.Unknown
elif value == 1:
return ProgressTag.Init
elif value == 2:
return ProgressTag.Load
elif value == 3:
return ProgressTag.Save
elif value == 4:
return ProgressTag.Ingest
elif value == 5:
return ProgressTag.AttachLoraAdapter
elif value == 6:
return ProgressTag.DetachLoraAdapter
else:
raise Exception(f"Unknown progress tag value: {value}")

class Logger:
"""
Logger class for reporting messages.
Expand Down Expand Up @@ -59,13 +87,15 @@ def log_warn(self, func_name: str, message: str) -> None:
"""
print(f"[Warn]: Func('{func_name}') {message}", flush=True, end='')

def progress(self, done_size: int, total_size: int) -> None:
def progress(self, tag: ProgressTag, done_size: int, total_size: int) -> None:
"""
Logs progress messages.
:param done_size(int): size of the completed task
:param total_size(int): total size of the task
"""
if tag == ProgressTag.Ingest:
return
progressBar(done_size, total_size)

def reset(self) -> None:
Expand All @@ -76,7 +106,7 @@ def reset(self) -> None:

C_LLAMA_LOGGER_FUNC = ctypes.CFUNCTYPE(None, ctypes.c_char_p, ctypes.c_int, ctypes.c_char_p, ctypes.c_int)
C_LLAMA_LOGGER_RESET_FUNC = ctypes.CFUNCTYPE(None)
C_LLAMA_LOGGER_PROGRESS_FUNC = ctypes.CFUNCTYPE(None, ctypes.c_size_t, ctypes.c_size_t)
C_LLAMA_LOGGER_PROGRESS_FUNC = ctypes.CFUNCTYPE(None, ctypes.c_uint8, ctypes.c_size_t, ctypes.c_size_t)

class c_llama_logger(ctypes.Structure):
"""
Expand Down Expand Up @@ -139,15 +169,15 @@ def c_logger_func(func_name: ctypes.c_char_p, func_name_len: ctypes.c_int, messa
func(ctypes.string_at(func_name, int(func_name_len)).decode('utf-8'), ctypes.string_at(message, int(message_len)).decode('utf-8'))
return C_LLAMA_LOGGER_FUNC(c_logger_func)

def make_c_progress_func(func: Callable[[int, int], None]) -> Any:
def make_c_progress_func(func: Callable[[ProgressTag, int, int], None]) -> Any:
"""
Creates a C-compatible progress function from a Python callable.
:param func: Python callable to be converted to a C-compatible progress function.
:return: C-compatible progress function.
"""
def c_progress_func(done_size: ctypes.c_size_t, total_size: ctypes.c_size_t) -> None:
func(int(done_size), int(total_size))
def c_progress_func(tag: ctypes.c_uint8, done_size: ctypes.c_size_t, total_size: ctypes.c_size_t) -> None:
func(ProgressTag.from_int(int(tag)), int(done_size), int(total_size))
return C_LLAMA_LOGGER_PROGRESS_FUNC(c_progress_func)

def make_c_logger_reset_func(func: Callable[[], None]) -> Any:
Expand Down
21 changes: 13 additions & 8 deletions lib/bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,29 +193,31 @@ namespace fastllama {
prompt.insert(0, 1, ' ');

auto embd_input = tokenize(m_model.vocabulary, prompt, true);

auto const embd_input_size = embd_input.size();

auto max_input_size = m_model.params.n_ctx - 4;
if (embd_input.size() > static_cast<std::size_t>(max_input_size)) {
m_model.logger.log_err("ingest", "prompt size(='", embd_input.size(), "') exceeds maximum allowed size('", max_input_size, "')");
if (embd_input_size > static_cast<std::size_t>(max_input_size)) {
m_model.logger.log_err("ingest", "prompt size(='", embd_input_size, "') exceeds maximum allowed size('", max_input_size, "')");
return false;
}

if (is_system_prompt) {
if (m_keep < static_cast<int>(embd_input.size())) {
m_model.logger.log_err("ingest", "system prompt size(='", embd_input.size(), "') exceeds 'n_keep'(='", m_keep, "')");
if (m_keep < static_cast<int>(embd_input_size)) {
m_model.logger.log_err("ingest", "system prompt size(='", embd_input_size, "') exceeds 'n_keep'(='", m_keep, "')");
return false;
}
m_system_prompt = embd_input;
}

auto const n_batch = m_model.n_batch;

for(auto i = 0ul; i < embd_input.size(); i += static_cast<std::size_t>(n_batch)) {
auto block = std::min(static_cast<std::size_t>(n_batch), embd_input.size() - i);
get_logger().progress(ProgressTag::Ingest, 0, embd_input_size);

recycle_embed_if_exceeds_context();
for(auto i = 0ul; i < embd_input_size; i += static_cast<std::size_t>(n_batch)) {
auto block = std::min(static_cast<std::size_t>(n_batch), embd_input_size - i);

// std::cout<<"E Size: " << m_embd.size()<<", Past: "<<n_past<<", Mem: "<<m_mem_per_token<<std::endl;
recycle_embed_if_exceeds_context();

if (!m_embd.empty()) {
if (!m_model.eval(static_cast<std::size_t>(n_past), m_embd, m_logits, m_mem_per_token)) {
Expand All @@ -228,8 +230,11 @@ namespace fastllama {

std::copy_n(embd_input.begin() + static_cast<std::ptrdiff_t>(i), block, std::back_inserter(m_embd));
std::copy_n(embd_input.begin() + static_cast<std::ptrdiff_t>(i), block, std::back_inserter(m_last_n_tokens));
get_logger().progress(ProgressTag::Ingest, block, embd_input_size);
}

get_logger().progress(ProgressTag::Ingest, embd_input_size, embd_input_size);

m_last_n_tokens.clear();
return true;
}
Expand Down
4 changes: 1 addition & 3 deletions lib/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,14 +878,12 @@ namespace fastllama {
gf.n_threads = model.threads;
ggml_graph_compute(model_loader.mem_ctx, &gf);

// why is the sun flat?

model_loader.mem_ctx.free();

}

data_loaded += ggml_nelements(current_lora_tensor);
logger.progress(data_loaded, total_size);
logger.progress(is_detach ? ProgressTag::DetachLoraAdapter : ProgressTag::AttachLoraAdapter, data_loaded, total_size);

}

Expand Down

0 comments on commit 5f98c57

Please sign in to comment.