Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Low level new loading system #64

Merged
38 changes: 37 additions & 1 deletion LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,65 +13,101 @@ public struct LLamaContextParams
/// RNG seed, -1 for random
/// </summary>
public int seed;

/// <summary>
/// text context
/// </summary>
public int n_ctx;

/// <summary>
/// prompt processing batch size
/// </summary>
public int n_batch;

/// <summary>
/// grouped-query attention (TEMP - will be moved to model hparams)
/// </summary>
public int n_gqa;

/// <summary>
/// rms norm epsilon (TEMP - will be moved to model hparams)
/// </summary>
public float rms_norm_eps;

/// <summary>
/// number of layers to store in VRAM
/// </summary>
public int n_gpu_layers;

/// <summary>
/// the GPU that is used for scratch and small tensors
/// </summary>
public int main_gpu;

/// <summary>
/// how to split layers across multiple GPUs
/// </summary>
public TensorSplits tensor_split;
public float[] tensor_split;

/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE base frequency
/// </summary>
public float rope_freq_base;

/// <summary>
/// ref: https://github.com/ggerganov/llama.cpp/pull/2054
/// RoPE frequency scaling factor
/// </summary>
public float rope_freq_scale;

/// <summary>
/// called with a progress value between 0 and 1, pass NULL to disable
/// </summary>
public IntPtr progress_callback;

/// <summary>
/// context pointer passed to the progress callback
/// </summary>
public IntPtr progress_callback_user_data;


/// <summary>
/// if true, reduce VRAM usage at the cost of performance
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool low_vram;

/// <summary>
/// use fp16 for KV cache
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool f16_kv;

/// <summary>
/// the llama_eval() call computes all logits, not just the last one
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool logits_all;

/// <summary>
/// only load the vocabulary, no weights
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool vocab_only;

/// <summary>
/// use mmap if possible
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool use_mmap;

/// <summary>
/// force system to keep model in RAM
/// </summary>
[MarshalAs(UnmanagedType.I1)]
public bool use_mlock;

/// <summary>
/// embedding mode only
/// </summary>
Expand Down
36 changes: 30 additions & 6 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
Expand Down Expand Up @@ -29,7 +27,7 @@ static NativeApi()
}
private const string libraryName = "libllama";

[DllImport("libllama", EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
Expand All @@ -56,7 +54,10 @@ static NativeApi()
/// <param name="params_"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_init_from_file(string path_model, LLamaContextParams params_);
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaContextParams params_);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams params_);

/// <summary>
/// not great API - very likely to change.
Expand All @@ -65,27 +66,35 @@ static NativeApi()
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_backend_init(bool numa);

/// <summary>
/// Frees all allocated memory
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free(IntPtr ctx);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model);

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
/// the layers modified by the adapter. Can be NULL to use the current loaded model.
/// The model needs to be reloaded before applying a new adapter, otherwise the adapter
/// will be applied on top of the previous one
/// </summary>
/// <param name="ctx"></param>
/// <param name="model_ptr"></param>
/// <param name="path_lora"></param>
/// <param name="path_base_model"></param>
/// <param name="n_threads"></param>
/// <returns>Returns 0 on success</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_apply_lora_from_file(SafeLLamaContextHandle ctx, string path_lora, string path_base_model, int n_threads);
public static extern int llama_model_apply_lora_from_file(SafeLlamaModelHandle model_ptr, string path_lora, string? path_base_model, int n_threads);

/// <summary>
/// Returns the number of tokens in the KV cache
Expand Down Expand Up @@ -294,5 +303,20 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_print_system_info();

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_vocab_from_model(SafeLlamaModelHandle model);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_ctx_from_model(SafeLlamaModelHandle model);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_n_embd_from_model(SafeLlamaModelHandle model);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern byte* llama_token_to_str_with_model(SafeLlamaModelHandle model, int llamaToken);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern int llama_tokenize_with_model(SafeLlamaModelHandle model, byte* text, int* tokens, int n_max_tokens, bool add_bos);
}
}
51 changes: 43 additions & 8 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,26 +1,61 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;

namespace LLama.Native
{
public class SafeLLamaContextHandle: SafeLLamaHandleBase
/// <summary>
/// A safe wrapper around a llama_context
/// </summary>
public class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
protected SafeLLamaContextHandle()
{
}
/// <summary>
/// This field guarantees that a reference to the model is held for as long as this handle is held
/// </summary>
private SafeLlamaModelHandle? _model;

public SafeLLamaContextHandle(IntPtr handle)
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
/// <param name="handle">pointer to an allocated llama_context</param>
/// <param name="model">the model which this context was created from</param>
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
// Increment the model reference count while this context exists
_model = model;
var success = false;
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
// Decrement refcount on model
_model?.DangerousRelease();
_model = null;

NativeApi.llama_free(handle);
SetHandle(IntPtr.Zero);
return true;
}

/// <summary>
/// Create a new llama_state for the given model
/// </summary>
/// <param name="model"></param>
/// <param name="lparams"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
if (ctx_ptr == IntPtr.Zero)
throw new RuntimeError("Failed to create context from model");

return new(ctx_ptr, model);
}
}
}
10 changes: 7 additions & 3 deletions LLama/Native/SafeLLamaHandleBase.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text;

namespace LLama.Native
{
public abstract class SafeLLamaHandleBase: SafeHandle
/// <summary>
/// Base class for all llama handles to native resources
/// </summary>
public abstract class SafeLLamaHandleBase
: SafeHandle
{
private protected SafeLLamaHandleBase()
: base(IntPtr.Zero, ownsHandle: true)
Expand All @@ -24,8 +26,10 @@ private protected SafeLLamaHandleBase(IntPtr handle, bool ownsHandle)
SetHandle(handle);
}

/// <inheritdoc />
public override bool IsInvalid => handle == IntPtr.Zero;

/// <inheritdoc />
public override string ToString()
=> $"0x{handle.ToString("x16")}";
}
Expand Down