diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 2f667be0b..a55bfedb7 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -682,11 +682,9 @@ public void Save(string path) Directory.CreateDirectory(path); string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var bytes = ContextState?.ToByteArray(); - if (bytes is not null) - { - File.WriteAllBytes(modelStateFilePath, bytes); - } + if (ContextState != null) + using (var stateStream = File.Create(modelStateFilePath)) + ContextState?.Save(stateStream); string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState)); @@ -722,10 +720,11 @@ public static SessionState Load(string path) throw new ArgumentException("Directory does not exist", nameof(path)); } - string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); - var contextState = File.Exists(modelStateFilePath) ? - State.FromByteArray(File.ReadAllBytes(modelStateFilePath)) - : null; + var modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME); + State? contextState = default; + if (File.Exists(modelStateFilePath)) + using (var modelStateStream = File.OpenRead(modelStateFilePath)) + contextState = State.Load(modelStateStream); string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME); var executorState = JsonSerializer.Deserialize(File.ReadAllText(executorStateFilepath)); diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 6488e6468..892ec9f29 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -13,6 +13,7 @@ using LLama.Sampling; using Microsoft.Extensions.Logging; using System.Threading; +using System.Security.Cryptography; namespace LLama { @@ -622,28 +623,70 @@ protected override bool ReleaseHandle() } /// - /// Convert this state to a byte array + /// Write all the bytes of this state to the given stream /// + /// + public async Task SaveAsync(Stream stream) + { + UnmanagedMemoryStream from; + unsafe + { + from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size)); + } + await from.CopyToAsync(stream); + } + + /// + /// Write all the bytes of this state to the given stream + /// + /// + public void Save(Stream stream) + { + UnmanagedMemoryStream from; + unsafe + { + from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size)); + } + from.CopyTo(stream); + } + + /// + /// Load a state from a stream + /// + /// /// - [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] - public byte[] ToByteArray() + public static async Task LoadAsync(Stream stream) { - var bytes = new byte[_size]; - Marshal.Copy(handle, bytes, 0, (int)_size); - return bytes; + var memory = Marshal.AllocHGlobal((nint)stream.Length); + var state = new State(memory, checked((ulong)stream.Length)); + + UnmanagedMemoryStream dest; + unsafe + { + dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length); + } + await stream.CopyToAsync(dest); + + return state; } /// - /// Load state from a byte array + /// Load a state from a stream /// - /// + /// /// - [Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")] - public static State FromByteArray(byte[] bytes) + public static State Load(Stream stream) { - var memory = Marshal.AllocHGlobal(bytes.Length); - Marshal.Copy(bytes, 0, memory, bytes.Length); - return new State(memory, (ulong)bytes.Length); + var memory = Marshal.AllocHGlobal((nint)stream.Length); + var state = new State(memory, checked((ulong)stream.Length)); + + unsafe + { + var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length); + stream.CopyTo(dest); + } + + return state; } } diff --git a/LLama/Native/LLamaTokenType.cs b/LLama/Native/LLamaTokenType.cs deleted file mode 100644 index 3dbaae3d3..000000000 --- a/LLama/Native/LLamaTokenType.cs +++ /dev/null @@ -1,37 +0,0 @@ -using System; - -namespace LLama.Native; - -/// -/// Token Types -/// -/// C# equivalent of llama_token_get_type -[Obsolete("will be removed from llama.cpp once per token attributes are available from GGUF file")] -public enum LLamaTokenType -{ - /// - /// No specific type has been set for this token - /// - LLAMA_TOKEN_TYPE_UNDEFINED = 0, - - /// - /// This is a "normal" token - /// - LLAMA_TOKEN_TYPE_NORMAL = 1, - - /// - /// An "unknown" character/text token e.g. <unk> - /// - LLAMA_TOKEN_TYPE_UNKNOWN = 2, - - /// - /// A special control token e.g. </s> - /// - LLAMA_TOKEN_TYPE_CONTROL = 3, - - LLAMA_TOKEN_TYPE_USER_DEFINED = 4, - - LLAMA_TOKEN_TYPE_UNUSED = 5, - - LLAMA_TOKEN_TYPE_BYTE = 6, -} \ No newline at end of file diff --git a/LLama/Native/NativeLogConfig.cs b/LLama/Native/NativeLogConfig.cs index 82b097fb3..939c4d7f2 100644 --- a/LLama/Native/NativeLogConfig.cs +++ b/LLama/Native/NativeLogConfig.cs @@ -1,4 +1,4 @@ -using System.Runtime.InteropServices; +using System.Runtime.InteropServices; using System.Text; using System.Threading; using Microsoft.Extensions.Logging; @@ -52,7 +52,7 @@ public static void llama_log_set(LLamaLogCallback? logCallback) { // We can't set the log method yet since that would cause the llama.dll to load. // Instead configure it to be set when the native library loading is done - NativeLibraryConfig.Instance.WithLogCallback(logCallback); + NativeLibraryConfig.All.WithLogCallback(logCallback); } } diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 9d5ca4ffd..4fe9f1ff5 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -536,10 +536,7 @@ public unsafe ulong GetState(byte* dest, ulong size) if (size < required) throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}"); - unsafe - { - return llama_state_get_data(this, dest); - } + return llama_state_get_data(this, dest); } /// @@ -589,17 +586,6 @@ public void SetSeed(uint seed) llama_set_rng_seed(this, seed); } - /// - /// Set the number of threads used for decoding - /// - /// n_threads is the number of threads used for generation (single token) - /// n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) - [Obsolete("Use `GenerationThreads` and `BatchThreads` properties")] - public void SetThreads(uint threads, uint threadsBatch) - { - llama_set_n_threads(this, threads, threadsBatch); - } - #region timing /// /// Get performance information