Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -682,11 +682,9 @@ public void Save(string path)
Directory.CreateDirectory(path);

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var bytes = ContextState?.ToByteArray();
if (bytes is not null)
{
File.WriteAllBytes(modelStateFilePath, bytes);
}
if (ContextState != null)
using (var stateStream = File.Create(modelStateFilePath))
ContextState?.Save(stateStream);

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
File.WriteAllText(executorStateFilepath, JsonSerializer.Serialize(ExecutorState));
Expand Down Expand Up @@ -722,10 +720,11 @@ public static SessionState Load(string path)
throw new ArgumentException("Directory does not exist", nameof(path));
}

string modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
var contextState = File.Exists(modelStateFilePath) ?
State.FromByteArray(File.ReadAllBytes(modelStateFilePath))
: null;
var modelStateFilePath = Path.Combine(path, ChatSession.MODEL_STATE_FILENAME);
State? contextState = default;
if (File.Exists(modelStateFilePath))
using (var modelStateStream = File.OpenRead(modelStateFilePath))
contextState = State.Load(modelStateStream);

string executorStateFilepath = Path.Combine(path, ChatSession.EXECUTOR_STATE_FILENAME);
var executorState = JsonSerializer.Deserialize<ExecutorBaseState>(File.ReadAllText(executorStateFilepath));
Expand Down
69 changes: 56 additions & 13 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
using LLama.Sampling;
using Microsoft.Extensions.Logging;
using System.Threading;
using System.Security.Cryptography;

namespace LLama
{
Expand Down Expand Up @@ -622,28 +623,70 @@ protected override bool ReleaseHandle()
}

/// <summary>
/// Convert this state to a byte array
/// Write all the bytes of this state to the given stream
/// </summary>
/// <param name="stream"></param>
public async Task SaveAsync(Stream stream)
{
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
}
await from.CopyToAsync(stream);
}

/// <summary>
/// Write all the bytes of this state to the given stream
/// </summary>
/// <param name="stream"></param>
public void Save(Stream stream)
{
UnmanagedMemoryStream from;
unsafe
{
from = new UnmanagedMemoryStream((byte*)handle.ToPointer(), checked((long)Size));
}
from.CopyTo(stream);
}

/// <summary>
/// Load a state from a stream
/// </summary>
/// <param name="stream"></param>
/// <returns></returns>
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public byte[] ToByteArray()
public static async Task<State> LoadAsync(Stream stream)
{
var bytes = new byte[_size];
Marshal.Copy(handle, bytes, 0, (int)_size);
return bytes;
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));

UnmanagedMemoryStream dest;
unsafe
{
dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
}
await stream.CopyToAsync(dest);

return state;
}

/// <summary>
/// Load state from a byte array
/// Load a state from a stream
/// </summary>
/// <param name="bytes"></param>
/// <param name="stream"></param>
/// <returns></returns>
[Obsolete("It is not generally safe to convert a state into a byte array - it will fail if the state is very large")]
public static State FromByteArray(byte[] bytes)
public static State Load(Stream stream)
{
var memory = Marshal.AllocHGlobal(bytes.Length);
Marshal.Copy(bytes, 0, memory, bytes.Length);
return new State(memory, (ulong)bytes.Length);
var memory = Marshal.AllocHGlobal((nint)stream.Length);
var state = new State(memory, checked((ulong)stream.Length));

unsafe
{
var dest = new UnmanagedMemoryStream((byte*)memory.ToPointer(), stream.Length);
stream.CopyTo(dest);
}

return state;
}
}

Expand Down
37 changes: 0 additions & 37 deletions LLama/Native/LLamaTokenType.cs

This file was deleted.

4 changes: 2 additions & 2 deletions LLama/Native/NativeLogConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System.Runtime.InteropServices;
using System.Runtime.InteropServices;
using System.Text;
using System.Threading;
using Microsoft.Extensions.Logging;
Expand Down Expand Up @@ -52,7 +52,7 @@ public static void llama_log_set(LLamaLogCallback? logCallback)
{
// We can't set the log method yet since that would cause the llama.dll to load.
// Instead configure it to be set when the native library loading is done
NativeLibraryConfig.Instance.WithLogCallback(logCallback);
NativeLibraryConfig.All.WithLogCallback(logCallback);
}
}

Expand Down
16 changes: 1 addition & 15 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -536,10 +536,7 @@ public unsafe ulong GetState(byte* dest, ulong size)
if (size < required)
throw new ArgumentOutOfRangeException(nameof(size), $"Allocated space is too small, {size} < {required}");

unsafe
{
return llama_state_get_data(this, dest);
}
return llama_state_get_data(this, dest);
}

/// <summary>
Expand Down Expand Up @@ -589,17 +586,6 @@ public void SetSeed(uint seed)
llama_set_rng_seed(this, seed);
}

/// <summary>
/// Set the number of threads used for decoding
/// </summary>
/// <param name="threads">n_threads is the number of threads used for generation (single token)</param>
/// <param name="threadsBatch">n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens)</param>
[Obsolete("Use `GenerationThreads` and `BatchThreads` properties")]
public void SetThreads(uint threads, uint threadsBatch)
{
llama_set_n_threads(this, threads, threadsBatch);
}

#region timing
/// <summary>
/// Get performance information
Expand Down