Permalink
Browse files

Allow navigation properties to be typed as IEnumerable so long as bac…

…ked by ICollection

Issues #5771 and #1481
  • Loading branch information...
ajcvickers committed Jun 13, 2016
1 parent b7e0ee2 commit 78476005b30ff201d4523d24da090ce6c84fb248

Large diffs are not rendered by default.

Oops, something went wrong.
@@ -2189,7 +2189,7 @@ public static class EntityFrameworkQueryableExtensions
/// type of entity being queried (<typeparamref name="TEntity" />). If you wish to include additional types based on the navigation
/// properties of the type being included, then chain a call to
/// <see
/// cref="ThenInclude{TEntity, TPreviousProperty, TProperty}(IIncludableQueryable{TEntity, ICollection{TPreviousProperty}}, Expression{Func{TPreviousProperty, TProperty}})" />
/// cref="ThenInclude{TEntity, TPreviousProperty, TProperty}(IIncludableQueryable{TEntity, IEnumerable{TPreviousProperty}}, Expression{Func{TPreviousProperty, TProperty}})" />
/// after this call.
/// </summary>
/// <example>
@@ -2289,7 +2289,7 @@ public static class EntityFrameworkQueryableExtensions
/// A new query with the related data included.
/// </returns>
public static IIncludableQueryable<TEntity, TProperty> ThenInclude<TEntity, TPreviousProperty, TProperty>(
[NotNull] this IIncludableQueryable<TEntity, ICollection<TPreviousProperty>> source,
[NotNull] this IIncludableQueryable<TEntity, IEnumerable<TPreviousProperty>> source,
[NotNull] Expression<Func<TPreviousProperty, TProperty>> navigationPropertyPath)
where TEntity : class
=> new IncludableQueryable<TEntity, TProperty>(
@@ -38,26 +38,26 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation)
}
var property = navigation.GetPropertyInfo();
var elementType = property.PropertyType.TryGetElementType(typeof(ICollection<>));
var elementType = property.PropertyType.TryGetElementType(typeof(IEnumerable<>));
// TODO: Only ICollections supported; add support for enumerables with add/remove methods
// Issue #752
if (elementType == null)
{
throw new InvalidOperationException(
CoreStrings.NavigationBadType(
navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName, navigation.GetTargetType().Name));
navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.Name, navigation.GetTargetType().Name));
}
if (property.PropertyType.IsArray)
{
throw new InvalidOperationException(
CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.Name, property.PropertyType.FullName));
CoreStrings.NavigationArray(navigation.Name, navigation.DeclaringEntityType.DisplayName(), property.PropertyType.Name));
}
if (property.GetMethod == null)
{
throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.Name));
throw new InvalidOperationException(CoreStrings.NavigationNoGetter(navigation.Name, navigation.DeclaringEntityType.DisplayName()));
}
var boundMethod = _genericCreate.MakeGenericMethod(
@@ -69,7 +69,7 @@ public virtual IClrCollectionAccessor Create([NotNull] INavigation navigation)
[UsedImplicitly]
private static IClrCollectionAccessor CreateGeneric<TEntity, TCollection, TElement>(PropertyInfo property)
where TEntity : class
where TCollection : class, ICollection<TElement>
where TCollection : class, IEnumerable<TElement>
{
var getterDelegate = (Func<TEntity, TCollection>)property.GetMethod.CreateDelegate(typeof(Func<TEntity, TCollection>));
@@ -14,7 +14,7 @@ namespace Microsoft.EntityFrameworkCore.Metadata.Internal
/// </summary>
public class ClrICollectionAccessor<TEntity, TCollection, TElement> : IClrCollectionAccessor
where TEntity : class
where TCollection : class, ICollection<TElement>
where TCollection : class, IEnumerable<TElement>
{
private readonly string _propertyName;
private readonly Func<TEntity, TCollection> _getCollection;
@@ -89,10 +89,10 @@ public virtual object Create(IEnumerable<object> values)
if (_createCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType(
_propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName));
_propertyName, typeof(TEntity).Name, typeof(TCollection).Name));
}
var collection = _createCollection();
var collection = (ICollection<TElement>)_createCollection();
foreach (TElement value in values)
{
collection.Add(value);
@@ -107,25 +107,42 @@ public virtual object Create(IEnumerable<object> values)
/// </summary>
public virtual object GetOrCreate(object instance) => GetOrCreateCollection(instance);
private TCollection GetOrCreateCollection(object instance)
private ICollection<TElement> GetOrCreateCollection(object instance)
{
var collection = _getCollection((TEntity)instance);
var collection = GetCollection(instance);
if (collection == null)
{
if (_setCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).FullName));
throw new InvalidOperationException(CoreStrings.NavigationNoSetter(_propertyName, typeof(TEntity).Name));
}
if (_createAndSetCollection == null)
{
throw new InvalidOperationException(CoreStrings.NavigationCannotCreateType(
_propertyName, typeof(TEntity).FullName, typeof(TCollection).FullName));
_propertyName, typeof(TEntity).Name, typeof(TCollection).Name));
}
collection = _createAndSetCollection((TEntity)instance, _setCollection);
collection = (ICollection<TElement>)_createAndSetCollection((TEntity)instance, _setCollection);
}
return collection;
}
private ICollection<TElement> GetCollection(object instance)
{
var enumerable = _getCollection((TEntity)instance);
var collection = enumerable as ICollection<TElement>;
if (enumerable != null
&& collection == null)
{
throw new InvalidOperationException(
CoreStrings.NavigationBadType(
_propertyName, typeof(TEntity).Name, enumerable.GetType().Name, typeof(TElement).Name));
}
return collection;
}
@@ -135,7 +152,7 @@ private TCollection GetOrCreateCollection(object instance)
/// </summary>
public virtual bool Contains(object instance, object value)
{
var collection = _getCollection((TEntity)instance);
var collection = GetCollection((TEntity)instance);
return (collection != null) && collection.Contains((TElement)value);
}
@@ -145,6 +162,6 @@ public virtual bool Contains(object instance, object value)
/// directly from your code. This API may change or be removed in future releases.
/// </summary>
public virtual void Remove(object instance, object value)
=> _getCollection((TEntity)instance)?.Remove((TElement)value);
=> GetCollection((TEntity)instance)?.Remove((TElement)value);
}
}
@@ -597,7 +597,7 @@ public void Customer_collections_materialize_properly_3758()
var query4 = ctx.Customers.Select(c => c.Orders4);
Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).FullName, typeof(MyInvalidCollection3758<Order3758>).FullName),
Assert.Equal(CoreStrings.NavigationCannotCreateType("Orders4", typeof(Customer3758).Name, typeof(MyInvalidCollection3758<Order3758>).Name),
Assert.Throws<InvalidOperationException>(() => query4.ToList()).Message);
}
}
@@ -2,6 +2,7 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
@@ -14,7 +15,6 @@
// ReSharper disable ConvertToAutoPropertyWhenPossible
// ReSharper disable ConvertToAutoPropertyWithPrivateSetter
// ReSharper disable UnusedMember.Local
namespace Microsoft.EntityFrameworkCore.Tests.Metadata.Internal
{
public class ClrCollectionAccessorFactoryTest
@@ -30,6 +30,12 @@ public void Navigation_is_returned_if_it_implements_IClrCollectionAccessor()
Assert.Same(accessorMock.Object, source.Create(navigationMock.Object));
}
[Fact]
public void Delegate_accessor_is_returned_for_IEnumerable_navigation()
{
AccessorTest("AsIEnumerable", e => e.AsIEnumerable);
}
[Fact]
public void Delegate_accessor_is_returned_for_ICollection_navigation()
{
@@ -135,19 +141,52 @@ public void Creating_accessor_for_navigation_without_getter_throws()
var navigation = CreateNavigation("WithNoGetter");
Assert.Equal(
CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).FullName),
CoreStrings.NavigationNoGetter("WithNoGetter", typeof(MyEntity).Name),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
}
[Fact]
public void Creating_accessor_for_enumerable_navigation_throws()
public void Add_for_enumerable_backed_by_non_collection_throws()
{
var navigation = CreateNavigation("AsIEnumerable");
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Add(e, v));
}
[Fact]
public void AddRange_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.AddRange(e, new[] { v }));
}
[Fact]
public void Contains_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Contains(e, v));
}
[Fact]
public void Remove_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.Remove(e, v));
}
[Fact]
public void GetOrCreate_for_enumerable_backed_by_non_collection_throws()
{
Enumerable_backed_by_non_collection_throws((a, e, v) => a.GetOrCreate(e));
}
private void Enumerable_backed_by_non_collection_throws(Action<IClrCollectionAccessor, MyEntity, MyOtherEntity> test)
{
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsIEnumerableNotCollection"));
var entity = new MyEntity();
var value = new MyOtherEntity();
entity.InitializeCollections();
Assert.Equal(
CoreStrings.NavigationBadType(
"AsIEnumerable", typeof(MyEntity).FullName, typeof(IEnumerable<MyOtherEntity>).FullName, typeof(MyOtherEntity).FullName),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
"AsIEnumerableNotCollection", typeof(MyEntity).Name, typeof(MyEnumerable).Name, typeof(MyOtherEntity).Name),
Assert.Throws<InvalidOperationException>(() => test(accessor, entity, value)).Message);
}
[Fact]
@@ -156,7 +195,7 @@ public void Creating_accessor_for_array_navigation_throws()
var navigation = CreateNavigation("AsArray");
Assert.Equal(
CoreStrings.NavigationArray("AsArray", typeof(MyEntity).FullName, typeof(MyOtherEntity[]).FullName),
CoreStrings.NavigationArray("AsArray", typeof(MyEntity).Name, typeof(MyOtherEntity[]).Name),
Assert.Throws<InvalidOperationException>(() => new ClrCollectionAccessorFactory().Create(navigation)).Message);
}
@@ -166,7 +205,7 @@ public void Initialization_for_navigation_without_setter_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("WithNoSetter"));
Assert.Equal(
CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).FullName),
CoreStrings.NavigationNoSetter("WithNoSetter", typeof(MyEntity).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}
@@ -176,7 +215,7 @@ public void Initialization_for_navigation_with_private_constructor_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyPrivateCollection"));
Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).FullName, typeof(MyPrivateCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyPrivateCollection", typeof(MyEntity).Name, typeof(MyPrivateCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}
@@ -186,7 +225,7 @@ public void Initialization_for_navigation_with_internal_constructor_throws()
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyInternalCollection"));
Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).FullName, typeof(MyInternalCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyInternalCollection", typeof(MyEntity).Name, typeof(MyInternalCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}
@@ -196,7 +235,7 @@ public void Initialization_for_navigation_without_parameterless_constructor_thro
var accessor = new ClrCollectionAccessorFactory().Create(CreateNavigation("AsMyUnavailableCollection"));
Assert.Equal(
CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).FullName, typeof(MyUnavailableCollection).FullName),
CoreStrings.NavigationCannotCreateType("AsMyUnavailableCollection", typeof(MyEntity).Name, typeof(MyUnavailableCollection).Name),
Assert.Throws<InvalidOperationException>(() => accessor.Add(new MyEntity(), new MyOtherEntity())).Message);
}
@@ -223,6 +262,7 @@ private class MyEntity
// ReSharper disable once NotAccessedField.Local
private ICollection<MyOtherEntity> _withNoGetter;
private IEnumerable<MyOtherEntity> _enumerable;
private IEnumerable<MyOtherEntity> _enumerableNotCollection;
private MyOtherEntity[] _array;
private MyPrivateCollection _privateCollection;
private MyInternalCollection _internalCollection;
@@ -237,6 +277,7 @@ public void InitializeCollections()
_withNoSetter = new HashSet<MyOtherEntity>();
_withNoGetter = new HashSet<MyOtherEntity>();
_enumerable = new HashSet<MyOtherEntity>();
_enumerableNotCollection = new MyEnumerable();
_array = new MyOtherEntity[0];
_privateCollection = MyPrivateCollection.Create();
_internalCollection = new MyInternalCollection();
@@ -280,6 +321,12 @@ internal IEnumerable<MyOtherEntity> AsIEnumerable
set { _enumerable = value; }
}
internal IEnumerable<MyOtherEntity> AsIEnumerableNotCollection
{
get { return _enumerableNotCollection; }
set { _enumerableNotCollection = value; }
}
internal MyOtherEntity[] AsArray
{
get { return _array; }
@@ -339,5 +386,15 @@ public MyUnavailableCollection(bool _)
{
}
}
private class MyEnumerable : IEnumerable<MyOtherEntity>
{
public IEnumerator<MyOtherEntity> GetEnumerator()
{
throw new NotImplementedException();
}
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
}
}
}

0 comments on commit 7847600

Please sign in to comment.