From 2d89ee440206354c59ea2c993b48e81f8c278cc0 Mon Sep 17 00:00:00 2001
From: LucStr <25279790+LucStr@users.noreply.github.com>
Date: Sun, 14 Jan 2024 18:18:16 +0100
Subject: [PATCH 1/2] Add support for root interface serialization

---
 .../Serialization/BsonSerializationArgs.cs    |  16 ++-
 .../InterfaceDiscriminatorConvention.cs       | 108 +++++++++++++++
 .../Serializers/BsonClassMapSerializer.cs     |   8 +-
 .../DiscriminatedInterfaceSerializer.cs       |  23 +--
 ...nterfaceHierarchyWithoutAttributesTests.cs | 131 ++++++++++++++++++
 5 files changed, 264 insertions(+), 22 deletions(-)
 create mode 100644 src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
 create mode 100644 tests/MongoDB.Bson.Tests/Serialization/Serializers/AnimalInterfaceHierarchyWithoutAttributesTests.cs

diff --git a/src/MongoDB.Bson/Serialization/BsonSerializationArgs.cs b/src/MongoDB.Bson/Serialization/BsonSerializationArgs.cs
index b16d5a3e1b1..e02daf88828 100644
--- a/src/MongoDB.Bson/Serialization/BsonSerializationArgs.cs
+++ b/src/MongoDB.Bson/Serialization/BsonSerializationArgs.cs
@@ -14,6 +14,7 @@
 */
 
 using System;
+using MongoDB.Bson.Serialization.Conventions;
 
 namespace MongoDB.Bson.Serialization
 {
@@ -24,6 +25,7 @@ public struct BsonSerializationArgs
     {
         // private fields
         private Type _nominalType;
+        private IDiscriminatorConvention _discriminatorConvention;
         private bool _serializeAsNominalType;
         private bool _serializeIdFirst;
 
@@ -34,14 +36,17 @@ public struct BsonSerializationArgs
         /// <param name="nominalType">The nominal type.</param>
         /// <param name="serializeAsNominalType">Whether to serialize as the nominal type.</param>
         /// <param name="serializeIdFirst">Whether to serialize the id first.</param>
+        /// <param name="discriminatorConvention">The discriminator convention to be used.</param>
         public BsonSerializationArgs(
             Type nominalType,
             bool serializeAsNominalType,
-            bool serializeIdFirst)
+            bool serializeIdFirst,
+            IDiscriminatorConvention discriminatorConvention)
         {
             _nominalType = nominalType;
             _serializeAsNominalType = serializeAsNominalType;
             _serializeIdFirst = serializeIdFirst;
+            _discriminatorConvention = discriminatorConvention;
         }
 
         // public properties
@@ -57,6 +62,15 @@ public Type NominalType
             set { _nominalType = value; }
         }
 
+        /// <summary>
+        /// Gets or sets the discriminator convention.
+        /// </summary>
+        public IDiscriminatorConvention DiscriminatorConvention
+        {
+            get { return _discriminatorConvention; }
+            set { _discriminatorConvention = value; }
+        }
+
         /// <summary>
         /// Gets or sets a value indicating whether to serialize the value as if it were an instance of the nominal type.
         /// </summary>
