Skip to content

Commit

Permalink
Support shared in-memory databases (#88)
Browse files Browse the repository at this point in the history
* support shared in-memory databases

* support connection duplication for isolated in-memory database

* fix it's creating on-disk databases when using 'DataSource=:memory:?cache=shared' connection string
  • Loading branch information
yxmm-wxe authored Apr 3, 2023
1 parent 28706a9 commit 9a35b6e
Show file tree
Hide file tree
Showing 6 changed files with 203 additions and 10 deletions.
10 changes: 8 additions & 2 deletions DuckDB.NET.Data/ConnectionString/DuckDBConnectionStringParser.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using DuckDB.NET.Data.Internal;
using System;
using System.Collections.Generic;
using System.Linq;
Expand All @@ -10,7 +11,7 @@ public static DuckDBConnectionString Parse(string connectionString)
{
var properties = connectionString
.Split(new[] {';'}, StringSplitOptions.RemoveEmptyEntries)
.Select(pair => pair.Split(new[] {'='}, StringSplitOptions.RemoveEmptyEntries))
.Select(pair => pair.Split(new[] {'='}, 2, StringSplitOptions.RemoveEmptyEntries))
.ToDictionary(pair => pair[0].Trim(), pair => pair[1].Trim());

var dataSource = GetDataSource(properties);
Expand All @@ -22,7 +23,12 @@ public static DuckDBConnectionString Parse(string connectionString)

if (dataSource.Equals(DuckDBConnectionStringBuilder.InMemoryDataSource, StringComparison.OrdinalIgnoreCase))
{
dataSource = string.Empty;
dataSource = InMemoryDataSource.Default;
}

if (dataSource.Equals(DuckDBConnectionStringBuilder.InMemorySharedDataSource, StringComparison.OrdinalIgnoreCase))
{
dataSource = InMemoryDataSource.CacheShared;
}

return new DuckDBConnectionString(dataSource);
Expand Down
40 changes: 35 additions & 5 deletions DuckDB.NET.Data/DuckDBConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@ namespace DuckDB.NET.Data
public class DuckDBConnection : DbConnection
{
private readonly ConnectionManager connectionManager = ConnectionManager.Default;
private ConnectionReference connectionReference;
private ConnectionState connectionState = ConnectionState.Closed;

internal bool InMemoryDuplication { get; set; }
internal ConnectionReference ConnectionReference { get; set; }

#region Protected Properties

protected override DbProviderFactory DbProviderFactory => DuckDBClientFactory.Instance;
Expand All @@ -35,7 +37,7 @@ public DuckDBConnection(string connectionString)

public override string DataSource { get; }

internal DuckDBNativeConnection NativeConnection => connectionReference.NativeConnection;
internal DuckDBNativeConnection NativeConnection => ConnectionReference.NativeConnection;

public override string ServerVersion => NativeMethods.Startup.DuckDBLibraryVersion().ToManagedString(false);

Expand Down Expand Up @@ -63,9 +65,16 @@ public override void Open()
throw new InvalidOperationException("Connection is already open.");
}

var connectionString = DuckDBConnectionStringParser.Parse(ConnectionString);
if (InMemoryDuplication)
{
ConnectionReference = connectionManager.DuplicateConnectionReference(ConnectionReference);
}
else
{
var connectionString = DuckDBConnectionStringParser.Parse(ConnectionString);

connectionReference = connectionManager.GetConnectionReference(connectionString);
ConnectionReference = connectionManager.GetConnectionReference(connectionString);
}

connectionState = ConnectionState.Open;
}
Expand Down Expand Up @@ -129,7 +138,7 @@ protected override void Dispose(bool disposing)
{
if (connectionState == ConnectionState.Open)
{
connectionManager.ReturnConnectionReference(connectionReference);
connectionManager.ReturnConnectionReference(ConnectionReference);
connectionState = ConnectionState.Closed;
}
}
Expand All @@ -144,5 +153,26 @@ private void EnsureConnectionOpen([CallerMemberName]string operation = "")
throw new InvalidOperationException($"{operation} requires an open connection");
}
}

