diff --git a/src/MongoDB.Bson/IO/ElementAppendingBsonWriter.cs b/src/MongoDB.Bson/IO/ElementAppendingBsonWriter.cs
index a587e873efc..046a5f909ac 100644
--- a/src/MongoDB.Bson/IO/ElementAppendingBsonWriter.cs
+++ b/src/MongoDB.Bson/IO/ElementAppendingBsonWriter.cs
@@ -74,6 +74,30 @@ public override void WriteEndDocument()
base.WriteEndDocument();
}
+ public override void WriteRawBsonDocument(IByteBuffer slice)
+ {
+ WriteStartDocument();
+
+ if (Wrapped is BsonBinaryWriter binaryWriter)
+ {
+ // just copy the bytes (without the length and terminating null)
+ var lengthBytes = new byte[4];
+ slice.GetBytes(0, lengthBytes, 0, 4);
+ var length = BitConverter.ToInt32(lengthBytes, 0);
+ using (var elements = slice.GetSlice(4, length - 5))
+ {
+ var stream = binaryWriter.BsonStream;
+ stream.WriteSlice(elements);
+ }
+ }
+ else
+ {
+ throw new NotSupportedException("WriteRawBsonDocument supports only BsonBinaryWriter.");
+ }
+
+ WriteEndDocument();
+ }
+
///
public override void WriteStartDocument()
{
diff --git a/src/MongoDB.Bson/IO/JsonReader.cs b/src/MongoDB.Bson/IO/JsonReader.cs
index 3c2e9e96132..f33ec492c3c 100644
--- a/src/MongoDB.Bson/IO/JsonReader.cs
+++ b/src/MongoDB.Bson/IO/JsonReader.cs
@@ -1195,7 +1195,7 @@ private BsonValue ParseDateTimeExtendedJson()
}
else if (valueToken.Type == JsonTokenType.BeginObject)
{
- VerifyToken("$numberLong");
+ VerifyString("$numberLong");
VerifyToken(":");
var millisecondsSinceEpochToken = PopToken();
if (millisecondsSinceEpochToken.Type == JsonTokenType.String)
diff --git a/src/MongoDB.Bson/ObjectModel/BsonBinarySubType.cs b/src/MongoDB.Bson/ObjectModel/BsonBinarySubType.cs
index d960b140128..eb617e0c010 100644
--- a/src/MongoDB.Bson/ObjectModel/BsonBinarySubType.cs
+++ b/src/MongoDB.Bson/ObjectModel/BsonBinarySubType.cs
@@ -51,6 +51,10 @@ public enum BsonBinarySubType
///
MD5 = 0x05,
///
+ /// Encrypted binary data.
+ ///
+ Encrypted = 0x06,
+ ///
/// User defined binary data.
///
UserDefined = 0x80
diff --git a/src/MongoDB.Driver.Core/Core/Clusters/Cluster.cs b/src/MongoDB.Driver.Core/Core/Clusters/Cluster.cs
index 2f8a1504729..249c75135ab 100644
--- a/src/MongoDB.Driver.Core/Core/Clusters/Cluster.cs
+++ b/src/MongoDB.Driver.Core/Core/Clusters/Cluster.cs
@@ -26,6 +26,7 @@
using MongoDB.Driver.Core.Events;
using MongoDB.Driver.Core.Misc;
using MongoDB.Driver.Core.Servers;
+using MongoDB.Libmongocrypt;
namespace MongoDB.Driver.Core.Clusters
{
@@ -62,6 +63,7 @@ internal abstract class Cluster : ICluster
// fields
private readonly IClusterClock _clusterClock = new ClusterClock();
private readonly ClusterId _clusterId;
+ private CryptClient _cryptClient = null;
private ClusterDescription _description;
private TaskCompletionSource _descriptionChangedTaskCompletionSource;
private readonly object _descriptionLock = new object();
@@ -109,6 +111,11 @@ public ClusterId ClusterId
get { return _clusterId; }
}
+ public CryptClient CryptClient
+ {
+ get { return _cryptClient; }
+ }
+
public ClusterDescription Description
{
get
@@ -188,7 +195,13 @@ private void ExitServerSelectionWaitQueue()
public virtual void Initialize()
{
ThrowIfDisposed();
- _state.TryChange(State.Initial, State.Open);
+ if (_state.TryChange(State.Initial, State.Open))
+ {
+ if (_settings.KmsProviders != null || _settings.SchemaMap != null)
+ {
+ _cryptClient = CryptClientCreator.CreateCryptClient(_settings.KmsProviders, _settings.SchemaMap);
+ }
+ }
}
private void RapidHeartbeatTimerCallback(object args)
@@ -350,7 +363,7 @@ private async Task WaitForDescriptionChangedAsync(IServerSelector selector, Clus
{
using (var helper = new WaitForDescriptionChangedHelper(this, selector, description, descriptionChangedTask, timeout, cancellationToken))
{
- var completedTask = await Task.WhenAny(helper.Tasks).ConfigureAwait(false);
+ var completedTask = await Task.WhenAny(helper.Tasks).ConfigureAwait(false);
helper.HandleCompletedTask(completedTask);
}
}
@@ -528,7 +541,7 @@ private sealed class WaitForDescriptionChangedHelper : IDisposable
private readonly CancellationTokenSource _timeoutCancellationTokenSource;
private readonly Task _timeoutTask;
- public WaitForDescriptionChangedHelper(Cluster cluster, IServerSelector selector, ClusterDescription description, Task descriptionChangedTask , TimeSpan timeout, CancellationToken cancellationToken)
+ public WaitForDescriptionChangedHelper(Cluster cluster, IServerSelector selector, ClusterDescription description, Task descriptionChangedTask, TimeSpan timeout, CancellationToken cancellationToken)
{
_cluster = cluster;
_description = description;
diff --git a/src/MongoDB.Driver.Core/Core/Clusters/CryptClientCreator.cs b/src/MongoDB.Driver.Core/Core/Clusters/CryptClientCreator.cs
new file mode 100644
index 00000000000..8ece458840a
--- /dev/null
+++ b/src/MongoDB.Driver.Core/Core/Clusters/CryptClientCreator.cs
@@ -0,0 +1,105 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using MongoDB.Bson;
+using MongoDB.Bson.IO;
+using MongoDB.Driver.Core.Misc;
+using MongoDB.Libmongocrypt;
+
+namespace MongoDB.Driver.Core.Clusters
+{
+ ///
+ /// Represents a creator for CryptClient.
+ ///
+ public sealed class CryptClientCreator
+ {
+ #region static
+#pragma warning disable 3002
+ ///
+ /// Create a CryptClient instance.
+ ///
+ /// The kms providers.
+ /// The schema map.
+ /// The CryptClient instance.
+ public static CryptClient CreateCryptClient(
+ IReadOnlyDictionary> kmsProviders,
+ IReadOnlyDictionary schemaMap)
+ {
+ var helper = new CryptClientCreator(kmsProviders, schemaMap);
+ var cryptOptions = helper.CreateCryptOptions();
+ return helper.CreateCryptClient(cryptOptions);
+ }
+#pragma warning restore
+ #endregion
+
+ private readonly IReadOnlyDictionary> _kmsProviders;
+ private readonly IReadOnlyDictionary _schemaMap;
+
+ private CryptClientCreator(
+ IReadOnlyDictionary> kmsProviders,
+ IReadOnlyDictionary schemaMap)
+ {
+ _kmsProviders = Ensure.IsNotNull(kmsProviders, nameof(kmsProviders));
+ _schemaMap = schemaMap;
+ }
+
+ private CryptClient CreateCryptClient(CryptOptions options)
+ {
+ return CryptClientFactory.Create(options);
+ }
+
+ private CryptOptions CreateCryptOptions()
+ {
+ Dictionary kmsProvidersMap = null;
+ if (_kmsProviders != null && _kmsProviders.Count > 0)
+ {
+ kmsProvidersMap = new Dictionary();
+ if (_kmsProviders.TryGetValue("aws", out var awsProvider))
+ {
+ if (awsProvider.TryGetValue("accessKeyId", out var accessKeyId) &&
+ awsProvider.TryGetValue("secretAccessKey", out var secretAccessKey))
+ {
+ kmsProvidersMap.Add(KmsType.Aws, new AwsKmsCredentials((string)secretAccessKey, (string)accessKeyId));
+ }
+ }
+ if (_kmsProviders.TryGetValue("local", out var localProvider))
+ {
+ if (localProvider.TryGetValue("key", out var keyObject) && keyObject is byte[] key)
+ {
+ kmsProvidersMap.Add(KmsType.Local, new LocalKmsCredentials(key));
+ }
+ }
+ }
+ else
+ {
+ throw new ArgumentException("At least one kms provider must be specified");
+ }
+
+ byte[] schemaBytes = null;
+ if (_schemaMap != null)
+ {
+ var schemaMapElements = _schemaMap.Select(c => new BsonElement(c.Key, c.Value));
+ var schemaDocument = new BsonDocument(schemaMapElements);
+ var writerSettings = new BsonBinaryWriterSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ schemaBytes = schemaDocument.ToBson(writerSettings: writerSettings);
+ }
+
+ return new CryptOptions(kmsProvidersMap, schemaBytes);
+ }
+ }
+}
diff --git a/src/MongoDB.Driver.Core/Core/Clusters/ICluster.cs b/src/MongoDB.Driver.Core/Core/Clusters/ICluster.cs
index ca47a1936fc..5114db4c398 100644
--- a/src/MongoDB.Driver.Core/Core/Clusters/ICluster.cs
+++ b/src/MongoDB.Driver.Core/Core/Clusters/ICluster.cs
@@ -20,6 +20,7 @@
using MongoDB.Driver.Core.Clusters.ServerSelectors;
using MongoDB.Driver.Core.Configuration;
using MongoDB.Driver.Core.Servers;
+using MongoDB.Libmongocrypt;
namespace MongoDB.Driver.Core.Clusters
{
@@ -66,6 +67,14 @@ public interface ICluster : IDisposable
/// A core server session.
ICoreServerSession AcquireServerSession();
+ ///
+ /// Gets the crypt client.
+ ///
+ /// A crypt client.
+#pragma warning disable CS3003
+ CryptClient CryptClient { get; }
+#pragma warning restore
+
///
/// Initializes the cluster.
///
diff --git a/src/MongoDB.Driver.Core/Core/Configuration/ClusterSettings.cs b/src/MongoDB.Driver.Core/Core/Configuration/ClusterSettings.cs
index 791491c54f0..383edc1c0c6 100644
--- a/src/MongoDB.Driver.Core/Core/Configuration/ClusterSettings.cs
+++ b/src/MongoDB.Driver.Core/Core/Configuration/ClusterSettings.cs
@@ -17,6 +17,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net;
+using MongoDB.Bson;
using MongoDB.Driver.Core.Clusters;
using MongoDB.Driver.Core.Clusters.ServerSelectors;
using MongoDB.Driver.Core.Misc;
@@ -36,8 +37,10 @@ public class ClusterSettings
// fields
private readonly ClusterConnectionMode _connectionMode;
private readonly IReadOnlyList _endPoints;
+ private readonly IReadOnlyDictionary> _kmsProviders;
private readonly int _maxServerSelectionWaitQueueSize;
private readonly string _replicaSetName;
+ private readonly IReadOnlyDictionary _schemaMap;
private readonly ConnectionStringScheme _scheme;
private readonly TimeSpan _serverSelectionTimeout;
private readonly IServerSelector _preServerSelector;
@@ -49,30 +52,36 @@ public class ClusterSettings
///
/// The connection mode.
/// The end points.
+ /// The kms providers.
/// Maximum size of the server selection wait queue.
/// Name of the replica set.
/// The server selection timeout.
/// The pre server selector.
/// The post server selector.
+ /// The schema map.
/// The connection string scheme.
public ClusterSettings(
Optional connectionMode = default(Optional),
Optional> endPoints = default(Optional>),
+ Optional>> kmsProviders = default(Optional>>),
Optional maxServerSelectionWaitQueueSize = default(Optional),
Optional replicaSetName = default(Optional),
Optional serverSelectionTimeout = default(Optional),
Optional preServerSelector = default(Optional),
Optional postServerSelector = default(Optional),
+ Optional> schemaMap = default(Optional>),
Optional scheme = default(Optional))
{
_connectionMode = connectionMode.WithDefault(ClusterConnectionMode.Automatic);
_endPoints = Ensure.IsNotNull(endPoints.WithDefault(__defaultEndPoints), "endPoints").ToList();
+ _kmsProviders = kmsProviders.WithDefault(null);
_maxServerSelectionWaitQueueSize = Ensure.IsGreaterThanOrEqualToZero(maxServerSelectionWaitQueueSize.WithDefault(500), "maxServerSelectionWaitQueueSize");
_replicaSetName = replicaSetName.WithDefault(null);
_serverSelectionTimeout = Ensure.IsGreaterThanOrEqualToZero(serverSelectionTimeout.WithDefault(TimeSpan.FromSeconds(30)), "serverSelectionTimeout");
_preServerSelector = preServerSelector.WithDefault(null);
_postServerSelector = postServerSelector.WithDefault(null);
_scheme = scheme.WithDefault(ConnectionStringScheme.MongoDB);
+ _schemaMap = schemaMap.WithDefault(null);
}
// properties
@@ -98,6 +107,17 @@ public IReadOnlyList EndPoints
get { return _endPoints; }
}
+ ///
+ /// Gets the kms providers.
+ ///
+ ///
+ /// The kms providers.
+ ///
+ public IReadOnlyDictionary> KmsProviders
+ {
+ get { return _kmsProviders; }
+ }
+
///
/// Gets the maximum size of the server selection wait queue.
///
@@ -120,6 +140,17 @@ public string ReplicaSetName
get { return _replicaSetName; }
}
+ ///
+ /// Gets the schema map.
+ ///
+ ///
+ /// The schema map.
+ ///
+ public IReadOnlyDictionary SchemaMap
+ {
+ get { return _schemaMap; }
+ }
+
///
/// Gets the connection string scheme.
///
@@ -170,31 +201,37 @@ public IServerSelector PostServerSelector
///
/// The connection mode.
/// The end points.
+ /// The kms providers.
/// Maximum size of the server selection wait queue.
/// Name of the replica set.
/// The server selection timeout.
/// The pre server selector.
/// The post server selector.
+ /// The schema map.
/// The connection string scheme.
/// A new ClusterSettings instance.
public ClusterSettings With(
Optional connectionMode = default(Optional),
Optional> endPoints = default(Optional>),
+ Optional>> kmsProviders = default(Optional>>),
Optional maxServerSelectionWaitQueueSize = default(Optional),
Optional replicaSetName = default(Optional),
Optional serverSelectionTimeout = default(Optional),
Optional preServerSelector = default(Optional),
Optional postServerSelector = default(Optional),
+ Optional> schemaMap = default(Optional>),
Optional scheme = default(Optional))
{
return new ClusterSettings(
connectionMode: connectionMode.WithDefault(_connectionMode),
endPoints: Optional.Enumerable(endPoints.WithDefault(_endPoints)),
+ kmsProviders: Optional.Create(kmsProviders.WithDefault(_kmsProviders)),
maxServerSelectionWaitQueueSize: maxServerSelectionWaitQueueSize.WithDefault(_maxServerSelectionWaitQueueSize),
replicaSetName: replicaSetName.WithDefault(_replicaSetName),
serverSelectionTimeout: serverSelectionTimeout.WithDefault(_serverSelectionTimeout),
preServerSelector: Optional.Create(preServerSelector.WithDefault(_preServerSelector)),
postServerSelector: Optional.Create(postServerSelector.WithDefault(_postServerSelector)),
+ schemaMap: Optional.Create(schemaMap.WithDefault(_schemaMap)),
scheme: scheme.WithDefault(_scheme));
}
}
diff --git a/src/MongoDB.Driver.Core/Core/Misc/Feature.cs b/src/MongoDB.Driver.Core/Core/Misc/Feature.cs
index c7450f5a73d..90b038813d6 100644
--- a/src/MongoDB.Driver.Core/Core/Misc/Feature.cs
+++ b/src/MongoDB.Driver.Core/Core/Misc/Feature.cs
@@ -41,6 +41,7 @@ public class Feature
private static readonly Feature __bypassDocumentValidation = new Feature("BypassDocumentValidation", new SemanticVersion(3, 2, 0));
private static readonly Feature __changeStreamStage = new Feature("ChangeStreamStage", new SemanticVersion(3, 5, 11));
private static readonly Feature __changeStreamPostBatchResumeToken = new Feature("ChangeStreamPostBatchResumeToken", new SemanticVersion(4, 0 ,7));
+ private static readonly Feature __clientSideEncryption = new Feature("ClientSideEncryption", new SemanticVersion(4, 1, 9));
private static readonly CollationFeature __collation = new CollationFeature("Collation", new SemanticVersion(3, 3, 11));
private static readonly Feature __commandMessage = new Feature("CommandMessage", new SemanticVersion(3, 6, 0));
private static readonly CommandsThatWriteAcceptWriteConcernFeature __commandsThatWriteAcceptWriteConcern = new CommandsThatWriteAcceptWriteConcernFeature("CommandsThatWriteAcceptWriteConcern", new SemanticVersion(3, 3, 11));
@@ -170,6 +171,11 @@ public class Feature
///
public static Feature ChangeStreamPostBatchResumeToken => __changeStreamPostBatchResumeToken;
+ ///
+ /// Gets the client side encryption feature.
+ ///
+ public static Feature ClientSideEncryption => __clientSideEncryption;
+
///
/// Gets the collation feature.
///
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldDecryptor.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldDecryptor.cs
new file mode 100644
index 00000000000..6e1bc9bbea9
--- /dev/null
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldDecryptor.cs
@@ -0,0 +1,79 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Bson.Serialization.Serializers;
+using MongoDB.Driver.Core.WireProtocol.Messages;
+using MongoDB.Driver.Core.WireProtocol.Messages.Encoders;
+
+namespace MongoDB.Driver.Core.WireProtocol
+{
+ internal class CommandMessageFieldDecryptor
+ {
+ // private fields
+ private readonly IBinaryDocumentFieldDecryptor _documentFieldDecryptor;
+ private readonly MessageEncoderSettings _messageEncoderSettings;
+
+ // constructors
+ public CommandMessageFieldDecryptor(IBinaryDocumentFieldDecryptor documentFieldDecryptor, MessageEncoderSettings messageEncoderSettings)
+ {
+ _documentFieldDecryptor = documentFieldDecryptor;
+ _messageEncoderSettings = messageEncoderSettings;
+ }
+
+ // public methods
+ public CommandResponseMessage DecryptFields(CommandResponseMessage encryptedResponseMessage, CancellationToken cancellationToken)
+ {
+ var encryptedDocumentBytes = GetEncryptedDocumentBytes(encryptedResponseMessage);
+ var unencryptedDocumentBytes = _documentFieldDecryptor.DecryptFields(encryptedDocumentBytes, cancellationToken);
+ return CreateUnencryptedResponseMessage(encryptedResponseMessage, unencryptedDocumentBytes);
+ }
+
+ public async Task DecryptFieldsAsync(CommandResponseMessage encryptedResponseMessage, CancellationToken cancellationToken)
+ {
+ var encryptedDocumentBytes = GetEncryptedDocumentBytes(encryptedResponseMessage);
+ var unencryptedDocumentBytes = await _documentFieldDecryptor.DecryptFieldsAsync(encryptedDocumentBytes, cancellationToken).ConfigureAwait(false);
+ return CreateUnencryptedResponseMessage(encryptedResponseMessage, unencryptedDocumentBytes);
+ }
+
+ // private methods
+ private CommandResponseMessage CreateUnencryptedResponseMessage(CommandResponseMessage encryptedResponseMessage, byte[] unencryptedDocumentBytes)
+ {
+ var unencryptedDocument = new RawBsonDocument(unencryptedDocumentBytes);
+ var unencryptedSections = new[] { new Type0CommandMessageSection(unencryptedDocument, RawBsonDocumentSerializer.Instance) };
+ var encryptedCommandMessage = encryptedResponseMessage.WrappedMessage;
+ var unencryptedCommandMessage = new CommandMessage(
+ encryptedCommandMessage.RequestId,
+ encryptedCommandMessage.ResponseTo,
+ unencryptedSections,
+ encryptedCommandMessage.MoreToCome);
+ return new CommandResponseMessage(unencryptedCommandMessage);
+ }
+
+ private byte[] GetEncryptedDocumentBytes(CommandResponseMessage encryptedResponseMessage)
+ {
+ var encryptedCommandMessage = encryptedResponseMessage.WrappedMessage;
+ var encryptedSections = encryptedCommandMessage.Sections;
+ var encryptedType0Section = (Type0CommandMessageSection)encryptedSections.Single();
+ var encryptedDocumentSlice = encryptedType0Section.Document.Slice;
+ var encryptedDocumentBytes = new byte[encryptedDocumentSlice.Length];
+ encryptedDocumentSlice.GetBytes(0, encryptedDocumentBytes, 0, encryptedDocumentBytes.Length);
+ return encryptedDocumentBytes;
+ }
+ }
+}
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldEncryptor.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldEncryptor.cs
new file mode 100644
index 00000000000..8f083f7362e
--- /dev/null
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandMessageFieldEncryptor.cs
@@ -0,0 +1,183 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.IO;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Bson.IO;
+using MongoDB.Bson.Serialization.Serializers;
+using MongoDB.Driver.Core.WireProtocol.Messages;
+using MongoDB.Driver.Core.WireProtocol.Messages.Encoders;
+using MongoDB.Driver.Core.WireProtocol.Messages.Encoders.BinaryEncoders;
+
+namespace MongoDB.Driver.Core.WireProtocol
+{
+ internal class CommandMessageFieldEncryptor
+ {
+ // private fields
+ private readonly byte[] _buffer = new byte[1024];
+ private readonly IBinaryCommandFieldEncryptor _commandFieldEncryptor;
+ private readonly MessageEncoderSettings _messageEncoderSettings;
+
+ // constructors
+ public CommandMessageFieldEncryptor(IBinaryCommandFieldEncryptor commandFieldEncryptor, MessageEncoderSettings messageEncoderSettings)
+ {
+ _commandFieldEncryptor = commandFieldEncryptor;
+ _messageEncoderSettings = messageEncoderSettings;
+ }
+
+ // public static methods
+ public CommandRequestMessage EncryptFields(string databaseName, CommandRequestMessage unencryptedRequestMessage, CancellationToken cancellationToken)
+ {
+ var unencryptedCommandBytes = GetUnencryptedCommandBytes(unencryptedRequestMessage);
+ var encryptedCommandBytes = _commandFieldEncryptor.EncryptFields(databaseName, unencryptedCommandBytes, cancellationToken);
+ return CreateEncryptedRequestMessage(unencryptedRequestMessage, encryptedCommandBytes);
+ }
+
+ public async Task EncryptFieldsAsync(string databaseName, CommandRequestMessage unencryptedRequestMessage, CancellationToken cancellationToken)
+ {
+ var unencryptedCommandBytes = GetUnencryptedCommandBytes(unencryptedRequestMessage);
+ var encryptedCommandBytes = await _commandFieldEncryptor.EncryptFieldsAsync(databaseName, unencryptedCommandBytes, cancellationToken).ConfigureAwait(false);
+ return CreateEncryptedRequestMessage(unencryptedRequestMessage, encryptedCommandBytes);
+ }
+
+ // private static methods
+ private byte[] CombineCommandMessageSectionsIntoSingleDocument(Stream stream)
+ {
+ using (var inputStream = new BsonStreamAdapter(stream, ownsStream: false))
+ using (var memoryStream = new MemoryStream())
+ using (var outputStream = new BsonStreamAdapter(memoryStream, ownsStream: false))
+ {
+ var messageStartPosition = inputStream.Position;
+ var messageLength = inputStream.ReadInt32();
+ var messageEndPosition = messageStartPosition + messageLength;
+ var requestId = inputStream.ReadInt32();
+ var responseTo = inputStream.ReadInt32();
+ var opcode = inputStream.ReadInt32();
+ var flags = (OpMsgFlags)inputStream.ReadInt32();
+ if (flags.HasFlag(OpMsgFlags.ChecksumPresent))
+ {
+ messageEndPosition -= 4; // ignore checksum
+ }
+
+ CopyType0Section(inputStream, outputStream);
+ outputStream.Position -= 1;
+ while (inputStream.Position < messageEndPosition)
+ {
+ CopyType1Section(inputStream, outputStream);
+ }
+ outputStream.WriteByte(0);
+ outputStream.BackpatchSize(0);
+
+ return memoryStream.ToArray();
+ }
+ }
+
+ private void CopyBsonDocument(BsonStream inputStream, BsonStream outputStream)
+ {
+ var documentLength = inputStream.ReadInt32();
+ inputStream.Position -= 4;
+ CopyBytes(inputStream, outputStream, documentLength);
+ }
+
+ private void CopyBytes(BsonStream inputStream, BsonStream outputStream, int count)
+ {
+ while (count > 0)
+ {
+ var chunkSize = Math.Min(count, _buffer.Length);
+ inputStream.ReadBytes(_buffer, 0, chunkSize);
+ outputStream.WriteBytes(_buffer, 0, chunkSize);
+ count -= chunkSize;
+ }
+ }
+
+ private void CopyType0Section(BsonStream inputStream, BsonStream outputStream)
+ {
+ var payloadType = (PayloadType)inputStream.ReadByte();
+ if (payloadType != PayloadType.Type0)
+ {
+ throw new FormatException("Expected first section to be of type 0.");
+ }
+
+ CopyBsonDocument(inputStream, outputStream);
+ }
+
+ private void CopyType1Section(BsonStream inputStream, BsonStream outputStream)
+ {
+ var payloadType = (PayloadType)inputStream.ReadByte();
+ if (payloadType != PayloadType.Type1)
+ {
+ throw new FormatException("Expected subsequent sections to be of type 1.");
+ }
+
+ var sectionStartPosition = inputStream.Position;
+ var sectionSize = inputStream.ReadInt32();
+ var sectionEndPosition = sectionStartPosition + sectionSize;
+ var identifier = inputStream.ReadCString(Utf8Encodings.Lenient);
+
+ outputStream.WriteByte((byte)BsonType.Array);
+ outputStream.WriteCString(identifier);
+ var arrayStartPosition = outputStream.Position;
+ outputStream.WriteInt32(0); // array length will be backpatched
+ var index = 0;
+ while (inputStream.Position < sectionEndPosition)
+ {
+ outputStream.WriteByte((byte)BsonType.Document);
+ outputStream.WriteCString(index.ToString());
+ CopyBsonDocument(inputStream, outputStream);
+ index++;
+ }
+ outputStream.WriteByte(0);
+ outputStream.BackpatchSize(arrayStartPosition);
+ }
+
+ private CommandRequestMessage CreateEncryptedRequestMessage(CommandRequestMessage unencryptedRequestMessage, byte[] encryptedDocumentBytes)
+ {
+ var encryptedDocument = new RawBsonDocument(encryptedDocumentBytes);
+ var encryptedSections = new[] { new Type0CommandMessageSection(encryptedDocument, RawBsonDocumentSerializer.Instance) };
+ var unencryptedCommandMessage = unencryptedRequestMessage.WrappedMessage;
+ var encryptedCommandMessage = new CommandMessage(
+ unencryptedCommandMessage.RequestId,
+ unencryptedCommandMessage.ResponseTo,
+ encryptedSections,
+ unencryptedCommandMessage.MoreToCome);
+ return new CommandRequestMessage(encryptedCommandMessage, unencryptedRequestMessage.ShouldBeSent);
+ }
+
+ private byte[] GetUnencryptedCommandBytes(CommandRequestMessage unencryptedRequestMessage)
+ {
+ using (var stream = new MemoryStream())
+ {
+ WriteUnencryptedRequestMessageToStream(stream, unencryptedRequestMessage);
+ stream.Position = 0;
+ return CombineCommandMessageSectionsIntoSingleDocument(stream);
+ }
+ }
+
+ private void WriteUnencryptedRequestMessageToStream(
+ Stream stream,
+ CommandRequestMessage unencryptedRequestMessage)
+ {
+ var clonedMessageEncoderSettings = _messageEncoderSettings.Clone();
+ clonedMessageEncoderSettings.Set(MessageEncoderSettingsName.MaxDocumentSize, 2097152);
+ clonedMessageEncoderSettings.Set(MessageEncoderSettingsName.MaxMessageSize, 2097152 + 16384);
+ var encoderFactory = new BinaryMessageEncoderFactory(stream, clonedMessageEncoderSettings, compressorSource: null);
+ var encoder = encoderFactory.GetCommandRequestMessageEncoder();
+ encoder.WriteMessage(unencryptedRequestMessage);
+ }
+ }
+}
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs
index 6a1543e621c..bd5ea7936cf 100644
--- a/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/CommandUsingCommandMessageWireProtocol.cs
@@ -41,6 +41,8 @@ internal class CommandUsingCommandMessageWireProtocol : IWirePro
private readonly List _commandPayloads;
private readonly IElementNameValidator _commandValidator; // TODO: how can this be supported when using CommandMessage?
private readonly DatabaseNamespace _databaseNamespace;
+ private readonly IBinaryDocumentFieldDecryptor _documentFieldDecryptor;
+ private readonly IBinaryCommandFieldEncryptor _documentFieldEncryptor;
private readonly MessageEncoderSettings _messageEncoderSettings;
private readonly Action _postWriteAction;
private readonly ReadPreference _readPreference;
@@ -78,6 +80,12 @@ internal class CommandUsingCommandMessageWireProtocol : IWirePro
_resultSerializer = Ensure.IsNotNull(resultSerializer, nameof(resultSerializer));
_messageEncoderSettings = messageEncoderSettings;
_postWriteAction = postWriteAction; // can be null
+
+ if (messageEncoderSettings != null)
+ {
+ _documentFieldDecryptor = messageEncoderSettings.GetOrDefault(MessageEncoderSettingsName.BinaryDocumentFieldDecryptor, null);
+ _documentFieldEncryptor = messageEncoderSettings.GetOrDefault(MessageEncoderSettingsName.BinaryDocumentFieldEncryptor, null);
+ }
}
// public methods
@@ -86,6 +94,7 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella
try
{
var message = CreateCommandMessage(connection.Description);
+ message = AutoEncryptFieldsIfNecessary(message, connection, cancellationToken);
try
{
@@ -100,6 +109,7 @@ public TCommandResult Execute(IConnection connection, CancellationToken cancella
{
var encoderSelector = new CommandResponseMessageEncoderSelector();
var response = (CommandResponseMessage)connection.ReceiveMessage(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken);
+ response = AutoDecryptFieldsIfNecessary(response, cancellationToken);
return ProcessResponse(connection.ConnectionId, response.WrappedMessage);
}
else
@@ -123,6 +133,7 @@ public async Task ExecuteAsync(IConnection connection, Cancellat
try
{
var message = CreateCommandMessage(connection.Description);
+ message = await AutoEncryptFieldsIfNecessaryAsync(message, connection, cancellationToken).ConfigureAwait(false);
try
{
@@ -137,6 +148,7 @@ public async Task ExecuteAsync(IConnection connection, Cancellat
{
var encoderSelector = new CommandResponseMessageEncoderSelector();
var response = (CommandResponseMessage)await connection.ReceiveMessageAsync(message.RequestId, encoderSelector, _messageEncoderSettings, cancellationToken).ConfigureAwait(false);
+ response = await AutoDecryptFieldsIfNecessaryAsync(response, cancellationToken).ConfigureAwait(false);
return ProcessResponse(connection.ConnectionId, response.WrappedMessage);
}
else
@@ -156,6 +168,68 @@ public async Task ExecuteAsync(IConnection connection, Cancellat
}
// private methods
+ private CommandResponseMessage AutoDecryptFieldsIfNecessary(CommandResponseMessage encryptedResponseMessage, CancellationToken cancellationToken)
+ {
+ if (_documentFieldDecryptor == null)
+ {
+ return encryptedResponseMessage;
+ }
+ else
+ {
+ var messageFieldDecryptor = new CommandMessageFieldDecryptor(_documentFieldDecryptor, _messageEncoderSettings);
+ return messageFieldDecryptor.DecryptFields(encryptedResponseMessage, cancellationToken);
+ }
+ }
+
+ private async Task AutoDecryptFieldsIfNecessaryAsync(CommandResponseMessage encryptedResponseMessage, CancellationToken cancellationToken)
+ {
+ if (_documentFieldDecryptor == null)
+ {
+ return encryptedResponseMessage;
+ }
+ else
+ {
+ var messageFieldDecryptor = new CommandMessageFieldDecryptor(_documentFieldDecryptor, _messageEncoderSettings);
+ return await messageFieldDecryptor.DecryptFieldsAsync(encryptedResponseMessage, cancellationToken).ConfigureAwait(false);
+ }
+ }
+
+ private CommandRequestMessage AutoEncryptFieldsIfNecessary(CommandRequestMessage unencryptedRequestMessage, IConnection connection, CancellationToken cancellationToken)
+ {
+ if (_documentFieldEncryptor == null)
+ {
+ return unencryptedRequestMessage;
+ }
+ else
+ {
+ if (connection.Description.IsMasterResult.MaxWireVersion < 8)
+ {
+ throw new NotSupportedException("Auto-encryption requires a minimum MongoDB version of 4.2.");
+ }
+
+ var helper = new CommandMessageFieldEncryptor(_documentFieldEncryptor, _messageEncoderSettings);
+ return helper.EncryptFields(_databaseNamespace.DatabaseName, unencryptedRequestMessage, cancellationToken);
+ }
+ }
+
+ private async Task AutoEncryptFieldsIfNecessaryAsync(CommandRequestMessage unencryptedRequestMessage, IConnection connection, CancellationToken cancellationToken)
+ {
+ if (_documentFieldEncryptor == null)
+ {
+ return unencryptedRequestMessage;
+ }
+ else
+ {
+ if (connection.Description.IsMasterResult.MaxWireVersion < 8)
+ {
+ throw new NotSupportedException("Auto-encryption requires a minimum MongoDB version of 4.2.");
+ }
+
+ var helper = new CommandMessageFieldEncryptor(_documentFieldEncryptor, _messageEncoderSettings);
+ return await helper.EncryptFieldsAsync(_databaseNamespace.DatabaseName, unencryptedRequestMessage, cancellationToken).ConfigureAwait(false);
+ }
+ }
+
private CommandRequestMessage CreateCommandMessage(ConnectionDescription connectionDescription)
{
var requestId = RequestMessage.GetNextRequestId();
@@ -188,28 +262,24 @@ private Type0CommandMessageSection CreateType0Section(ConnectionDe
{
var extraElements = new List();
- var dbElement = new BsonElement("$db", _databaseNamespace.DatabaseName);
- extraElements.Add(dbElement);
+ addIfNotAlreadyAdded("$db", _databaseNamespace.DatabaseName);
if (connectionDescription.IsMasterResult.ServerType != ServerType.Standalone
&& _readPreference != null
&& _readPreference != ReadPreference.Primary)
{
var readPreferenceDocument = QueryHelper.CreateReadPreferenceDocument(_readPreference);
- var readPreferenceElement = new BsonElement("$readPreference", readPreferenceDocument);
- extraElements.Add(readPreferenceElement);
+ addIfNotAlreadyAdded("$readPreference", readPreferenceDocument);
}
if (_session.Id != null)
{
- var lsidElement = new BsonElement("lsid", _session.Id);
- extraElements.Add(lsidElement);
+ addIfNotAlreadyAdded("lsid", _session.Id);
}
if (_session.ClusterTime != null)
{
- var clusterTimeElement = new BsonElement("$clusterTime", _session.ClusterTime);
- extraElements.Add(clusterTimeElement);
+ addIfNotAlreadyAdded("$clusterTime", _session.ClusterTime);
}
Action writerSettingsConfigurator = s => s.GuidRepresentation = GuidRepresentation.Unspecified;
@@ -217,21 +287,29 @@ private Type0CommandMessageSection CreateType0Section(ConnectionDe
if (_session.IsInTransaction)
{
var transaction = _session.CurrentTransaction;
- extraElements.Add(new BsonElement("txnNumber", transaction.TransactionNumber));
+ addIfNotAlreadyAdded("txnNumber", transaction.TransactionNumber);
if (transaction.State == CoreTransactionState.Starting)
{
- extraElements.Add(new BsonElement("startTransaction", true));
+ addIfNotAlreadyAdded("startTransaction", true);
var readConcern = ReadConcernHelper.GetReadConcernForFirstCommandInTransaction(_session, connectionDescription);
if (readConcern != null)
{
- extraElements.Add(new BsonElement("readConcern", readConcern));
+ addIfNotAlreadyAdded("readConcern", readConcern);
}
}
- extraElements.Add(new BsonElement("autocommit", false));
+ addIfNotAlreadyAdded("autocommit", false);
}
var elementAppendingSerializer = new ElementAppendingSerializer(BsonDocumentSerializer.Instance, extraElements, writerSettingsConfigurator);
return new Type0CommandMessageSection(_command, elementAppendingSerializer);
+
+ void addIfNotAlreadyAdded(string key, BsonValue value)
+ {
+ if (!_command.Contains(key))
+ {
+ extraElements.Add(new BsonElement(key, value));
+ }
+ }
}
private bool IsRetryableWriteExceptionAndDeploymentDoesNotSupportRetryableWrites(MongoCommandException exception)
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryCommandFieldEncryptor.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryCommandFieldEncryptor.cs
new file mode 100644
index 00000000000..66eec78143e
--- /dev/null
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryCommandFieldEncryptor.cs
@@ -0,0 +1,44 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace MongoDB.Driver.Core.WireProtocol
+{
+ ///
+ /// Interface for decrypting fields in a binary document.
+ ///
+ public interface IBinaryCommandFieldEncryptor
+ {
+ ///
+ /// Encrypts the fields.
+ ///
+ /// The database name.
+ /// The unencrypted command bytes.
+ /// The cancellation token.
+ /// An encrypted document.
+ byte[] EncryptFields(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken);
+
+ ///
+ /// Encrypts the fields asynchronously.
+ ///
+ /// The database name.
+ /// The unencrypted command bytes.
+ /// The cancellation token.
+ /// An encrypted document.
+ Task EncryptFieldsAsync(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken);
+ }
+}
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryDocumentFieldDecryptor.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryDocumentFieldDecryptor.cs
new file mode 100644
index 00000000000..4b9c99e97ab
--- /dev/null
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/IBinaryDocumentFieldDecryptor.cs
@@ -0,0 +1,42 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Threading;
+using System.Threading.Tasks;
+
+namespace MongoDB.Driver.Core.WireProtocol
+{
+ ///
+ /// Interface for decrypting fields in a binary document.
+ ///
+ public interface IBinaryDocumentFieldDecryptor
+ {
+ ///
+ /// Decrypts the fields.
+ ///
+ /// The encrypted document bytes.
+ /// The cancellation token.
+ /// An unencrypted document.
+ byte[] DecryptFields(byte[] encryptedDocumentBytes, CancellationToken cancellationToken);
+
+ ///
+ /// Decrypts the fields asynchronously.
+ ///
+ /// The encrypted document bytes.
+ /// The cancellation token.
+ /// An unencrypted document.
+ Task DecryptFieldsAsync(byte[] encryptedDocumentBytes, CancellationToken cancellationToken);
+ }
+}
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/CommandMessageBinaryEncoder.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/CommandMessageBinaryEncoder.cs
index 97ef64359ce..6d37bb3b113 100644
--- a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/CommandMessageBinaryEncoder.cs
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/CommandMessageBinaryEncoder.cs
@@ -296,7 +296,10 @@ private void WriteType1Section(BsonBinaryWriter writer, Type1CommandMessageSecti
stream.WriteCString(section.Identifier);
var batch = section.Documents;
- var maxDocumentSize = section.MaxDocumentSize ?? writer.Settings.MaxDocumentSize;
+ var maxDocumentSize =
+ IsEncryptionConfigured && MaxDocumentSize.HasValue
+ ? MaxDocumentSize.Value
+ : section.MaxDocumentSize ?? writer.Settings.MaxDocumentSize;
writer.PushSettings(s => ((BsonBinaryWriterSettings)s).MaxDocumentSize = maxDocumentSize);
writer.PushElementNameValidator(section.ElementNameValidator);
try
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/MessageBinaryEncoderBase.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/MessageBinaryEncoderBase.cs
index d4ac861a779..6735d5954ec 100644
--- a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/MessageBinaryEncoderBase.cs
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/BinaryEncoders/MessageBinaryEncoderBase.cs
@@ -68,6 +68,20 @@ protected UTF8Encoding Encoding
}
}
+ ///
+ /// Gets a flag whether encryption has been configured.
+ ///
+ ///
+ /// The flag whether encryption is configured or not.
+ ///
+ protected bool IsEncryptionConfigured
+ {
+ get
+ {
+ return _encoderSettings?.GetOrDefault(MessageEncoderSettingsName.BinaryDocumentFieldEncryptor, null) != null;
+ }
+ }
+
///
/// Gets the maximum size of the document.
///
diff --git a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/MessageEncoderSettings.cs b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/MessageEncoderSettings.cs
index 2c7122aba05..a669475cc9d 100644
--- a/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/MessageEncoderSettings.cs
+++ b/src/MongoDB.Driver.Core/Core/WireProtocol/Messages/Encoders/MessageEncoderSettings.cs
@@ -13,12 +13,8 @@
* limitations under the License.
*/
-using System;
using System.Collections;
using System.Collections.Generic;
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
namespace MongoDB.Driver.Core.WireProtocol.Messages.Encoders
{
@@ -28,6 +24,16 @@ namespace MongoDB.Driver.Core.WireProtocol.Messages.Encoders
public static class MessageEncoderSettingsName
{
// encoder settings used by the binary encoders
+ ///
+ /// The name of the binary document field decryptor setting.
+ ///
+ public const string BinaryDocumentFieldDecryptor = "BinaryDocumentFieldDecryptor";
+
+ ///
+ /// The name of the binary document field encryptor setting.
+ ///
+ public const string BinaryDocumentFieldEncryptor = "BinaryDocumentFieldEncryptor";
+
///
/// The name of the FixOldBinarySubTypeOnInput setting.
///
@@ -173,5 +179,15 @@ public T GetOrDefault(string name, T defaultValue)
return defaultValue;
}
}
+
+ ///
+ /// Sets the specified setting.
+ ///
+ /// The name.
+ /// The value.
+ public void Set(string name, object value)
+ {
+ _settings[name] = value;
+ }
}
}
diff --git a/src/MongoDB.Driver.Core/MongoDB.Driver.Core.csproj b/src/MongoDB.Driver.Core/MongoDB.Driver.Core.csproj
index dada86120dc..acf591fc50f 100644
--- a/src/MongoDB.Driver.Core/MongoDB.Driver.Core.csproj
+++ b/src/MongoDB.Driver.Core/MongoDB.Driver.Core.csproj
@@ -45,6 +45,7 @@
+
diff --git a/src/MongoDB.Driver.Legacy/MongoServerSettings.cs b/src/MongoDB.Driver.Legacy/MongoServerSettings.cs
index 2b01f48ce7c..df9c6ae0358 100644
--- a/src/MongoDB.Driver.Legacy/MongoServerSettings.cs
+++ b/src/MongoDB.Driver.Legacy/MongoServerSettings.cs
@@ -1062,6 +1062,7 @@ internal ClusterKey ToClusterKey()
_heartbeatInterval,
_heartbeatTimeout,
_ipv6,
+ kmsProviders: null, // not supported for legacy
_localThreshold,
_maxConnectionIdleTime,
_maxConnectionLifeTime,
@@ -1069,6 +1070,7 @@ internal ClusterKey ToClusterKey()
_minConnectionPoolSize,
MongoDefaults.TcpReceiveBufferSize,
_replicaSetName,
+ schemaMap: null, // not supported for legacy
_scheme,
_sdamLogFilename,
MongoDefaults.TcpSendBufferSize,
diff --git a/src/MongoDB.Driver/ClusterKey.cs b/src/MongoDB.Driver/ClusterKey.cs
index 7e09a441e26..8f42a81b40b 100644
--- a/src/MongoDB.Driver/ClusterKey.cs
+++ b/src/MongoDB.Driver/ClusterKey.cs
@@ -16,7 +16,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
+using MongoDB.Bson;
using MongoDB.Driver.Core.Configuration;
+using MongoDB.Driver.Encryption;
using MongoDB.Shared;
namespace MongoDB.Driver
@@ -35,6 +37,7 @@ internal class ClusterKey
private readonly TimeSpan _heartbeatInterval;
private readonly TimeSpan _heartbeatTimeout;
private readonly bool _ipv6;
+ private readonly IReadOnlyDictionary> _kmsProviders;
private readonly TimeSpan _localThreshold;
private readonly TimeSpan _maxConnectionIdleTime;
private readonly TimeSpan _maxConnectionLifeTime;
@@ -42,6 +45,7 @@ internal class ClusterKey
private readonly int _minConnectionPoolSize;
private readonly int _receiveBufferSize;
private readonly string _replicaSetName;
+ private readonly IReadOnlyDictionary _schemaMap;
private readonly ConnectionStringScheme _scheme;
private readonly string _sdamLogFilename;
private readonly int _sendBufferSize;
@@ -65,6 +69,7 @@ internal class ClusterKey
TimeSpan heartbeatInterval,
TimeSpan heartbeatTimeout,
bool ipv6,
+ IReadOnlyDictionary> kmsProviders,
TimeSpan localThreshold,
TimeSpan maxConnectionIdleTime,
TimeSpan maxConnectionLifeTime,
@@ -72,6 +77,7 @@ internal class ClusterKey
int minConnectionPoolSize,
int receiveBufferSize,
string replicaSetName,
+ IReadOnlyDictionary schemaMap,
ConnectionStringScheme scheme,
string sdamLogFilename,
int sendBufferSize,
@@ -93,6 +99,7 @@ internal class ClusterKey
_heartbeatInterval = heartbeatInterval;
_heartbeatTimeout = heartbeatTimeout;
_ipv6 = ipv6;
+ _kmsProviders = kmsProviders;
_localThreshold = localThreshold;
_maxConnectionIdleTime = maxConnectionIdleTime;
_maxConnectionLifeTime = maxConnectionLifeTime;
@@ -100,6 +107,7 @@ internal class ClusterKey
_minConnectionPoolSize = minConnectionPoolSize;
_receiveBufferSize = receiveBufferSize;
_replicaSetName = replicaSetName;
+ _schemaMap = schemaMap;
_scheme = scheme;
_sdamLogFilename = sdamLogFilename;
_sendBufferSize = sendBufferSize;
@@ -125,6 +133,7 @@ internal class ClusterKey
public TimeSpan HeartbeatInterval { get { return _heartbeatInterval; } }
public TimeSpan HeartbeatTimeout { get { return _heartbeatTimeout; } }
public bool IPv6 { get { return _ipv6; } }
+ public IReadOnlyDictionary> KmsProviders { get { return _kmsProviders; } }
public TimeSpan LocalThreshold { get { return _localThreshold; } }
public TimeSpan MaxConnectionIdleTime { get { return _maxConnectionIdleTime; } }
public TimeSpan MaxConnectionLifeTime { get { return _maxConnectionLifeTime; } }
@@ -132,8 +141,9 @@ internal class ClusterKey
public int MinConnectionPoolSize { get { return _minConnectionPoolSize; } }
public int ReceiveBufferSize { get { return _receiveBufferSize; } }
public string ReplicaSetName { get { return _replicaSetName; } }
+ public IReadOnlyDictionary SchemaMap { get { return _schemaMap; } }
public ConnectionStringScheme Scheme { get { return _scheme; } }
- public string SdamLogFilename { get { return _sdamLogFilename; }}
+ public string SdamLogFilename { get { return _sdamLogFilename; } }
public int SendBufferSize { get { return _sendBufferSize; } }
public IReadOnlyList Servers { get { return _servers; } }
public TimeSpan ServerSelectionTimeout { get { return _serverSelectionTimeout; } }
@@ -172,6 +182,7 @@ public override bool Equals(object obj)
_heartbeatInterval == rhs._heartbeatInterval &&
_heartbeatTimeout == rhs._heartbeatTimeout &&
_ipv6 == rhs._ipv6 &&
+ KmsProvidersHelper.Equals(_kmsProviders, rhs.KmsProviders) &&
_localThreshold == rhs._localThreshold &&
_maxConnectionIdleTime == rhs._maxConnectionIdleTime &&
_maxConnectionLifeTime == rhs._maxConnectionLifeTime &&
@@ -179,6 +190,7 @@ public override bool Equals(object obj)
_minConnectionPoolSize == rhs._minConnectionPoolSize &&
_receiveBufferSize == rhs._receiveBufferSize &&
_replicaSetName == rhs._replicaSetName &&
+ _schemaMap.IsEquivalentTo(rhs._schemaMap, object.Equals) &&
_scheme == rhs._scheme &&
_sdamLogFilename == rhs._sdamLogFilename &&
_sendBufferSize == rhs._sendBufferSize &&
diff --git a/src/MongoDB.Driver/ClusterRegistry.cs b/src/MongoDB.Driver/ClusterRegistry.cs
index b1085d38a39..91b8f6ad295 100644
--- a/src/MongoDB.Driver/ClusterRegistry.cs
+++ b/src/MongoDB.Driver/ClusterRegistry.cs
@@ -84,10 +84,12 @@ private ClusterSettings ConfigureCluster(ClusterSettings settings, ClusterKey cl
return settings.With(
connectionMode: clusterKey.ConnectionMode.ToCore(),
endPoints: Optional.Enumerable(endPoints),
+ kmsProviders: Optional.Create(clusterKey.KmsProviders),
replicaSetName: clusterKey.ReplicaSetName,
maxServerSelectionWaitQueueSize: clusterKey.WaitQueueSize,
serverSelectionTimeout: clusterKey.ServerSelectionTimeout,
postServerSelector: new LatencyLimitingServerSelector(clusterKey.LocalThreshold),
+ schemaMap: Optional.Create(clusterKey.SchemaMap),
scheme: clusterKey.Scheme);
}
diff --git a/src/MongoDB.Driver/Encryption/AutoEncryptionLibMongoController.cs b/src/MongoDB.Driver/Encryption/AutoEncryptionLibMongoController.cs
new file mode 100644
index 00000000000..95bca8afb11
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/AutoEncryptionLibMongoController.cs
@@ -0,0 +1,227 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Driver.Core.Misc;
+using MongoDB.Driver.Core.WireProtocol;
+using MongoDB.Libmongocrypt;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal sealed class AutoEncryptionLibMongoCryptController : LibMongoCryptControllerBase, IBinaryDocumentFieldDecryptor, IBinaryCommandFieldEncryptor
+ {
+ // private fields
+ private readonly IMongoClient _client;
+ private readonly IMongoClient _mongocryptdClient;
+ private readonly MongocryptdFactory _mongocryptdFactory;
+
+ // constructors
+ public AutoEncryptionLibMongoCryptController(
+ IMongoClient client,
+ CryptClient cryptClient,
+ AutoEncryptionOptions autoEncryptionOptions)
+ : base(
+ Ensure.IsNotNull(cryptClient, nameof(cryptClient)),
+ Ensure.IsNotNull(autoEncryptionOptions, nameof(autoEncryptionOptions)).KeyVaultClient ?? client,
+ Ensure.IsNotNull(Ensure.IsNotNull(autoEncryptionOptions, nameof(autoEncryptionOptions)).KeyVaultNamespace, nameof(autoEncryptionOptions.KeyVaultNamespace)))
+ {
+ _client = Ensure.IsNotNull(client, nameof(client)); // _client might not be fully constructed at this point, don't call any instance methods on it yet
+ _mongocryptdFactory = new MongocryptdFactory(autoEncryptionOptions.ExtraOptions);
+ _mongocryptdClient = _mongocryptdFactory.CreateMongocryptdClient();
+ }
+
+ // public methods
+ public byte[] DecryptFields(byte[] encryptedDocumentBytes, CancellationToken cancellationToken)
+ {
+ try
+ {
+ using (var context = _cryptClient.StartDecryptionContext(encryptedDocumentBytes))
+ {
+ return ProcessStates(context, databaseName: null, cancellationToken);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public async Task DecryptFieldsAsync(byte[] encryptedDocumentBytes, CancellationToken cancellationToken)
+ {
+ try
+ {
+ using (var context = _cryptClient.StartDecryptionContext(encryptedDocumentBytes))
+ {
+ return await ProcessStatesAsync(context, databaseName: null, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public byte[] EncryptFields(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken)
+ {
+ try
+ {
+ using (var context = _cryptClient.StartEncryptionContext(databaseName, unencryptedCommandBytes))
+ {
+ return ProcessStates(context, databaseName, cancellationToken);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public async Task EncryptFieldsAsync(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken)
+ {
+ try
+ {
+ using (var context = _cryptClient.StartEncryptionContext(databaseName, unencryptedCommandBytes))
+ {
+ return await ProcessStatesAsync(context, databaseName, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ // protected methods
+ protected override void ProcessState(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ switch (context.State)
+ {
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_COLLINFO:
+ ProcessNeedCollectionInfoState(context, databaseName, cancellationToken);
+ break;
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_MARKINGS:
+ ProcessNeedMongoMarkingsState(context, databaseName, cancellationToken);
+ break;
+ default:
+ base.ProcessState(context, databaseName, cancellationToken);
+ break;
+ }
+ }
+
+ protected override async Task ProcessStateAsync(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ switch (context.State)
+ {
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_COLLINFO:
+ await ProcessNeedCollectionInfoStateAsync(context, databaseName, cancellationToken).ConfigureAwait(false);
+ break;
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_MARKINGS:
+ await ProcessNeedMongoMarkingsStateAsync(context, databaseName, cancellationToken).ConfigureAwait(false);
+ break;
+ default:
+ await base.ProcessStateAsync(context, databaseName, cancellationToken).ConfigureAwait(false);
+ break;
+ }
+ }
+
+ // private methods
+ private void ProcessNeedCollectionInfoState(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ var database = _client.GetDatabase(databaseName);
+ var filterBytes = context.GetOperation().ToArray();
+ var filterDocument = new RawBsonDocument(filterBytes);
+ var filter = new BsonDocumentFilterDefinition(filterDocument);
+ var options = new ListCollectionsOptions { Filter = filter };
+ var cursor = database.ListCollections(options, cancellationToken);
+ var results = cursor.ToList(cancellationToken);
+ FeedResults(context, results);
+ }
+
+ private async Task ProcessNeedCollectionInfoStateAsync(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ var database = _client.GetDatabase(databaseName);
+ var filterBytes = context.GetOperation().ToArray();
+ var filterDocument = new RawBsonDocument(filterBytes);
+ var filter = new BsonDocumentFilterDefinition(filterDocument);
+ var options = new ListCollectionsOptions { Filter = filter };
+ var cursor = await database.ListCollectionsAsync(options, cancellationToken).ConfigureAwait(false);
+ var results = await cursor.ToListAsync(cancellationToken).ConfigureAwait(false);
+ FeedResults(context, results);
+ }
+
+ private void ProcessNeedMongoMarkingsState(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ var database = _mongocryptdClient.GetDatabase(databaseName);
+ var commandBytes = context.GetOperation().ToArray();
+ var commandDocument = new RawBsonDocument(commandBytes);
+ var command = new BsonDocumentCommand(commandDocument);
+
+ BsonDocument response = null;
+ for (var attempt = 1; response == null; attempt++)
+ {
+ try
+ {
+ response = database.RunCommand(command, cancellationToken: cancellationToken);
+ }
+ catch (TimeoutException) when (attempt == 1)
+ {
+ _mongocryptdFactory.SpawnMongocryptdProcessIfRequired();
+ }
+ }
+
+ RestoreDbNodeInResponse(commandDocument, response);
+ FeedResult(context, response);
+ }
+
+ private async Task ProcessNeedMongoMarkingsStateAsync(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ var database = _mongocryptdClient.GetDatabase(databaseName);
+ var commandBytes = context.GetOperation().ToArray();
+ var commandDocument = new RawBsonDocument(commandBytes);
+ var command = new BsonDocumentCommand(commandDocument);
+
+ BsonDocument response = null;
+ for (var attempt = 1; response == null; attempt++)
+ {
+ try
+ {
+ response = await database.RunCommandAsync(command, cancellationToken: cancellationToken).ConfigureAwait(false);
+ }
+ catch (TimeoutException) when (attempt == 1)
+ {
+ _mongocryptdFactory.SpawnMongocryptdProcessIfRequired();
+ }
+ }
+
+ RestoreDbNodeInResponse(commandDocument, response);
+ FeedResult(context, response);
+ }
+
+ private void RestoreDbNodeInResponse(BsonDocument request, BsonDocument response)
+ {
+ if (request.TryGetElement("$db", out var db))
+ {
+ var result = response["result"].AsBsonDocument;
+ if (!result.Contains("$db"))
+ {
+ result.Add(db);
+ }
+ }
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/AutoEncryptionOptions.cs b/src/MongoDB.Driver/Encryption/AutoEncryptionOptions.cs
new file mode 100644
index 00000000000..2c854ca923a
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/AutoEncryptionOptions.cs
@@ -0,0 +1,229 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Collections.Generic;
+using System.Linq;
+using System.Text;
+using MongoDB.Bson;
+using MongoDB.Driver.Core.Misc;
+using MongoDB.Shared;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Auto encryption options.
+ ///
+ public class AutoEncryptionOptions
+ {
+ // private fields
+ private readonly bool _bypassAutoEncryption;
+ private readonly IReadOnlyDictionary _extraOptions;
+ private readonly IMongoClient _keyVaultClient;
+ private readonly CollectionNamespace _keyVaultNamespace;
+ private readonly IReadOnlyDictionary> _kmsProviders;
+ private readonly IReadOnlyDictionary _schemaMap;
+
+ // constructors
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The keyVault namespace.
+ /// The kms providers.
+ /// The bypass auto encryption flag.
+ /// The extra options.
+ /// The keyVault client.
+ /// The schema map.
+ public AutoEncryptionOptions(
+ CollectionNamespace keyVaultNamespace,
+ IReadOnlyDictionary> kmsProviders,
+ Optional bypassAutoEncryption = default,
+ Optional> extraOptions = default,
+ Optional keyVaultClient = default,
+ Optional> schemaMap = default)
+ {
+ _keyVaultNamespace = Ensure.IsNotNull(keyVaultNamespace, nameof(keyVaultNamespace));
+ _kmsProviders = Ensure.IsNotNull(kmsProviders, nameof(kmsProviders));
+ _bypassAutoEncryption = bypassAutoEncryption.WithDefault(false);
+ _extraOptions = extraOptions.WithDefault(null);
+ _keyVaultClient = keyVaultClient.WithDefault(null);
+ _schemaMap = schemaMap.WithDefault(null);
+
+ EncryptionExtraOptionsValidator.EnsureThatExtraOptionsAreValid(_extraOptions);
+ KmsProvidersHelper.EnsureKmsProvidersAreValid(_kmsProviders);
+ }
+
+ // public properties
+ ///
+ /// Gets a value indicating whether to bypass automatic encryption.
+ ///
+ ///
+ /// true if automatic encryption should be bypasssed; otherwise, false.
+ ///
+ public bool BypassAutoEncryption => _bypassAutoEncryption;
+
+ ///
+ /// Gets the extra options.
+ ///
+ ///
+ /// The extra options.
+ ///
+ public IReadOnlyDictionary ExtraOptions => _extraOptions;
+
+ ///
+ /// Gets the key vault client.
+ ///
+ ///
+ /// The key vault client.
+ ///
+ public IMongoClient KeyVaultClient => _keyVaultClient;
+
+ ///
+ /// Gets the key vault namespace.
+ ///
+ ///
+ /// The key vault namespace.
+ ///
+ public CollectionNamespace KeyVaultNamespace => _keyVaultNamespace;
+
+ ///
+ /// Gets the KMS providers.
+ ///
+ ///
+ /// The KMS providers.
+ ///
+ public IReadOnlyDictionary> KmsProviders => _kmsProviders;
+
+ ///
+ /// Gets the schema map.
+ ///
+ ///
+ /// The schema map.
+ ///
+ public IReadOnlyDictionary SchemaMap => _schemaMap;
+
+ ///
+ /// Returns a new instance of the class.
+ ///
+ /// The keyVault namespace.
+ /// The kms providers.
+ /// The bypass auto encryption flag.
+ /// The extra options.
+ /// The keyVault client.
+ /// The schema map.
+ /// A new instance of .
+ public AutoEncryptionOptions With(
+ Optional keyVaultNamespace = default,
+ Optional>> kmsProviders = default,
+ Optional bypassAutoEncryption = default,
+ Optional> extraOptions = default,
+ Optional keyVaultClient = default,
+ Optional> schemaMap = default)
+ {
+ return new AutoEncryptionOptions(
+ keyVaultNamespace.WithDefault(_keyVaultNamespace),
+ kmsProviders.WithDefault(_kmsProviders),
+ bypassAutoEncryption.WithDefault(_bypassAutoEncryption),
+ Optional.Create(extraOptions.WithDefault(_extraOptions)),
+ Optional.Create(keyVaultClient.WithDefault(_keyVaultClient)),
+ Optional.Create(schemaMap.WithDefault(_schemaMap)));
+ }
+
+ ///
+ public override bool Equals(object obj)
+ {
+ if (object.ReferenceEquals(obj, null) || GetType() != obj.GetType()) { return false; }
+ var rhs = (AutoEncryptionOptions)obj;
+
+ return
+ _bypassAutoEncryption.Equals(rhs._bypassAutoEncryption) &&
+ ExtraOptionsEquals(_extraOptions, rhs._extraOptions) &&
+ object.ReferenceEquals(_keyVaultClient, rhs._keyVaultClient) &&
+ _keyVaultNamespace.Equals(rhs._keyVaultNamespace) &&
+ KmsProvidersHelper.Equals(_kmsProviders, rhs._kmsProviders) &&
+ _schemaMap.IsEquivalentTo(rhs._schemaMap, object.Equals);
+ }
+
+ ///
+ public override int GetHashCode()
+ {
+ return new Hasher()
+ .Hash(_bypassAutoEncryption)
+ .HashElements(_extraOptions)
+ .Hash(_keyVaultClient)
+ .Hash(_keyVaultNamespace)
+ .HashElements(_kmsProviders)
+ .HashElements(_schemaMap)
+ .GetHashCode();
+ }
+
+ ///
+ public override string ToString()
+ {
+ var sb = new StringBuilder();
+ sb.Append("{ ");
+ sb.AppendFormat("BypassAutoEncryption : {0}, ", _bypassAutoEncryption);
+ sb.AppendFormat("KmsProviders : {0}, ", _kmsProviders.ToJson());
+ if (_keyVaultNamespace != null)
+ {
+ sb.AppendFormat("KeyVaultNamespace : \"{0}\", ", _keyVaultNamespace.FullName);
+ }
+ if (_extraOptions != null)
+ {
+ sb.AppendFormat("ExtraOptions : {0}, ", _extraOptions.ToJson());
+ }
+ if (_schemaMap != null)
+ {
+ sb.AppendFormat("SchemaMap : {0}, ", _schemaMap.ToJson());
+ }
+ sb.Remove(sb.Length - 2, 2);
+ sb.Append(" }");
+ return sb.ToString();
+ }
+
+ // private methods
+ private bool ExtraOptionsEquals(IReadOnlyDictionary x, IReadOnlyDictionary y)
+ {
+ return x.IsEquivalentTo(y, ExtraOptionEquals);
+ }
+
+ private bool ExtraOptionEquals(object x, object y)
+ {
+ if (object.ReferenceEquals(x, y))
+ {
+ return true;
+ }
+
+ if (object.ReferenceEquals(x, null) || object.ReferenceEquals(y, null))
+ {
+ return false;
+ }
+
+ if (x.GetType() != y.GetType())
+ {
+ return false;
+ }
+
+ if (x is IEnumerable enumerableX)
+ {
+ var enumerableY = (IEnumerable)y;
+ return enumerableX.SequenceEqual(enumerableY);
+ }
+ else
+ {
+ return x.Equals(y);
+ }
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/ClientEncryption.cs b/src/MongoDB.Driver/Encryption/ClientEncryption.cs
new file mode 100644
index 00000000000..1892450f0ea
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/ClientEncryption.cs
@@ -0,0 +1,151 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Driver.Core.Clusters;
+using MongoDB.Libmongocrypt;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Explicit client encryption.
+ ///
+ public sealed class ClientEncryption : IDisposable
+ {
+ // private fields
+ private readonly CryptClient _cryptClient;
+ private bool _disposed;
+ private readonly ExplicitEncryptionLibMongoCryptController _libMongoCryptController;
+
+ // constructors
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The client encryption options.
+ public ClientEncryption(ClientEncryptionOptions clientEncryptionOptions)
+ {
+ _cryptClient = CryptClientCreator.CreateCryptClient(
+ kmsProviders: clientEncryptionOptions.KmsProviders,
+ schemaMap: null);
+ _libMongoCryptController = new ExplicitEncryptionLibMongoCryptController(
+ _cryptClient,
+ clientEncryptionOptions);
+ }
+
+ // public methods
+ ///
+ /// Creates a data key.
+ ///
+ /// The kms provider.
+ /// The data key options.
+ /// The cancellation token.
+ /// A data key.
+ public Guid CreateDataKey(string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken)
+ {
+ return _libMongoCryptController.CreateDataKey(
+ kmsProvider,
+ dataKeyOptions.AlternateKeyNames,
+ dataKeyOptions.MasterKey,
+ cancellationToken);
+ }
+
+ ///
+ /// Creates a data key.
+ ///
+ /// The kms provider.
+ /// The data key options.
+ /// The cancellation token.
+ /// A data key.
+ public Task CreateDataKeyAsync(string kmsProvider, DataKeyOptions dataKeyOptions, CancellationToken cancellationToken)
+ {
+ return _libMongoCryptController.CreateDataKeyAsync(
+ kmsProvider,
+ dataKeyOptions.AlternateKeyNames,
+ dataKeyOptions.MasterKey,
+ cancellationToken);
+ }
+
+ ///
+ /// Decrypts the specified value.
+ ///
+ /// The value.
+ /// The cancellation token.
+ /// The decrypted value.
+ public BsonValue Decrypt(BsonBinaryData value, CancellationToken cancellationToken)
+ {
+ return _libMongoCryptController.DecryptField(value, cancellationToken);
+ }
+
+ ///
+ /// Decrypts the specified value.
+ ///
+ /// The value.
+ /// The cancellation token.
+ /// The decrypted value.
+ public Task DecryptAsync(BsonBinaryData value, CancellationToken cancellationToken)
+ {
+ return _libMongoCryptController.DecryptFieldAsync(value, cancellationToken);
+ }
+
+ ///
+ public void Dispose()
+ {
+ if (!_disposed)
+ {
+ _cryptClient.Dispose();
+ _disposed = true;
+ }
+ }
+
+ ///
+ /// Encrypts the specified value.
+ ///
+ /// The value.
+ /// The encrypt options.
+ /// The cancellation token.
+ /// The encrypted value.
+ public BsonBinaryData Encrypt(BsonValue value, EncryptOptions encryptOptions, CancellationToken cancellationToken)
+ {
+ var algorithm = (EncryptionAlgorithm)Enum.Parse(typeof(EncryptionAlgorithm), encryptOptions.Algorithm);
+ return _libMongoCryptController.EncryptField(
+ value,
+ encryptOptions.KeyId,
+ encryptOptions.AlternateKeyName,
+ algorithm,
+ cancellationToken);
+ }
+
+ ///
+ /// Encrypts the specified value.
+ ///
+ /// The value.
+ /// The encrypt options.
+ /// The cancellation token.
+ /// The encrypted value.
+ public Task EncryptAsync(BsonValue value, EncryptOptions encryptOptions, CancellationToken cancellationToken)
+ {
+ var algorithm = (EncryptionAlgorithm)Enum.Parse(typeof(EncryptionAlgorithm), encryptOptions.Algorithm);
+ return _libMongoCryptController.EncryptFieldAsync(
+ value,
+ encryptOptions.KeyId,
+ encryptOptions.AlternateKeyName,
+ algorithm,
+ cancellationToken);
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/ClientEncryptionOptions.cs b/src/MongoDB.Driver/Encryption/ClientEncryptionOptions.cs
new file mode 100644
index 00000000000..1a50bc43ee8
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/ClientEncryptionOptions.cs
@@ -0,0 +1,92 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Collections.Generic;
+using MongoDB.Driver.Core.Misc;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Client encryption options.
+ ///
+ public class ClientEncryptionOptions
+ {
+ // private fields
+ private readonly IMongoClient _keyVaultClient;
+ private readonly CollectionNamespace _keyVaultNamespace;
+ private readonly IReadOnlyDictionary> _kmsProviders;
+
+ // constructors
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The key vault client.
+ /// The key vault namespace.
+ /// The KMS providers.
+ public ClientEncryptionOptions(
+ IMongoClient keyVaultClient,
+ CollectionNamespace keyVaultNamespace,
+ IReadOnlyDictionary> kmsProviders)
+ {
+ _keyVaultClient = Ensure.IsNotNull(keyVaultClient, nameof(keyVaultClient));
+ _keyVaultNamespace = Ensure.IsNotNull(keyVaultNamespace, nameof(keyVaultNamespace));
+ _kmsProviders = Ensure.IsNotNull(kmsProviders, nameof(kmsProviders));
+
+ KmsProvidersHelper.EnsureKmsProvidersAreValid(_kmsProviders);
+ }
+
+ // public properties
+ ///
+ /// Gets the key vault client.
+ ///
+ ///
+ /// The key vault client.
+ ///
+ public IMongoClient KeyVaultClient => _keyVaultClient;
+
+ ///
+ /// Gets the key vault namespace.
+ ///
+ ///
+ /// The key vault namespace.
+ ///
+ public CollectionNamespace KeyVaultNamespace => _keyVaultNamespace;
+
+ ///
+ /// Gets the KMS providers.
+ ///
+ ///
+ /// The KMS providers.
+ ///
+ public IReadOnlyDictionary> KmsProviders => _kmsProviders;
+
+ ///
+ /// Returns a new ClientEncryptionOptions instance with some settings changed.
+ ///
+ /// The key vault client.
+ /// The key vault namespace.
+ /// The KMS providers.
+ public ClientEncryptionOptions With(
+ Optional keyVaultClient = default,
+ Optional keyVaultNamespace = default,
+ Optional>> kmsProviders = default)
+ {
+ return new ClientEncryptionOptions(
+ keyVaultClient: keyVaultClient.WithDefault(_keyVaultClient),
+ keyVaultNamespace: keyVaultNamespace.WithDefault(_keyVaultNamespace),
+ kmsProviders: kmsProviders.WithDefault(_kmsProviders));
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/DataKeyOptions.cs b/src/MongoDB.Driver/Encryption/DataKeyOptions.cs
new file mode 100644
index 00000000000..c564c1b7cc1
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/DataKeyOptions.cs
@@ -0,0 +1,76 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Collections.Generic;
+using MongoDB.Bson;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Options for creating a data key.
+ ///
+ public class DataKeyOptions
+ {
+ // private fields
+ private readonly IReadOnlyList _alternateKeyNames;
+ private readonly BsonDocument _masterKey;
+
+ // constructors
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The alternate key names.
+ /// The master key.
+ public DataKeyOptions(
+ Optional> alternateKeyNames = default,
+ Optional masterKey = default)
+ {
+ _alternateKeyNames = alternateKeyNames.WithDefault(null);
+ _masterKey = masterKey.WithDefault(null);
+ }
+
+ ///
+ /// Gets the alternate key names.
+ ///
+ ///
+ /// The alternate key names.
+ ///
+ public IReadOnlyList AlternateKeyNames => _alternateKeyNames;
+
+ // public properties
+ ///
+ /// Gets the master key.
+ ///
+ ///
+ /// The master key.
+ ///
+ public BsonDocument MasterKey => _masterKey;
+
+ ///
+ /// Returns a new DataKeyOptions instance with some settings changed.
+ ///
+ /// The alternate key names.
+ /// The master key.
+ /// A new DataKeyOptions instance.
+ public DataKeyOptions With(
+ Optional> alternateKeyNames = default,
+ Optional masterKey = default)
+ {
+ return new DataKeyOptions(
+ alternateKeyNames: Optional.Create(alternateKeyNames.WithDefault(_alternateKeyNames)),
+ masterKey: Optional.Create(masterKey.WithDefault(_masterKey)));
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/EncryptOptions.cs b/src/MongoDB.Driver/Encryption/EncryptOptions.cs
new file mode 100644
index 00000000000..d6903989a67
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/EncryptOptions.cs
@@ -0,0 +1,99 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using MongoDB.Driver.Core.Misc;
+using System;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Encryption options for explicit encryption.
+ ///
+ public class EncryptOptions
+ {
+ // private fields
+ private readonly string _algorithm;
+ private readonly string _alternateKeyName;
+ private readonly Guid? _keyId;
+
+ // constructors
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The encryption algorithm.
+ /// The alternate key name.
+ /// The key Id.
+ public EncryptOptions(
+ string algorithm,
+ Optional alternateKeyName = default,
+ Optional keyId = default)
+ {
+ _algorithm = Ensure.IsNotNull(algorithm, nameof(algorithm));
+ _alternateKeyName = alternateKeyName.WithDefault(null);
+ _keyId = keyId.WithDefault(null);
+ EnsureThatOptionsAreValid();
+ }
+
+ // public properties
+ ///
+ /// Gets the algorithm.
+ ///
+ ///
+ /// The algorithm.
+ ///
+ public string Algorithm=> _algorithm;
+
+ ///
+ /// Gets the alternate key name.
+ ///
+ ///
+ /// The alternate key name.
+ ///
+ public string AlternateKeyName => _alternateKeyName;
+
+ ///
+ /// Gets the key identifier.
+ ///
+ ///
+ /// The key identifier.
+ ///
+ public Guid? KeyId => _keyId;
+
+ ///
+ /// Returns a new EncryptOptions instance with some settings changed.
+ ///
+ /// The encryption algorithm.
+ /// The alternate key name.
+ /// The keyId.
+ /// A new EncryptOptions instance.
+ public EncryptOptions With(
+ Optional algorithm = default,
+ Optional alternateKeyName = default,
+ Optional keyId = default)
+ {
+ return new EncryptOptions(
+ algorithm: algorithm.WithDefault(_algorithm),
+ alternateKeyName: alternateKeyName.WithDefault(_alternateKeyName),
+ keyId: keyId.WithDefault(_keyId));
+ }
+
+ // private methods
+ private void EnsureThatOptionsAreValid()
+ {
+ Ensure.That(!(!_keyId.HasValue && _alternateKeyName == null), "Key Id and AlternateKeyName may not both be null.");
+ Ensure.That(!(_keyId.HasValue && _alternateKeyName != null), "Key Id and AlternateKeyName may not both be set.");
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/ExplicitEncryptionLibMongoCryptController.cs b/src/MongoDB.Driver/Encryption/ExplicitEncryptionLibMongoCryptController.cs
new file mode 100644
index 00000000000..ddc8501c9d2
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/ExplicitEncryptionLibMongoCryptController.cs
@@ -0,0 +1,283 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Bson.IO;
+using MongoDB.Bson.Serialization.Serializers;
+using MongoDB.Driver.Core.Misc;
+using MongoDB.Libmongocrypt;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal sealed class ExplicitEncryptionLibMongoCryptController : LibMongoCryptControllerBase
+ {
+ // constructors
+ public ExplicitEncryptionLibMongoCryptController(
+ CryptClient cryptClient,
+ ClientEncryptionOptions clientEncryptionOptions)
+ : base(
+ Ensure.IsNotNull(cryptClient, nameof(cryptClient)),
+ Ensure.IsNotNull(Ensure.IsNotNull(clientEncryptionOptions, nameof(clientEncryptionOptions)).KeyVaultClient, nameof(clientEncryptionOptions.KeyVaultClient)),
+ Ensure.IsNotNull(Ensure.IsNotNull(clientEncryptionOptions, nameof(clientEncryptionOptions)).KeyVaultNamespace, nameof(clientEncryptionOptions.KeyVaultNamespace)))
+ {
+ }
+
+ // public methods
+ public Guid CreateDataKey(
+ string kmsProvider,
+ IReadOnlyList alternateKeyNames,
+ BsonDocument masterKey,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ var kmsKeyId = GetKmsKeyId(kmsProvider, alternateKeyNames, masterKey);
+
+ using (var context = _cryptClient.StartCreateDataKeyContext(kmsKeyId))
+ {
+ var wrappedKeyBytes = ProcessStates(context, _keyVaultNamespace.DatabaseNamespace.DatabaseName, cancellationToken);
+
+ var wrappedKeyDocument = new RawBsonDocument(wrappedKeyBytes);
+ var keyId = UnwrapKeyId(wrappedKeyDocument);
+
+ _keyVaultCollection.Value.InsertOne(wrappedKeyDocument, cancellationToken: cancellationToken);
+
+ return keyId;
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public async Task CreateDataKeyAsync(
+ string kmsProvider,
+ IReadOnlyList alternateKeyNames,
+ BsonDocument masterKey,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ var kmsKeyId = GetKmsKeyId(kmsProvider, alternateKeyNames, masterKey);
+
+ using (var context = _cryptClient.StartCreateDataKeyContext(kmsKeyId))
+ {
+ var wrappedKeyBytes = await ProcessStatesAsync(context, _keyVaultNamespace.DatabaseNamespace.DatabaseName, cancellationToken).ConfigureAwait(false);
+
+ var wrappedKeyDocument = new RawBsonDocument(wrappedKeyBytes);
+ var keyId = UnwrapKeyId(wrappedKeyDocument);
+
+ await _keyVaultCollection.Value.InsertOneAsync(wrappedKeyDocument, cancellationToken: cancellationToken).ConfigureAwait(false);
+
+ return keyId;
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public BsonValue DecryptField(BsonBinaryData encryptedValue, CancellationToken cancellationToken)
+ {
+ try
+ {
+ var wrappedValueBytes = GetWrappedValueBytes(encryptedValue);
+
+ using (var context = _cryptClient.StartExplicitDecryptionContext(wrappedValueBytes))
+ {
+ var wrappedBytes = ProcessStates(context, databaseName: null, cancellationToken);
+ return UnwrapDecryptedValue(wrappedBytes);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public async Task DecryptFieldAsync(BsonBinaryData wrappedBinaryValue, CancellationToken cancellationToken)
+ {
+ try
+ {
+ var wrappedValueBytes = GetWrappedValueBytes(wrappedBinaryValue);
+
+ using (var context = _cryptClient.StartExplicitDecryptionContext(wrappedValueBytes))
+ {
+ var wrappedBytes = await ProcessStatesAsync(context, databaseName: null, cancellationToken).ConfigureAwait(false);
+ return UnwrapDecryptedValue(wrappedBytes);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public BsonBinaryData EncryptField(
+ BsonValue value,
+ Guid? keyId,
+ string alternateKeyName,
+ EncryptionAlgorithm encryptionAlgorithm,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ var wrappedValueBytes = GetWrappedValueBytes(value);
+
+ CryptContext context;
+ if (keyId.HasValue && alternateKeyName != null)
+ {
+ throw new ArgumentException("keyId and alternateKeyName cannot both be provided.");
+ }
+ else if (keyId.HasValue)
+ {
+ var keyBytes = GuidConverter.ToBytes(keyId.Value, GuidRepresentation.Standard);
+ context = _cryptClient.StartExplicitEncryptionContextWithKeyId(keyBytes, encryptionAlgorithm, wrappedValueBytes);
+ }
+ else if (alternateKeyName != null)
+ {
+ var wrappedAlternateKeyNameBytes = GetWrappedAlternateKeyNameBytes(alternateKeyName);
+ context = _cryptClient.StartExplicitEncryptionContextWithKeyAltName(wrappedAlternateKeyNameBytes, encryptionAlgorithm, wrappedValueBytes);
+ }
+ else
+ {
+ throw new ArgumentException("Either keyId or alternateKeyName must be provided.");
+ }
+
+ using (context)
+ {
+ var wrappedBytes = ProcessStates(context, databaseName: null, cancellationToken);
+ return UnwrapEncryptedValue(wrappedBytes);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ public async Task EncryptFieldAsync(
+ BsonValue value,
+ Guid? keyId,
+ string alternateKeyName,
+ EncryptionAlgorithm encryptionAlgorithm,
+ CancellationToken cancellationToken)
+ {
+ try
+ {
+ var wrappedValueBytes = GetWrappedValueBytes(value);
+
+ CryptContext context;
+ if (keyId.HasValue && alternateKeyName != null)
+ {
+ throw new ArgumentException("keyId and alternateKeyName cannot both be provided.");
+ }
+ else if (keyId.HasValue)
+ {
+ var bytes = GuidConverter.ToBytes(keyId.Value, GuidRepresentation.Standard);
+ context = _cryptClient.StartExplicitEncryptionContextWithKeyId(bytes, encryptionAlgorithm, wrappedValueBytes);
+ }
+ else if (alternateKeyName != null)
+ {
+ var wrappedAlternateKeyNameBytes = GetWrappedAlternateKeyNameBytes(alternateKeyName);
+ context = _cryptClient.StartExplicitEncryptionContextWithKeyAltName(wrappedAlternateKeyNameBytes, encryptionAlgorithm, wrappedValueBytes);
+ }
+ else
+ {
+ throw new ArgumentException("Either keyId or alternateKeyName must be provided.");
+ }
+
+ using (context)
+ {
+ var wrappedBytes = await ProcessStatesAsync(context, databaseName: null, cancellationToken).ConfigureAwait(false);
+ return UnwrapEncryptedValue(wrappedBytes);
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoEncryptionException(ex);
+ }
+ }
+
+ // private methods
+ private IKmsKeyId GetKmsKeyId(string kmsProvider, IReadOnlyList alternateKeyNames, BsonDocument masterKey)
+ {
+ IEnumerable wrappedAlternateKeyNamesBytes = null;
+ if (alternateKeyNames != null)
+ {
+ wrappedAlternateKeyNamesBytes = alternateKeyNames.Select(GetWrappedAlternateKeyNameBytes);
+ }
+
+ switch (kmsProvider)
+ {
+ case "aws":
+ var customerMasterKey = masterKey["key"].ToString();
+ var region = masterKey["region"].ToString();
+ return wrappedAlternateKeyNamesBytes != null
+ ? new AwsKeyId(customerMasterKey, region, wrappedAlternateKeyNamesBytes)
+ : new AwsKeyId(customerMasterKey, region);
+ case "local":
+ return wrappedAlternateKeyNamesBytes != null ? new LocalKeyId(wrappedAlternateKeyNamesBytes) : new LocalKeyId();
+ default:
+ throw new ArgumentException($"Invalid kmsProvider {kmsProvider}.");
+ }
+ }
+
+ private byte[] GetWrappedAlternateKeyNameBytes(string value)
+ {
+ return
+ !string.IsNullOrWhiteSpace(value)
+ ? new BsonDocument("keyAltName", value).ToBson()
+ : null;
+ }
+
+ private byte[] GetWrappedValueBytes(BsonValue value)
+ {
+ var wrappedValue = new BsonDocument("v", value);
+ var writerSettings = BsonBinaryWriterSettings.Defaults.Clone();
+ writerSettings.GuidRepresentation = GuidRepresentation.Unspecified;
+ return wrappedValue.ToBson(writerSettings: writerSettings);
+ }
+
+ private BsonValue UnwrapDecryptedValue(byte[] wrappedBytes)
+ {
+ var wrappedDocument = new RawBsonDocument(wrappedBytes);
+ return wrappedDocument["v"];
+ }
+
+ private BsonBinaryData UnwrapEncryptedValue(byte[] encryptedWrappedBytes)
+ {
+ var wrappedDocument = new RawBsonDocument(encryptedWrappedBytes);
+ return wrappedDocument["v"].AsBsonBinaryData;
+ }
+
+ private Guid UnwrapKeyId(RawBsonDocument wrappedKeyDocument)
+ {
+ var keyId = wrappedKeyDocument["_id"].AsBsonBinaryData;
+ if (keyId.SubType != BsonBinarySubType.UuidStandard)
+ {
+ throw new InvalidOperationException($"KeyId sub type must be UuidStandard, not: {keyId.SubType}.");
+ }
+ return GuidConverter.FromBytes(keyId.Bytes, GuidRepresentation.Standard);
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/KmsProvidersHelper.cs b/src/MongoDB.Driver/Encryption/KmsProvidersHelper.cs
new file mode 100644
index 00000000000..910ebf4665a
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/KmsProvidersHelper.cs
@@ -0,0 +1,64 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using MongoDB.Driver.Core.Misc;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal static class KmsProvidersHelper
+ {
+ public static void EnsureKmsProvidersAreValid(IReadOnlyDictionary> kmsProviders)
+ {
+ foreach (var kmsProvider in kmsProviders)
+ {
+ foreach (var option in Ensure.IsNotNull(kmsProvider.Value, nameof(kmsProvider)))
+ {
+ var optionValue = Ensure.IsNotNull(option.Value, "kmsProviderOption");
+ var isValid = optionValue is byte[] || optionValue is string;
+ if (!isValid)
+ {
+ throw new ArgumentException($"Invalid kms provider option type: {optionValue.GetType().Name}.");
+ }
+ }
+ }
+ }
+
+ public static bool Equals(IReadOnlyDictionary> x, IReadOnlyDictionary> y)
+ {
+ return x.IsEquivalentTo(y, KmsProviderIsEquivalentTo);
+ }
+
+ // private methods
+ private static bool KmsProviderIsEquivalentTo(IReadOnlyDictionary x, IReadOnlyDictionary y)
+ {
+ return x.IsEquivalentTo(y, KmsProviderOptionEquals);
+ }
+
+ private static bool KmsProviderOptionEquals(object x, object y)
+ {
+ if (x is byte[] xBytes && y is byte[] yBytes)
+ {
+ return xBytes.SequenceEqual(yBytes);
+ }
+ else
+ {
+ return object.Equals(x, y);
+ }
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/LibMongoCryptControllerBase.cs b/src/MongoDB.Driver/Encryption/LibMongoCryptControllerBase.cs
new file mode 100644
index 00000000000..e5beededa25
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/LibMongoCryptControllerBase.cs
@@ -0,0 +1,240 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Net.Security;
+using System.Net.Sockets;
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Bson;
+using MongoDB.Bson.IO;
+using MongoDB.Libmongocrypt;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal abstract class LibMongoCryptControllerBase
+ {
+ // protected fields
+ protected readonly CryptClient _cryptClient;
+ protected readonly IMongoClient _keyVaultClient;
+ protected readonly Lazy> _keyVaultCollection;
+ protected readonly CollectionNamespace _keyVaultNamespace;
+
+ // constructors
+ protected LibMongoCryptControllerBase(
+ CryptClient cryptClient,
+ IMongoClient keyVaultClient,
+ CollectionNamespace keyVaultNamespace)
+ {
+ _cryptClient = cryptClient;
+ _keyVaultClient = keyVaultClient; // _keyVaultClient might not be fully constructed at this point, don't call any instance methods on it yet
+ _keyVaultNamespace = keyVaultNamespace;
+ _keyVaultCollection = new Lazy>(GetKeyVaultCollection); // delay use _keyVaultClient
+ }
+
+ // protected methods
+ protected void FeedResult(CryptContext context, BsonDocument document)
+ {
+ var writerSettings = new BsonBinaryWriterSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ var documentBytes = document.ToBson(writerSettings: writerSettings);
+ context.Feed(documentBytes);
+ context.MarkDone();
+ }
+
+ protected void FeedResults(CryptContext context, IEnumerable documents)
+ {
+ var writerSettings = new BsonBinaryWriterSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ foreach (var document in documents)
+ {
+ var documentBytes = document.ToBson(writerSettings: writerSettings);
+ context.Feed(documentBytes);
+ }
+ context.MarkDone();
+ }
+
+ protected virtual void ProcessState(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ switch (context.State)
+ {
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
+ ProcessNeedKmsState(context, cancellationToken);
+ break;
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_KEYS:
+ ProcessNeedMongoKeysState(context, cancellationToken);
+ break;
+ default:
+ throw new InvalidOperationException($"Unexpected context state: {context.State}.");
+ }
+ }
+
+ protected virtual async Task ProcessStateAsync(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ switch (context.State)
+ {
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_KMS:
+ await ProcessNeedKmsStateAsync(context, cancellationToken).ConfigureAwait(false);
+ break;
+ case CryptContext.StateCode.MONGOCRYPT_CTX_NEED_MONGO_KEYS:
+ await ProcessNeedMongoKeysStateAsync(context, cancellationToken).ConfigureAwait(false);
+ break;
+ default:
+ throw new InvalidOperationException($"Unexpected context state: {context.State}.");
+ }
+ }
+
+ protected byte[] ProcessStates(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ byte[] result = null;
+ while (context.State != CryptContext.StateCode.MONGOCRYPT_CTX_DONE)
+ {
+ if (context.State == CryptContext.StateCode.MONGOCRYPT_CTX_READY)
+ {
+ result = ProcessReadyState(context);
+ }
+ else
+ {
+ ProcessState(context, databaseName, cancellationToken);
+ }
+ }
+ return result;
+ }
+
+ protected async Task ProcessStatesAsync(CryptContext context, string databaseName, CancellationToken cancellationToken)
+ {
+ byte[] result = null;
+ while (context.State != CryptContext.StateCode.MONGOCRYPT_CTX_DONE)
+ {
+ if (context.State == CryptContext.StateCode.MONGOCRYPT_CTX_READY)
+ {
+ result = ProcessReadyState(context);
+ }
+ else
+ {
+ await ProcessStateAsync(context, databaseName, cancellationToken).ConfigureAwait(false);
+ }
+ }
+ return result;
+ }
+
+ // private methods
+ private IMongoCollection GetKeyVaultCollection()
+ {
+ var keyVaultDatabase = _keyVaultClient.GetDatabase(_keyVaultNamespace.DatabaseNamespace.DatabaseName);
+ return keyVaultDatabase.GetCollection(_keyVaultNamespace.CollectionName);
+ }
+
+ private void ProcessNeedKmsState(CryptContext context, CancellationToken cancellationToken)
+ {
+ var requests = context.GetKmsMessageRequests();
+ foreach (var request in requests)
+ {
+ SendKmsRequest(request, cancellationToken);
+ }
+ requests.MarkDone();
+ }
+
+ private async Task ProcessNeedKmsStateAsync(CryptContext context, CancellationToken cancellationToken)
+ {
+ var requests = context.GetKmsMessageRequests();
+ foreach (var request in requests)
+ {
+ await SendKmsRequestAsync(request, cancellationToken).ConfigureAwait(false);
+ }
+ requests.MarkDone();
+ }
+
+ private void ProcessNeedMongoKeysState(CryptContext context, CancellationToken cancellationToken)
+ {
+ var filterBytes = context.GetOperation().ToArray();
+ var filterDocument = new RawBsonDocument(filterBytes);
+ var filter = new BsonDocumentFilterDefinition(filterDocument);
+ var cursor = _keyVaultCollection.Value.FindSync(filter, cancellationToken: cancellationToken);
+ var results = cursor.ToList(cancellationToken);
+ FeedResults(context, results);
+ }
+
+ private async Task ProcessNeedMongoKeysStateAsync(CryptContext context, CancellationToken cancellationToken)
+ {
+ var filterBytes = context.GetOperation().ToArray();
+ var filterDocument = new RawBsonDocument(filterBytes);
+ var filter = new BsonDocumentFilterDefinition(filterDocument);
+ var cursor = await _keyVaultCollection.Value.FindAsync(filter, cancellationToken: cancellationToken).ConfigureAwait(false);
+ var results = await cursor.ToListAsync(cancellationToken).ConfigureAwait(false);
+ FeedResults(context, results);
+ }
+
+ private byte[] ProcessReadyState(CryptContext context)
+ {
+ return context.FinalizeForEncryption().ToArray();
+ }
+
+ private void SendKmsRequest(KmsRequest request, CancellationToken cancellation)
+ {
+ var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+ socket.Connect(request.Endpoint, 443);
+
+ using (var networkStream = new NetworkStream(socket, ownsSocket: true))
+ using (var sslStream = new SslStream(networkStream, leaveInnerStreamOpen: false))
+ {
+#if NETSTANDARD1_5
+ sslStream.AuthenticateAsClientAsync(request.Endpoint).ConfigureAwait(false).GetAwaiter().GetResult();
+#else
+ sslStream.AuthenticateAsClient(request.Endpoint);
+#endif
+
+ var requestBytes = request.Message.ToArray();
+ sslStream.Write(requestBytes);
+
+ var buffer = new byte[4096];
+ while (request.BytesNeeded > 0)
+ {
+ var count = sslStream.Read(buffer, 0, buffer.Length);
+ var responseBytes = new byte[count];
+ Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
+ request.Feed(responseBytes);
+ }
+ }
+ }
+
+ private async Task SendKmsRequestAsync(KmsRequest request, CancellationToken cancellation)
+ {
+ var socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+#if NETSTANDARD1_5
+ await socket.ConnectAsync(request.Endpoint, 443).ConfigureAwait(false);
+#else
+ await Task.Factory.FromAsync(socket.BeginConnect(request.Endpoint, 443, null, null), socket.EndConnect).ConfigureAwait(false);
+#endif
+
+ using (var networkStream = new NetworkStream(socket, ownsSocket: true))
+ using (var sslStream = new SslStream(networkStream, leaveInnerStreamOpen: false))
+ {
+ await sslStream.AuthenticateAsClientAsync(request.Endpoint).ConfigureAwait(false);
+
+ var requestBytes = request.Message.ToArray();
+ await sslStream.WriteAsync(requestBytes, 0, requestBytes.Length).ConfigureAwait(false);
+
+ var buffer = new byte[4096];
+ while (request.BytesNeeded > 0)
+ {
+ var count = await sslStream.ReadAsync(buffer, 0, buffer.Length).ConfigureAwait(false);
+ var responseBytes = new byte[count];
+ Buffer.BlockCopy(buffer, 0, responseBytes, 0, count);
+ request.Feed(responseBytes);
+ }
+ }
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs b/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs
new file mode 100644
index 00000000000..3c54bfd4ff0
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/MongoEncryptionException.cs
@@ -0,0 +1,41 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+#if NET452
+using System.Runtime.Serialization;
+#endif
+
+using System;
+
+namespace MongoDB.Driver.Encryption
+{
+ ///
+ /// Represents an encryption exception.
+ ///
+#if NET452
+ [Serializable]
+#endif
+ public class MongoEncryptionException : MongoClientException
+ {
+ ///
+ /// Initializes a new instance of the class.
+ ///
+ /// The inner exception.
+ public MongoEncryptionException(Exception innerException)
+ : base($"Encryption related exception: {innerException.Message}.", innerException)
+ {
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/MongocryptdFactory.cs b/src/MongoDB.Driver/Encryption/MongocryptdFactory.cs
new file mode 100644
index 00000000000..965dc239b4a
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/MongocryptdFactory.cs
@@ -0,0 +1,184 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections;
+using System.Collections.Generic;
+using System.Diagnostics;
+using System.IO;
+using System.Linq;
+using System.Reflection;
+using MongoDB.Driver.Core.Misc;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal static class EncryptionExtraOptionsValidator
+ {
+ #region static
+ private static readonly Dictionary __supportedExtraOptions = new Dictionary
+ {
+ { "mongocryptdURI", new [] { typeof(string) } },
+ { "mongocryptdBypassSpawn", new [] { typeof(bool) } },
+ { "mongocryptdSpawnPath", new [] { typeof(string) } },
+ { "mongocryptdSpawnArgs", new [] { typeof(string), typeof(IEnumerable) } }
+ };
+ #endregion
+
+ public static void EnsureThatExtraOptionsAreValid(IReadOnlyDictionary extraOptions)
+ {
+ if (extraOptions == null)
+ {
+ return;
+ }
+
+ foreach (var extraOption in extraOptions)
+ {
+ if (__supportedExtraOptions.TryGetValue(extraOption.Key, out var validTypes))
+ {
+ var extraOptionValue = Ensure.IsNotNull(extraOption.Value, nameof(extraOptions));
+ var extraOptionValueType = extraOptionValue.GetType();
+ var isExtraOptionValueTypeValid = validTypes.Any(t => t.GetTypeInfo().IsAssignableFrom(extraOptionValueType));
+ if (!isExtraOptionValueTypeValid)
+ {
+ throw new ArgumentException($"Extra option {extraOption.Key} has invalid type: {extraOptionValueType}.", nameof(extraOptions));
+ }
+ }
+ else
+ {
+ throw new ArgumentException($"Invalid extra option key: {extraOption.Key}.", nameof(extraOptions));
+ }
+ }
+ }
+ }
+
+ internal class MongocryptdFactory
+ {
+ private readonly IReadOnlyDictionary _extraOptions;
+
+ public MongocryptdFactory(IReadOnlyDictionary extraOptions)
+ {
+ _extraOptions = extraOptions ?? new Dictionary();
+ }
+
+ // public methods
+ public IMongoClient CreateMongocryptdClient()
+ {
+ var connectionString = CreateMongocryptdConnectionString();
+ var clientSettings = MongoClientSettings.FromConnectionString(connectionString);
+ clientSettings.ServerSelectionTimeout = TimeSpan.FromMilliseconds(1000);
+ return new MongoClient(clientSettings);
+ }
+
+ public void SpawnMongocryptdProcessIfRequired()
+ {
+ if (ShouldMongocryptdBeSpawned(out var path, out var args))
+ {
+ StartProcess(path, args);
+ }
+ }
+
+ // private methods
+ private string CreateMongocryptdConnectionString()
+ {
+ if (_extraOptions.TryGetValue("mongocryptdURI", out var connectionString))
+ {
+ return (string)connectionString;
+ }
+ else
+ {
+ return "mongodb://localhost:27020";
+ }
+ }
+
+ private bool ShouldMongocryptdBeSpawned(out string path, out string args)
+ {
+ path = null;
+ args = null;
+ if (!_extraOptions.TryGetValue("mongocryptdBypassSpawn", out var mongoCryptBypassSpawn)
+ || !(bool)mongoCryptBypassSpawn)
+ {
+ if (_extraOptions.TryGetValue("mongocryptdSpawnPath", out var objPath))
+ {
+ path = (string)objPath;
+ }
+ else
+ {
+ path = string.Empty; // look at the PATH env variable
+ }
+
+ if (!Path.HasExtension(path))
+ {
+ string fileName = "mongocryptd.exe";
+ path = Path.Combine(path, fileName);
+ }
+
+ args = string.Empty;
+ if (_extraOptions.TryGetValue("mongocryptdSpawnArgs", out var mongocryptdSpawnArgs))
+ {
+ string trimStartHyphens(string str) => str.TrimStart('-').TrimStart('-');
+ switch (mongocryptdSpawnArgs)
+ {
+ case string str:
+ args += str;
+ break;
+ case IEnumerable enumerable:
+ foreach (var item in enumerable)
+ {
+ args += $"--{trimStartHyphens(item.ToString())} ";
+ }
+ break;
+ default:
+ throw new InvalidCastException($"Invalid type: {mongocryptdSpawnArgs.GetType().Name} of mongocryptdSpawnArgs option.");
+ }
+ }
+
+ args = args.Trim();
+ if (!args.Contains("idleShutdownTimeoutSecs"))
+ {
+ args += " --idleShutdownTimeoutSecs 60";
+ }
+ args = args.Trim();
+
+ return true;
+ }
+
+ return false;
+ }
+
+ private void StartProcess(string path, string args)
+ {
+ try
+ {
+ using (var process = new Process())
+ {
+ process.StartInfo.Arguments = args;
+ process.StartInfo.FileName = path;
+ process.StartInfo.CreateNoWindow = true;
+ process.StartInfo.UseShellExecute = false;
+
+ if (!process.Start())
+ {
+ // skip it. This case can happen if no new process resource is started
+ // (for example, if an existing process is reused)
+ }
+ }
+ }
+ catch (Exception ex)
+ {
+ throw new MongoClientException("Exception starting mongocryptd process. Is mongocryptd on the system path?", ex);
+ }
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/Encryption/NoopBinaryDocumentFieldCryptor.cs b/src/MongoDB.Driver/Encryption/NoopBinaryDocumentFieldCryptor.cs
new file mode 100644
index 00000000000..4b080ae81e8
--- /dev/null
+++ b/src/MongoDB.Driver/Encryption/NoopBinaryDocumentFieldCryptor.cs
@@ -0,0 +1,44 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Threading;
+using System.Threading.Tasks;
+using MongoDB.Driver.Core.WireProtocol;
+
+namespace MongoDB.Driver.Encryption
+{
+ internal class NoopBinaryDocumentFieldCryptor : IBinaryDocumentFieldDecryptor, IBinaryCommandFieldEncryptor
+ {
+ public byte[] DecryptFields(byte[] encryptedDocumentBytes, CancellationToken cancellationToken)
+ {
+ return encryptedDocumentBytes;
+ }
+
+ public Task DecryptFieldsAsync(byte[] encryptedDocumentBytes, CancellationToken cancellationToken)
+ {
+ return Task.FromResult(encryptedDocumentBytes);
+ }
+
+ public byte[] EncryptFields(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken)
+ {
+ return unencryptedCommandBytes;
+ }
+
+ public Task EncryptFieldsAsync(string databaseName, byte[] unencryptedCommandBytes, CancellationToken cancellationToken)
+ {
+ return Task.FromResult(unencryptedCommandBytes);
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/IReadOnlyDictionaryExtensions.cs b/src/MongoDB.Driver/IReadOnlyDictionaryExtensions.cs
new file mode 100644
index 00000000000..98894290498
--- /dev/null
+++ b/src/MongoDB.Driver/IReadOnlyDictionaryExtensions.cs
@@ -0,0 +1,53 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+
+namespace MongoDB.Driver
+{
+ internal static class IReadOnlyDictionaryExtensions
+ {
+ public static bool IsEquivalentTo(this IReadOnlyDictionary x, IReadOnlyDictionary y, Func equals)
+ {
+ if (object.ReferenceEquals(x, y))
+ {
+ return true;
+ }
+
+ if (object.ReferenceEquals(x, null) || object.ReferenceEquals(y, null))
+ {
+ return false;
+ }
+
+ if (x.Count != y.Count)
+ {
+ return false;
+ }
+
+ foreach (var keyValuePair in x)
+ {
+ var key = keyValuePair.Key;
+ var xValue = keyValuePair.Value;
+ if (!y.TryGetValue(key, out var yValue) || !equals(xValue, yValue))
+ {
+ return false;
+ }
+ }
+
+ return true;
+ }
+ }
+}
diff --git a/src/MongoDB.Driver/MongoClient.cs b/src/MongoDB.Driver/MongoClient.cs
index 0173b448078..a6f55d98617 100644
--- a/src/MongoDB.Driver/MongoClient.cs
+++ b/src/MongoDB.Driver/MongoClient.cs
@@ -29,6 +29,7 @@
using MongoDB.Driver.Core.Operations;
using MongoDB.Driver.Core.Servers;
using MongoDB.Driver.Core.WireProtocol.Messages.Encoders;
+using MongoDB.Driver.Encryption;
namespace MongoDB.Driver
{
@@ -53,6 +54,7 @@ private static IEnumerable SelectServersThatDetermineWhetherS
// private fields
private readonly ICluster _cluster;
+ private readonly AutoEncryptionLibMongoCryptController _libMongoCryptController;
private readonly IOperationExecutor _operationExecutor;
private readonly MongoClientSettings _settings;
@@ -78,6 +80,13 @@ public MongoClient(MongoClientSettings settings)
_settings = Ensure.IsNotNull(settings, nameof(settings)).FrozenCopy();
_cluster = ClusterRegistry.Instance.GetOrCreateCluster(_settings.ToClusterKey());
_operationExecutor = new OperationExecutor(this);
+ if (settings.AutoEncryptionOptions != null)
+ {
+ _libMongoCryptController = new AutoEncryptionLibMongoCryptController(
+ this,
+ _cluster.CryptClient,
+ settings.AutoEncryptionOptions);
+ }
}
///
@@ -128,10 +137,26 @@ public sealed override MongoClientSettings Settings
}
// internal properties
+ internal AutoEncryptionLibMongoCryptController LibMongoCryptController => _libMongoCryptController;
internal IOperationExecutor OperationExecutor => _operationExecutor;
+
+ // internal methods
+ internal void ConfigureAutoEncryptionMessageEncoderSettings(MessageEncoderSettings messageEncoderSettings)
+ {
+ var autoEncryptionOptions = _settings.AutoEncryptionOptions;
+ if (autoEncryptionOptions != null)
+ {
+ if (!autoEncryptionOptions.BypassAutoEncryption)
+ {
+ messageEncoderSettings.Add(MessageEncoderSettingsName.BinaryDocumentFieldEncryptor, _libMongoCryptController);
+ }
+ messageEncoderSettings.Add(MessageEncoderSettingsName.BinaryDocumentFieldDecryptor, _libMongoCryptController);
+ }
+ }
// private static methods
+
// public methods
///
public sealed override void DropDatabase(string name, CancellationToken cancellationToken = default(CancellationToken))
@@ -479,10 +504,10 @@ private IReadWriteBindingHandle CreateReadWriteBinding(IClientSessionHandle sess
ChangeStreamOptions options)
{
return ChangeStreamHelper.CreateChangeStreamOperation(
- pipeline,
- options,
- _settings.ReadConcern,
- GetMessageEncoderSettings(),
+ pipeline,
+ options,
+ _settings.ReadConcern,
+ GetMessageEncoderSettings(),
_settings.RetryReads);
}
@@ -520,12 +545,16 @@ private async Task ExecuteWriteOperationAsync(IClientSessionHa
private MessageEncoderSettings GetMessageEncoderSettings()
{
- return new MessageEncoderSettings
+ var messageEncoderSettings = new MessageEncoderSettings
{
{ MessageEncoderSettingsName.GuidRepresentation, _settings.GuidRepresentation },
{ MessageEncoderSettingsName.ReadEncoding, _settings.ReadEncoding ?? Utf8Encodings.Strict },
{ MessageEncoderSettingsName.WriteEncoding, _settings.WriteEncoding ?? Utf8Encodings.Strict }
};
+
+ ConfigureAutoEncryptionMessageEncoderSettings(messageEncoderSettings);
+
+ return messageEncoderSettings;
}
private IClientSessionHandle StartImplicitSession(bool areSessionsSupported)
diff --git a/src/MongoDB.Driver/MongoClientSettings.cs b/src/MongoDB.Driver/MongoClientSettings.cs
index 67be8219274..78e6f297810 100644
--- a/src/MongoDB.Driver/MongoClientSettings.cs
+++ b/src/MongoDB.Driver/MongoClientSettings.cs
@@ -21,6 +21,7 @@
using MongoDB.Bson;
using MongoDB.Driver.Core.Configuration;
using MongoDB.Driver.Core.Misc;
+using MongoDB.Driver.Encryption;
using MongoDB.Shared;
namespace MongoDB.Driver
@@ -33,6 +34,7 @@ public class MongoClientSettings : IEquatable, IInheritable
// private fields
private bool _allowInsecureTls;
private string _applicationName;
+ private AutoEncryptionOptions _autoEncryptionOptions;
private Action _clusterConfigurator;
private IReadOnlyList _compressors;
private ConnectionMode _connectionMode;
@@ -78,6 +80,7 @@ public MongoClientSettings()
{
_allowInsecureTls = false;
_applicationName = null;
+ _autoEncryptionOptions = null;
_compressors = new CompressorConfiguration[0];
_connectionMode = ConnectionMode.Automatic;
_connectTimeout = MongoDefaults.ConnectTimeout;
@@ -137,6 +140,19 @@ public string ApplicationName
}
}
+ ///
+ /// Gets or sets the auto encryption options.
+ ///
+ public AutoEncryptionOptions AutoEncryptionOptions
+ {
+ get { return _autoEncryptionOptions; }
+ set
+ {
+ if (_isFrozen) { throw new InvalidOperationException("MongoClientSettings is frozen."); }
+ _autoEncryptionOptions = value;
+ }
+ }
+
///
/// Gets or sets the compressors.
///
@@ -692,6 +708,7 @@ public static MongoClientSettings FromUrl(MongoUrl url)
var clientSettings = new MongoClientSettings();
clientSettings.AllowInsecureTls = url.AllowInsecureTls;
clientSettings.ApplicationName = url.ApplicationName;
+ clientSettings.AutoEncryptionOptions = null; // must be configured via code
clientSettings.Compressors = url.Compressors;
clientSettings.ConnectionMode = url.ConnectionMode;
clientSettings.ConnectTimeout = url.ConnectTimeout;
@@ -748,6 +765,7 @@ public MongoClientSettings Clone()
var clone = new MongoClientSettings();
clone._allowInsecureTls = _allowInsecureTls;
clone._applicationName = _applicationName;
+ clone._autoEncryptionOptions = _autoEncryptionOptions;
clone._compressors = _compressors;
clone._clusterConfigurator = _clusterConfigurator;
clone._connectionMode = _connectionMode;
@@ -808,6 +826,7 @@ public override bool Equals(object obj)
return
_allowInsecureTls == rhs._allowInsecureTls &&
_applicationName == rhs._applicationName &&
+ object.Equals(_autoEncryptionOptions, rhs._autoEncryptionOptions) &&
object.ReferenceEquals(_clusterConfigurator, rhs._clusterConfigurator) &&
_compressors.SequenceEqual(rhs._compressors) &&
_connectionMode == rhs._connectionMode &&
@@ -886,6 +905,7 @@ public override int GetHashCode()
return new Hasher()
.Hash(_allowInsecureTls)
.Hash(_applicationName)
+ .Hash(_autoEncryptionOptions)
.Hash(_clusterConfigurator)
.HashElements(_compressors)
.Hash(_connectionMode)
@@ -936,7 +956,10 @@ public override string ToString()
{
sb.AppendFormat("ApplicationName={0};", _applicationName);
}
-
+ if (_autoEncryptionOptions != null)
+ {
+ sb.AppendFormat("AutoEncryptionOptions={0};", _autoEncryptionOptions);
+ }
if (_compressors?.Any() ?? false)
{
sb.AppendFormat("Compressors=[{0}];", string.Join(",", _compressors));
@@ -1003,6 +1026,7 @@ internal ClusterKey ToClusterKey()
_heartbeatInterval,
_heartbeatTimeout,
_ipv6,
+ _autoEncryptionOptions?.KmsProviders,
_localThreshold,
_maxConnectionIdleTime,
_maxConnectionLifeTime,
@@ -1010,6 +1034,7 @@ internal ClusterKey ToClusterKey()
_minConnectionPoolSize,
MongoDefaults.TcpReceiveBufferSize, // TODO: add ReceiveBufferSize to MongoClientSettings?
_replicaSetName,
+ _autoEncryptionOptions?.SchemaMap,
_scheme,
_sdamLogFilename,
MongoDefaults.TcpSendBufferSize, // TODO: add SendBufferSize to MongoClientSettings?
diff --git a/src/MongoDB.Driver/MongoCollectionImpl.cs b/src/MongoDB.Driver/MongoCollectionImpl.cs
index c7bbcff4866..50befe8b82c 100644
--- a/src/MongoDB.Driver/MongoCollectionImpl.cs
+++ b/src/MongoDB.Driver/MongoCollectionImpl.cs
@@ -55,12 +55,7 @@ private MongoCollectionImpl(IMongoDatabase database, CollectionNamespace collect
_operationExecutor = Ensure.IsNotNull(operationExecutor, nameof(operationExecutor));
_documentSerializer = Ensure.IsNotNull(documentSerializer, nameof(documentSerializer));
- _messageEncoderSettings = new MessageEncoderSettings
- {
- { MessageEncoderSettingsName.GuidRepresentation, _settings.GuidRepresentation },
- { MessageEncoderSettingsName.ReadEncoding, _settings.ReadEncoding ?? Utf8Encodings.Strict },
- { MessageEncoderSettingsName.WriteEncoding, _settings.WriteEncoding ?? Utf8Encodings.Strict }
- };
+ _messageEncoderSettings = GetMessageEncoderSettings();
}
// properties
@@ -1071,6 +1066,23 @@ private IWriteBindingHandle CreateReadWriteBinding(IClientSessionHandle session)
return new ReadWriteBindingHandle(binding);
}
+ private MessageEncoderSettings GetMessageEncoderSettings()
+ {
+ var messageEncoderSettings = new MessageEncoderSettings
+ {
+ { MessageEncoderSettingsName.GuidRepresentation, _settings.GuidRepresentation },
+ { MessageEncoderSettingsName.ReadEncoding, _settings.ReadEncoding ?? Utf8Encodings.Strict },
+ { MessageEncoderSettingsName.WriteEncoding, _settings.WriteEncoding ?? Utf8Encodings.Strict }
+ };
+
+ if (_database.Client is MongoClient mongoClient)
+ {
+ mongoClient.ConfigureAutoEncryptionMessageEncoderSettings(messageEncoderSettings);
+ }
+
+ return messageEncoderSettings;
+ }
+
private IBsonSerializer GetValueSerializerForDistinct(RenderedFieldDefinition renderedField, IBsonSerializerRegistry serializerRegistry)
{
if (renderedField.UnderlyingSerializer != null)
diff --git a/src/MongoDB.Driver/MongoDB.Driver.csproj b/src/MongoDB.Driver/MongoDB.Driver.csproj
index 68f875bf2a3..0aa56c0272f 100644
--- a/src/MongoDB.Driver/MongoDB.Driver.csproj
+++ b/src/MongoDB.Driver/MongoDB.Driver.csproj
@@ -42,6 +42,7 @@
+
diff --git a/src/MongoDB.Driver/MongoDatabaseImpl.cs b/src/MongoDB.Driver/MongoDatabaseImpl.cs
index d11d179a4e9..31cb40c61d8 100644
--- a/src/MongoDB.Driver/MongoDatabaseImpl.cs
+++ b/src/MongoDB.Driver/MongoDatabaseImpl.cs
@@ -454,11 +454,12 @@ public override IMongoDatabase WithWriteConcern(WriteConcern writeConcern)
// private methods
private AggregateOperation CreateAggregateOperation(RenderedPipelineDefinition renderedPipeline, AggregateOptions options)
{
+ var messageEncoderSettings = GetMessageEncoderSettings();
return new AggregateOperation(
_databaseNamespace,
renderedPipeline.Documents,
renderedPipeline.OutputSerializer,
- _messageEncoderSettings)
+ messageEncoderSettings)
{
AllowDiskUse = options.AllowDiskUse,
BatchSize = options.BatchSize,
@@ -508,7 +509,12 @@ private FindOperation CreateAggregateToCollectionFindOperation
throw new ArgumentException($"Unexpected stage name: {stageName}.");
}
- return new FindOperation(outputCollectionNamespace, resultSerializer, _messageEncoderSettings)
+ // because auto encryption is not supported for non-collection commands.
+ // So, an error will be thrown in the previous CreateAggregateToCollectionOperation step.
+ // However, since we've added encryption configuration for CreateAggregateToCollectionOperation operation,
+ // it's not superfluous to also add it here
+ var messageEncoderSettings = GetMessageEncoderSettings();
+ return new FindOperation(outputCollectionNamespace, resultSerializer, messageEncoderSettings)
{
BatchSize = options.BatchSize,
Collation = options.Collation,
@@ -520,10 +526,11 @@ private FindOperation CreateAggregateToCollectionFindOperation
private AggregateToCollectionOperation CreateAggregateToCollectionOperation(RenderedPipelineDefinition renderedPipeline, AggregateOptions options)
{
+ var messageEncoderSettings = GetMessageEncoderSettings();
return new AggregateToCollectionOperation(
_databaseNamespace,
renderedPipeline.Documents,
- _messageEncoderSettings)
+ messageEncoderSettings)
{
AllowDiskUse = options.AllowDiskUse,
BypassDocumentValidation = options.BypassDocumentValidation,
@@ -751,12 +758,19 @@ private async Task ExecuteWriteOperationAsync(IClientSessionHandle session
private MessageEncoderSettings GetMessageEncoderSettings()
{
- return new MessageEncoderSettings
+ var messageEncoderSettings = new MessageEncoderSettings
{
{ MessageEncoderSettingsName.GuidRepresentation, _settings.GuidRepresentation },
{ MessageEncoderSettingsName.ReadEncoding, _settings.ReadEncoding ?? Utf8Encodings.Strict },
{ MessageEncoderSettingsName.WriteEncoding, _settings.WriteEncoding ?? Utf8Encodings.Strict }
};
+
+ if (_client is MongoClient mongoClient)
+ {
+ mongoClient.ConfigureAutoEncryptionMessageEncoderSettings(messageEncoderSettings);
+ }
+
+ return messageEncoderSettings;
}
private void UsingImplicitSession(Action func, CancellationToken cancellationToken)
diff --git a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/AspectAsserter.cs b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/AspectAsserter.cs
index f0dcab411f1..559ce62bac2 100644
--- a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/AspectAsserter.cs
+++ b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/AspectAsserter.cs
@@ -13,6 +13,7 @@
* limitations under the License.
*/
+using System.Collections.Generic;
using FluentAssertions;
namespace MongoDB.Bson.TestHelpers.JsonDrivenTests
@@ -21,6 +22,11 @@ public abstract class AspectAsserter
{
// public methods
public abstract void AssertAspects(object actualValue, BsonDocument aspects);
+
+ public virtual void ConfigurePlaceholders(KeyValuePair[] placeholders)
+ {
+ // do nothing by default
+ }
}
public abstract class AspectAsserter : AspectAsserter
diff --git a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/EmbeddedResourceJsonFileReader.cs b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/EmbeddedResourceJsonFileReader.cs
new file mode 100644
index 00000000000..0d36ba65877
--- /dev/null
+++ b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/EmbeddedResourceJsonFileReader.cs
@@ -0,0 +1,75 @@
+/* Copyright 2019-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Reflection;
+using MongoDB.Bson.IO;
+using MongoDB.Bson.Serialization;
+using MongoDB.Bson.Serialization.Serializers;
+
+namespace MongoDB.Bson.TestHelpers.JsonDrivenTests
+{
+ public abstract class EmbeddedResourceJsonFileReader
+ {
+ // protected properties
+ protected virtual Assembly Assembly => this.GetType().GetTypeInfo().Assembly;
+
+ protected virtual string PathPrefix { get; } = null;
+ protected virtual string[] PathPrefixes { get; } = null;
+
+ protected virtual BsonDocument ReadJsonDocument(string path)
+ {
+ var jsonReaderSettings = new JsonReaderSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ using (var stream = Assembly.GetManifestResourceStream(path))
+ using (var streamReader = new StreamReader(stream))
+ using (var jsonReader = new JsonReader(streamReader, jsonReaderSettings))
+ {
+ var context = BsonDeserializationContext.CreateRoot(jsonReader);
+ var document = BsonDocumentSerializer.Instance.Deserialize(context);
+ document.InsertAt(0, new BsonElement("_path", path));
+ return document;
+ }
+ }
+
+ protected virtual IEnumerable ReadJsonDocuments()
+ {
+ return
+ Assembly.GetManifestResourceNames()
+ .Where(path => ShouldReadJsonDocument(path))
+ .Select(path => ReadJsonDocument(path));
+ }
+
+ protected virtual bool ShouldReadJsonDocument(string path)
+ {
+ var prefixes = GetPathPrefixes();
+ return prefixes.Any(path.StartsWith) && path.EndsWith(".json");
+ }
+
+ private string[] GetPathPrefixes()
+ {
+ var prefixes = !string.IsNullOrEmpty(PathPrefix) ? new[] { PathPrefix } : PathPrefixes;
+
+ if (prefixes == null || prefixes.Length == 0)
+ {
+ throw new NotImplementedException("At least one path prefix must be specified.");
+ }
+
+ return prefixes;
+ }
+ }
+}
diff --git a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTest.cs b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTest.cs
index 2fbb0a03caa..a3c1352d93a 100644
--- a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTest.cs
+++ b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTest.cs
@@ -71,9 +71,9 @@ public virtual void Arrange(BsonDocument document)
_expectedException = new BsonDocument(); // any exception will do
}
- if (document.Contains("result"))
+ if (document.TryGetValue("result", out var result) || document.TryGetValue("results", out result))
{
- ParseExpectedResult(document["result"]);
+ ParseExpectedResult(result);
}
}
diff --git a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCase.cs b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCase.cs
index 667d3540079..705cc74c776 100644
--- a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCase.cs
+++ b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCase.cs
@@ -13,6 +13,10 @@
* limitations under the License.
*/
+using MongoDB.Bson.IO;
+using MongoDB.Bson.Serialization;
+using MongoDB.Bson.Serialization.Serializers;
+using System.IO;
using Xunit.Abstractions;
namespace MongoDB.Bson.TestHelpers.JsonDrivenTests
@@ -47,20 +51,43 @@ public JsonDrivenTestCase(string name, BsonDocument shared, BsonDocument test)
public void Deserialize(IXunitSerializationInfo info)
{
_name = info.GetValue(nameof(_name));
- _shared = BsonDocument.Parse(info.GetValue(nameof(_shared)));
- _test = BsonDocument.Parse(info.GetValue(nameof(_test)));
+ _shared = DeserializeBsonDocument(info.GetValue(nameof(_shared)));
+ _test = DeserializeBsonDocument(info.GetValue(nameof(_test)));
}
public void Serialize(IXunitSerializationInfo info)
{
info.AddValue(nameof(_name), _name);
- info.AddValue(nameof(_shared), _shared.ToJson());
- info.AddValue(nameof(_test), _test.ToJson());
+ info.AddValue(nameof(_shared), SerializeBsonDocument(_shared));
+ info.AddValue(nameof(_test), SerializeBsonDocument(_test));
}
public override string ToString()
{
return _name;
}
+
+ // private methods
+ private BsonDocument DeserializeBsonDocument(string value)
+ {
+ var jsonReaderSettings = new JsonReaderSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ using (var jsonReader = new JsonReader(value, jsonReaderSettings))
+ {
+ var context = BsonDeserializationContext.CreateRoot(jsonReader);
+ return BsonDocumentSerializer.Instance.Deserialize(context);
+ }
+ }
+
+ private string SerializeBsonDocument(BsonDocument value)
+ {
+ var jsonWriterSettings = new JsonWriterSettings { GuidRepresentation = GuidRepresentation.Unspecified };
+ using (var stringWriter = new StringWriter())
+ using (var jsonWriter = new JsonWriter(stringWriter, jsonWriterSettings))
+ {
+ var context = BsonSerializationContext.CreateRoot(jsonWriter);
+ BsonDocumentSerializer.Instance.Serialize(context, value);
+ return stringWriter.ToString();
+ }
+ }
}
}
diff --git a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCaseFactory.cs b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCaseFactory.cs
index 65937a7c516..d67f5eb8c5c 100644
--- a/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCaseFactory.cs
+++ b/tests/MongoDB.Bson.TestHelpers/JsonDrivenTests/JsonDrivenTestCaseFactory.cs
@@ -16,22 +16,14 @@
using System;
using System.Collections;
using System.Collections.Generic;
-using System.IO;
using System.Linq;
-using System.Reflection;
using Xunit.Abstractions;
namespace MongoDB.Bson.TestHelpers.JsonDrivenTests
{
- public abstract class JsonDrivenTestCaseFactory : IEnumerable