diff --git a/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs b/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
new file mode 100644
index 00000000000..e1f2aee5f37
--- /dev/null
+++ b/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
@@ -0,0 +1,108 @@
+/* Copyright 2010-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+
+namespace MongoDB.Bson.Serialization.Conventions
+{
+    /// <summary>
+    /// Represents a discriminator convention where the discriminator is an array of all the discriminators provided by the class maps of the root class down to the actual type.
+    /// </summary>
+    public class InterfaceDiscriminatorConvention<TInterface> : StandardDiscriminatorConvention
+    {
+        private readonly IDictionary<Type, BsonValue> _discriminators = new Dictionary<Type, BsonValue>();
+        // constructors
+        /// <summary>
+        /// Initializes a new instance of the HierarchicalDiscriminatorConvention class.
+        /// </summary>
+        /// <param name="elementName">The element name.</param>
+        public InterfaceDiscriminatorConvention(string elementName)
+            : base(elementName)
+        {
+            PrecomputeDiscriminators();
+        }
+
+        // public methods
+        /// <summary>
+        /// Gets the discriminator value for an actual type.
+        /// </summary>
+        /// <param name="nominalType">The nominal type.</param>
+        /// <param name="actualType">The actual type.</param>
+        /// <returns>The discriminator value.</returns>
+        public override BsonValue GetDiscriminator(Type nominalType, Type actualType)
+        {
+            if (nominalType != typeof(TInterface))
+            {
+                return null;
+            }
+
+            return _discriminators.TryGetValue(actualType, out var discriminator) ? discriminator : null;
+        }
+
+        private void PrecomputeDiscriminators()
+        {
+            var interfaceType = typeof(TInterface);
+
+            if (!interfaceType.IsInterface)
+            {
+                throw new ArgumentException("<TInterface> must be an interface", nameof(TInterface));
+            }
+
+            var dependents = interfaceType.Assembly.GetTypes().Where(x => interfaceType.IsAssignableFrom(x) && x.IsClass);
+
+            foreach (var dependent in dependents)
+            {
+                var interfaces = OrderInterfaces(dependent.GetInterfaces().ToList());
+                var discriminator = new BsonArray(interfaces.Select(x => x.Name))
+                {
+                    dependent.Name
+                };
+
+                _discriminators.Add(dependent, discriminator);
+            }
+        }
+
+        private IEnumerable<Type> OrderInterfaces(List<Type> interfaces)
+        {
+            var sorted = new List<Type>();
+            while (interfaces.Any())
+            {
+                var allParentInterfaces = interfaces.SelectMany(t => t.GetInterfaces()).ToList();
+
+                foreach (var interfaceType in interfaces)
+                {
+                    var newInterfaces = new List<Type>();
+
+                    if (allParentInterfaces.Contains(interfaceType))
+                    {
+                        newInterfaces.Add(interfaceType);
+                    }
+                    else
+                    {
+                        sorted.Add(interfaceType);
+                    }
+
+                    interfaces = newInterfaces;
+                }
+            }
+
+            sorted.Reverse();
+
+            return sorted;
+        }
+    }
+}
diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
index e2c7d6dfe6b..57af65f29b5 100644
--- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
+++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
@@ -18,6 +18,7 @@
 using System.ComponentModel;
 using System.Reflection;
 using MongoDB.Bson.IO;
+using MongoDB.Bson.Serialization.Conventions;
 using MongoDB.Bson.Serialization.Serializers;
 
 namespace MongoDB.Bson.Serialization
@@ -567,7 +568,7 @@ private void SerializeClass(BsonSerializationContext context, BsonSerializationA
 
             if (ShouldSerializeDiscriminator(args.NominalType))
             {
-                SerializeDiscriminator(context, args.NominalType, document);
+                SerializeDiscriminator(context, args.NominalType, document, args.DiscriminatorConvention);
             }
 
             foreach (var memberMap in _classMap.AllMemberMaps)
@@ -611,9 +612,10 @@ private void SerializeExtraElements(BsonSerializationContext context, object obj
             }
         }
 
-        private void SerializeDiscriminator(BsonSerializationContext context, Type nominalType, object obj)
+        private void SerializeDiscriminator(BsonSerializationContext context, Type nominalType, object obj, IDiscriminatorConvention discriminatorConvention)
         {
-            var discriminatorConvention = _classMap.GetDiscriminatorConvention();
+            discriminatorConvention ??= _classMap.GetDiscriminatorConvention();
+
             if (discriminatorConvention != null)
             {
                 var actualType = obj.GetType();
diff --git a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs
index 91ca46ef2d0..ff26505a4ce 100644
--- a/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs
+++ b/src/MongoDB.Bson/Serialization/Serializers/DiscriminatedInterfaceSerializer.cs
@@ -56,7 +56,6 @@ private static IBsonSerializer<TInterface> CreateInterfaceSerializer()
         private readonly Type _interfaceType;
         private readonly IDiscriminatorConvention _discriminatorConvention;
         private readonly IBsonSerializer<TInterface> _interfaceSerializer;
-        private readonly IBsonSerializer<object> _objectSerializer;
 
         // constructors
         /// <summary>
@@ -96,18 +95,6 @@ public DiscriminatedInterfaceSerializer(IDiscriminatorConvention discriminatorCo
 
             _interfaceType = typeof(TInterface);
             _discriminatorConvention = discriminatorConvention ?? BsonSerializer.LookupDiscriminatorConvention(typeof(TInterface));
-            _objectSerializer = BsonSerializer.LookupSerializer<object>();
-            if (_objectSerializer is ObjectSerializer standardObjectSerializer)
-            {
-                _objectSerializer = standardObjectSerializer.WithDiscriminatorConvention(_discriminatorConvention);
-            }
-            else
-            {
-                if (discriminatorConvention != null)
-                {
-                    throw new BsonSerializationException("Can't set discriminator convention on custom object serializer.");
-                }
-            }
 
             _interfaceSerializer = interfaceSerializer;
         }
@@ -164,12 +151,12 @@ public override void Serialize(BsonSerializationContext context, BsonSerializati
             if (value == null)
             {
                 bsonWriter.WriteNull();
+                return;
             }
-            else
-            {
-                args.NominalType = typeof(object);
-                _objectSerializer.Serialize(context, args, value);
-            }
+
+            args.DiscriminatorConvention = _discriminatorConvention;
+            var serializer = BsonSerializer.LookupSerializer(value.GetType());
+            serializer.Serialize(context, args, value);
         }
 
         /// <inheritdoc/>
diff --git a/tests/MongoDB.Bson.Tests/Serialization/Serializers/AnimalInterfaceHierarchyWithoutAttributesTests.cs b/tests/MongoDB.Bson.Tests/Serialization/Serializers/AnimalInterfaceHierarchyWithoutAttributesTests.cs
new file mode 100644
index 00000000000..e2e077ffe99
--- /dev/null
+++ b/tests/MongoDB.Bson.Tests/Serialization/Serializers/AnimalInterfaceHierarchyWithoutAttributesTests.cs
@@ -0,0 +1,131 @@
+/* Copyright 2010-present MongoDB Inc.
+*
+* Licensed under the Apache License, Version 2.0 (the "License");
+* you may not use this file except in compliance with the License.
+* You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+using System.Linq;
+using MongoDB.Bson;
+using MongoDB.Bson.Serialization;
+using MongoDB.Bson.Serialization.Conventions;
+using MongoDB.Bson.Serialization.Options;
+using MongoDB.Bson.Serialization.Serializers;
+using Xunit;
+
+namespace MongoDB.Bson.Tests.Serialization
+{
+    public class AnimalInterfaceHierarchyWithoutAttributesTests
+    {
+        public interface IAnimal
+        {
+            public ObjectId Id { get; set; }
+            public int Age { get; set; }
+            public string Name { get; set; }
+        }
+
+        public class Bear : IAnimal
+        {
+            public ObjectId Id { get; set; }
+            public int Age { get; set; }
+            public string Name { get; set; }
+        }
+
+        public interface  ICat : IAnimal
+        {
+        }
+
+        public class Tiger : ICat
+        {
+            public ObjectId Id { get; set; }
+            public int Age { get; set; }
+            public string Name { get; set; }
+        }
+
+        public class Lion : ICat
+        {
+            public ObjectId Id { get; set; }
+            public int Age { get; set; }
+            public string Name { get; set; }
+        }
+
+        static AnimalInterfaceHierarchyWithoutAttributesTests()
+        {
+            BsonSerializer.RegisterSerializer(new DiscriminatedInterfaceSerializer<IAnimal>(new InterfaceDiscriminatorConvention<IAnimal>("_t")));
+            BsonClassMap.RegisterClassMap<Bear>();
+            BsonClassMap.RegisterClassMap<Tiger>();
+            BsonClassMap.RegisterClassMap<Lion>();
+        }
+
+        [Fact]
+        public void TestDeserializeBear()
+        {
+            var document = new BsonDocument
+            {
+                { "_id", ObjectId.Empty },
+                { "_t", new BsonArray { "IAnimal", "Bear" } },
+                { "Age", 123 },
+                { "Name", "Panda Bear" }
+            };
+
+            var bson = document.ToBson();
+            var rehydrated = (Bear)BsonSerializer.Deserialize<IAnimal>(bson);
+            Assert.IsType<Bear>(rehydrated);
+
+            var json = rehydrated.ToJson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true });
+            var expected = "{ '_id' : ObjectId('000000000000000000000000'), '_t' : ['IAnimal', 'Bear'], 'Age' : 123, 'Name' : 'Panda Bear' }".Replace("'", "\"");
+            Assert.Equal(expected, json);
+            Assert.True(bson.SequenceEqual(rehydrated.ToBson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true })));
+        }
+
+        [Fact]
+        public void TestDeserializeTiger()
+        {
+            var document = new BsonDocument
+            {
+                { "_id", ObjectId.Empty },
+                { "_t", new BsonArray { "IAnimal", "ICat", "Tiger" } },
+                { "Age", 234 },
+                { "Name", "Striped Tiger" }
+            };
+
+            var bson = document.ToBson();
+            var rehydrated = (Tiger)BsonSerializer.Deserialize<IAnimal>(bson);
+            Assert.IsType<Tiger>(rehydrated);
+
+            var json = rehydrated.ToJson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true });
+            var expected = "{ '_id' : ObjectId('000000000000000000000000'), '_t' : ['IAnimal', 'ICat', 'Tiger'], 'Age' : 234, 'Name' : 'Striped Tiger' }".Replace("'", "\"");
+            Assert.Equal(expected, json);
+            Assert.True(bson.SequenceEqual(rehydrated.ToBson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true })));
+        }
+
+        [Fact]
+        public void TestDeserializeLion()
+        {
+            var document = new BsonDocument
+            {
+                { "_id", ObjectId.Empty },
+                { "_t", new BsonArray { "IAnimal", "ICat", "Lion" } },
+                { "Age", 234 },
+                { "Name", "King Lion" }
+            };
+
+            var bson = document.ToBson();
+            var rehydrated = (Lion)BsonSerializer.Deserialize<IAnimal>(bson);
+            Assert.IsType<Lion>(rehydrated);
+
+            var json = rehydrated.ToJson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true });
+            var expected = "{ '_id' : ObjectId('000000000000000000000000'), '_t' : ['IAnimal', 'ICat', 'Lion'], 'Age' : 234, 'Name' : 'King Lion' }".Replace("'", "\"");
+            Assert.Equal(expected, json);
+            Assert.True(bson.SequenceEqual(rehydrated.ToBson<IAnimal>(args: new BsonSerializationArgs { SerializeIdFirst = true })));
+        }
+    }
+}

From 1e31c8ad2a4402e2363f60dd71a44554beb53e2e Mon Sep 17 00:00:00 2001
From: LucStr <25279790+LucStr@users.noreply.github.com>
Date: Sun, 14 Jan 2024 22:18:54 +0100
Subject: [PATCH 2/2] CSHARP-1907 Make interface discriminator convention
 default for interfaces.

---
 src/MongoDB.Bson/Serialization/BsonSerializer.cs         | 9 +++++++--
 .../Conventions/InterfaceDiscriminatorConvention.cs      | 2 +-
 2 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/src/MongoDB.Bson/Serialization/BsonSerializer.cs b/src/MongoDB.Bson/Serialization/BsonSerializer.cs
index 5c027786d87..b0b7c9ac9f2 100644
--- a/src/MongoDB.Bson/Serialization/BsonSerializer.cs
+++ b/src/MongoDB.Bson/Serialization/BsonSerializer.cs
@@ -391,8 +391,7 @@ public static IDiscriminatorConvention LookupDiscriminatorConvention(Type type)
                     }
                     else if (typeInfo.IsInterface)
                     {
-                        // TODO: should convention for interfaces be inherited from parent interfaces?
-                        convention = LookupDiscriminatorConvention(typeof(object));
+                        convention = CreateInterfaceDiscriminatorConvention(type);
                         RegisterDiscriminatorConvention(type, convention);
                     }
                     else
@@ -432,6 +431,12 @@ public static IDiscriminatorConvention LookupDiscriminatorConvention(Type type)
             }
         }
 
+        private static IDiscriminatorConvention CreateInterfaceDiscriminatorConvention(Type type)
+        {
+            var discriminatorConventionType = typeof(InterfaceDiscriminatorConvention<>).MakeGenericType(type);
+            return (IDiscriminatorConvention) Activator.CreateInstance(discriminatorConventionType, "_t");
+        }
+
         /// <summary>
         /// Looks up an IdGenerator.
         /// </summary>
diff --git a/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs b/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
index e1f2aee5f37..a3b2777bc21 100644
--- a/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
+++ b/src/MongoDB.Bson/Serialization/Conventions/InterfaceDiscriminatorConvention.cs
@@ -62,7 +62,7 @@ private void PrecomputeDiscriminators()
                 throw new ArgumentException("<TInterface> must be an interface", nameof(TInterface));
             }
 
-            var dependents = interfaceType.Assembly.GetTypes().Where(x => interfaceType.IsAssignableFrom(x) && x.IsClass);
+            var dependents = interfaceType.Assembly.GetTypes().Where(x => interfaceType.IsAssignableFrom(x));
 
             foreach (var dependent in dependents)
             {