public DuckDBConnection Duplicate()
{
if (State != ConnectionState.Open)
{
throw new InvalidOperationException("Duplication requires an open connection");
}

if (!InMemoryDataSource.IsInMemoryDataSource(ConnectionReference.FileRefCounter.FileName))
{
throw new NotSupportedException();
}

var duplicatedConnection = new DuckDBConnection(ConnectionString)
{
InMemoryDuplication = true,
ConnectionReference = ConnectionReference,
};

return duplicatedConnection;
}
}
}
3 changes: 3 additions & 0 deletions DuckDB.NET.Data/DuckDBConnectionStringBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ public class DuckDBConnectionStringBuilder : DbConnectionStringBuilder
{
public const string InMemoryDataSource = ":memory:";
public const string InMemoryConnectionString = "DataSource=:memory:";

public const string InMemorySharedDataSource = ":memory:?cache=shared";
public const string InMemorySharedConnectionString = "DataSource=:meory:?cache=shared";

internal static readonly string[] DataSourceKeys = {"Data Source", "DataSource"};
private const string DataSourceKey = "DataSource";
Expand Down
25 changes: 22 additions & 3 deletions DuckDB.NET.Data/Internal/ConnectionManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ internal ConnectionReference GetConnectionReference(DuckDBConnectionString conne
{
var filename = connectionString.DataSource;

var fileRef = string.IsNullOrEmpty(filename) ? new FileRef("") : null;
var fileRef = InMemoryDataSource.IsDefault(filename) ? new FileRef("") : null;

//need to loop until we have a locked fileRef
//that is also in the cache
Expand Down Expand Up @@ -51,7 +51,7 @@ internal ConnectionReference GetConnectionReference(DuckDBConnectionString conne
{
if (fileRef.Database == null)
{
var path = filename == string.Empty ? null : filename;
var path = InMemoryDataSource.IsInMemoryDataSource(filename) ? null : filename;

var resultOpen = NativeMethods.Startup.DuckDBOpen(path, out fileRef.Database, new DuckDBConfig(), out var error);

Expand All @@ -76,7 +76,7 @@ internal ConnectionReference GetConnectionReference(DuckDBConnectionString conne
}
finally
{
if (!string.IsNullOrEmpty(filename))
if (!InMemoryDataSource.IsDefault(filename))
{
Monitor.Exit(fileRef);
}
Expand Down Expand Up @@ -112,5 +112,24 @@ internal void ReturnConnectionReference(ConnectionReference connectionReference)
}
}
}

internal ConnectionReference DuplicateConnectionReference(ConnectionReference connectionReference)
{
var fileRef = connectionReference.FileRefCounter;

lock (fileRef)
{
var resultConnect = NativeMethods.Startup.DuckDBConnect(fileRef.Database, out var duplicatedNativeConnection);
if (resultConnect.IsSuccess())
{
fileRef.Increment();
}
else
{
throw new DuckDBException("DuckDBConnect failed", resultConnect);
}
return new ConnectionReference(fileRef, duplicatedNativeConnection);
}
}
}
}
25 changes: 25 additions & 0 deletions DuckDB.NET.Data/Internal/InMemoryDataSource.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System;

namespace DuckDB.NET.Data.Internal
{
internal static class InMemoryDataSource
{
public static readonly string Default = string.Empty;
public static readonly string CacheShared = Guid.NewGuid().ToString();

public static bool IsInMemoryDataSource(string dataSource)
{
return IsDefault(dataSource) || IsCacheShared(dataSource);
}

public static bool IsDefault(string dataSource)
{
return dataSource == Default;
}

public static bool IsCacheShared(string dataSource)
{
return string.Equals(dataSource, CacheShared, StringComparison.OrdinalIgnoreCase);
}
}
}
110 changes: 110 additions & 0 deletions DuckDB.NET.Test/DuckDBConnectionTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -302,5 +302,115 @@ public void MultipleInMemoryConnectionsSeparateDatabases()

tableCount.Should().Be(1);
}

[Fact]
public void MultipleInMemoryConnectionsSharedDatabases()
{
var tableCount = 0;

using var firstConnection = new DuckDBConnection("DataSource=:memory:?cache=shared");
using var secondConnection = new DuckDBConnection("DataSource=:memory:?cache=shared");

firstConnection.Open();

var command = firstConnection.CreateCommand();
command.CommandText = "CREATE TABLE t1 (foo INTEGER, bar INTEGER);";
command.ExecuteNonQuery();

command.CommandText = "show tables;";
using (var dataReader = command.ExecuteReader())
{
while (dataReader.Read())
{
tableCount++;
}
}

tableCount.Should().Be(1);

// connection 2
tableCount = 0;
secondConnection.Open();
command = secondConnection.CreateCommand();

command.CommandText = "CREATE TABLE t2 (foo INTEGER, bar INTEGER);";
command.ExecuteNonQuery();

command.CommandText = "show tables;";
using (var dataReader = command.ExecuteReader())
{
while (dataReader.Read())
{
tableCount++;
}
}

tableCount.Should().Be(2);
}

[Fact]
public void DuplicateNotInMemoryConnectionError()
{
using var db1 = DisposableFile.GenerateInTemp("db", 1);
var cs = db1.ConnectionString;

using var duckDBConnection = new DuckDBConnection(cs);
duckDBConnection.Open();

duckDBConnection.Invoking(connection => connection.Duplicate()).Should().Throw<NotSupportedException>();
}

[Fact]
public void DuplicateInMemoryNotOpenedConnectionError()
{
using var duckDBConnection = new DuckDBConnection("DataSource =:memory:");

duckDBConnection.Invoking(connection => connection.Duplicate()).Should().Throw<InvalidOperationException>();
}

[Fact]
public void DuplicateInMemoryConnection()
{
var tableCount = 0;

using var firstConnection = new DuckDBConnection("DataSource=:memory:");

firstConnection.Open();

var command = firstConnection.CreateCommand();
command.CommandText = "CREATE TABLE t1 (foo INTEGER, bar INTEGER);";
command.ExecuteNonQuery();

command.CommandText = "show tables;";
using (var dataReader = command.ExecuteReader())
{
while (dataReader.Read())
{
tableCount++;
}
}

Assert.Equal(1, tableCount);

using var secondConnection = firstConnection.Duplicate();
// connection 2
tableCount = 0;
secondConnection.Open();
command = secondConnection.CreateCommand();

command.CommandText = "CREATE TABLE t2 (foo INTEGER, bar INTEGER);";
command.ExecuteNonQuery();

command.CommandText = "show tables;";
using (var dataReader = command.ExecuteReader())
{
while (dataReader.Read())
{
tableCount++;
}
}

Assert.Equal(2, tableCount);
}
}
}

0 comments on commit 9a35b6e

Please sign in to comment.