Skip to content

Commit

Permalink
Prefetch TotalCount (#2480)
Browse files Browse the repository at this point in the history
Co-authored-by: Michael Staib <michael@chillicream.com>
  • Loading branch information
PascalSenn and michaelstaib committed Oct 25, 2020
1 parent 650c80b commit ec06012
Show file tree
Hide file tree
Showing 9 changed files with 207 additions and 18 deletions.
Expand Up @@ -4,7 +4,7 @@
<Nullable>enable</Nullable>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup>
<PackageId>HotChocolate.Types.OffsetPagination</PackageId>
<AssemblyName>HotChocolate.Types.OffsetPagination</AssemblyName>
Expand Down Expand Up @@ -36,6 +36,6 @@
<AutoGen>True</AutoGen>
<DependentUpon>OffsetResources.resx</DependentUpon>
</Compile>
</ItemGroup>

</ItemGroup>
</Project>
Expand Up @@ -5,9 +5,9 @@
namespace HotChocolate.Types.Pagination
{
/// <summary>
/// Represents an offset paging handler, which can be implemented to
/// create optimized pagination for data sources.
///
/// Represents an offset paging handler, which can be implemented to
/// create optimized pagination for data sources.
///
/// The paging handler will be used by the paging middleware to slice the data.
/// </summary>
public abstract class OffsetPagingHandler : IPagingHandler
Expand All @@ -16,6 +16,7 @@ protected OffsetPagingHandler(PagingOptions options)
{
DefaultPageSize = options.DefaultPageSize ?? PagingDefaults.DefaultPageSize;
MaxPageSize = options.MaxPageSize ?? PagingDefaults.MaxPageSize;
IncludeTotalCount = options.IncludeTotalCount ?? PagingDefaults.IncludeTotalCount;

if (MaxPageSize < DefaultPageSize)
{
Expand All @@ -35,8 +36,13 @@ protected OffsetPagingHandler(PagingOptions options)
protected int MaxPageSize { get; }

/// <summary>
/// Ensures that the arguments passed in by the user are valid and
/// do not try to consume more items per page as specified by
/// Result should include total count.
/// </summary>
protected bool IncludeTotalCount { get; }

/// <summary>
/// Ensures that the arguments passed in by the user are valid and
/// do not try to consume more items per page as specified by
/// <see cref="MaxPageSize"/>.
/// </summary>
/// <param name="context">
Expand Down Expand Up @@ -76,7 +82,7 @@ public void ValidateContext(IResolverContext context)
/// The paging arguments provided by the user.
/// </param>
/// <returns>
/// The <see cref="CollectionSegment"/> representing
/// The <see cref="CollectionSegment"/> representing
/// the slice of items belonging to the requested page.
/// </returns>
protected abstract ValueTask<CollectionSegment> SliceAsync(
Expand Down
@@ -1,3 +1,4 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
Expand Down Expand Up @@ -52,10 +53,41 @@ await ExecuteQueryableAsync(queryable, context.RequestAborted)
items.RemoveAt(arguments.Take);
}

return new CollectionSegment((IReadOnlyCollection<object>)items, pageInfo, CountAsync);
Func<CancellationToken, ValueTask<int>> getTotalCount =
ct => throw new InvalidOperationException();

async ValueTask<int> CountAsync(CancellationToken cancellationToken) =>
await Task.Run(original.Count, cancellationToken).ConfigureAwait(false);
// TotalCount is one of the heaviest operations. It is only necessary to load totalCount
// when it is enabled (IncludeTotalCount) and when it is contained in the selection set.
if (IncludeTotalCount &&
context.Field.Type is ObjectType objectType &&
context.FieldSelection.SelectionSet is {} selectionSet)
{
IReadOnlyList<IFieldSelection> selections = context
.GetSelections(objectType, selectionSet, true);

var includeTotalCount = false;
for (var i = 0; i < selections.Count; i++)
{
if (selections[i].Field.Name.Value is "totalCount")
{
includeTotalCount = true;
break;
}
}

// When totalCount is included in the selection set we prefetch it, then capture the
// count in a variable, to pass it into the clojure
if (includeTotalCount)
{
var captureCount = original.Count();
getTotalCount = ct => new ValueTask<int>(captureCount);
}
}

return new CollectionSegment(
(IReadOnlyCollection<object>)items,
pageInfo,
getTotalCount);
}

protected virtual async ValueTask<List<TItemType>> ExecuteQueryableAsync(
Expand Down Expand Up @@ -90,7 +122,6 @@ await foreach (TItemType item in enumerable.WithCancellation(cancellationToken)
}

return list;

}
}
}
@@ -1,6 +1,5 @@
using System;
using System.Reflection;
using HotChocolate.Data.Properties;
using HotChocolate.Types;
using HotChocolate.Types.Descriptors;
using Microsoft.EntityFrameworkCore;
Expand All @@ -12,7 +11,8 @@ public class UseDbContextAttribute : ObjectFieldDescriptorAttribute
{
private static readonly MethodInfo _useDbContext =
typeof(EntityFrameworkObjectFieldDescriptorExtensions)
.GetMethod(nameof(EntityFrameworkObjectFieldDescriptorExtensions.UseDbContext),
.GetMethod(
nameof(EntityFrameworkObjectFieldDescriptorExtensions.UseDbContext),
BindingFlags.Public | BindingFlags.Static)!;

private readonly Type _dbContext;
Expand Down
Expand Up @@ -8,6 +8,7 @@
<ItemGroup>
<ProjectReference Include="..\..\..\Core\test\Types.Tests\HotChocolate.Types.Tests.csproj" />
<ProjectReference Include="..\..\src\EntityFramework\HotChocolate.Data.EntityFramework.csproj" />
<ProjectReference Include="..\..\src\Data\HotChocolate.Data.csproj" />
</ItemGroup>

<ItemGroup>
Expand Down
@@ -1,5 +1,6 @@
using System.Linq;
using System.Threading.Tasks;
using HotChocolate.Types;
using Microsoft.EntityFrameworkCore;

namespace HotChocolate.Data
Expand All @@ -13,6 +14,13 @@ public class Query
[UseDbContext(typeof(BookContext))]
public async Task<Author> GetAuthor([ScopedService]BookContext context) =>
await context.Authors.FirstOrDefaultAsync();

[UseDbContext(typeof(BookContext))]
[UseOffsetPaging(IncludeTotalCount = true)]
[UseFiltering]
[UseSorting]
public IQueryable<Author> GetAuthorOffsetPaging([ScopedService]BookContext context) =>
context.Authors;
}

public class InvalidQuery
Expand Down
Expand Up @@ -18,8 +18,11 @@ public async Task Execute_Queryable()
IServiceProvider services =
new ServiceCollection()
.AddPooledDbContextFactory<BookContext>(
b => b.UseInMemoryDatabase("Data Source=books.db"))
b => b.UseInMemoryDatabase("Data Source=books1.db"))
.AddGraphQL()
.AddFiltering()
.AddSorting()
.AddProjections()
.AddQueryType<Query>()
.Services
.BuildServiceProvider();
Expand All @@ -44,15 +47,115 @@ await using (BookContext context = contextFactory.CreateDbContext())
result.ToJson().MatchSnapshot();
}

[Fact]
public async Task Execute_Queryable_OffsetPaging_TotalCount()
{
// arrange
IServiceProvider services =
new ServiceCollection()
.AddPooledDbContextFactory<BookContext>(
b => b.UseInMemoryDatabase("Data Source=books2.db"))
.AddGraphQL()
.AddFiltering()
.AddSorting()
.AddProjections()
.AddQueryType<Query>()
.Services
.BuildServiceProvider();

IRequestExecutor executor =
await services.GetRequiredService<IRequestExecutorResolver>()
.GetRequestExecutorAsync();

IDbContextFactory<BookContext> contextFactory =
services.GetRequiredService<IDbContextFactory<BookContext>>();

await using (BookContext context = contextFactory.CreateDbContext())
{
await context.Authors.AddAsync(new Author { Name = "foo" });
await context.Authors.AddAsync(new Author { Name = "bar" });
await context.SaveChangesAsync();
}

// act
IExecutionResult result = await executor.ExecuteAsync(
@"query Test {
authorOffsetPaging {
items {
name
}
pageInfo {
hasNextPage
hasPreviousPage
}
totalCount
}
}");

// assert
result.ToJson().MatchSnapshot();
}

[Fact]
public async Task Execute_Queryable_OffsetPaging()
{
// arrange
IServiceProvider services =
new ServiceCollection()
.AddPooledDbContextFactory<BookContext>(
b => b.UseInMemoryDatabase("Data Source=books3.db"))
.AddGraphQL()
.AddFiltering()
.AddSorting()
.AddProjections()
.AddQueryType<Query>()
.Services
.BuildServiceProvider();

IRequestExecutor executor =
await services.GetRequiredService<IRequestExecutorResolver>()
.GetRequestExecutorAsync();

IDbContextFactory<BookContext> contextFactory =
services.GetRequiredService<IDbContextFactory<BookContext>>();

await using (BookContext context = contextFactory.CreateDbContext())
{
await context.Authors.AddAsync(new Author { Name = "foo" });
await context.Authors.AddAsync(new Author { Name = "bar" });
await context.SaveChangesAsync();
}

// act
IExecutionResult result = await executor.ExecuteAsync(
@"query Test {
authorOffsetPaging {
items {
name
}
pageInfo {
hasNextPage
hasPreviousPage
}
}
}");

// assert
result.ToJson().MatchSnapshot();
}

[Fact]
public async Task Execute_Single()
{
// arrange
IServiceProvider services =
new ServiceCollection()
.AddPooledDbContextFactory<BookContext>(
b => b.UseInMemoryDatabase("Data Source=books.db"))
b => b.UseInMemoryDatabase("Data Source=books4.db"))
.AddGraphQL()
.AddFiltering()
.AddSorting()
.AddProjections()
.AddQueryType<Query>()
.Services
.BuildServiceProvider();
Expand Down Expand Up @@ -85,8 +188,11 @@ public async Task DbContextType_Is_Object()
async Task CreateSchema() =>
await new ServiceCollection()
.AddPooledDbContextFactory<BookContext>(
b => b.UseInMemoryDatabase("Data Source=books.db"))
b => b.UseInMemoryDatabase("Data Source=books5.db"))
.AddGraphQL()
.AddFiltering()
.AddSorting()
.AddProjections()
.AddQueryType<InvalidQuery>()
.BuildSchemaAsync();

Expand Down
@@ -0,0 +1,18 @@
{
"data": {
"authorOffsetPaging": {
"items": [
{
"name": "foo"
},
{
"name": "bar"
}
],
"pageInfo": {
"hasNextPage": false,
"hasPreviousPage": false
}
}
}
}
@@ -0,0 +1,19 @@
{
"data": {
"authorOffsetPaging": {
"items": [
{
"name": "foo"
},
{
"name": "bar"
}
],
"pageInfo": {
"hasNextPage": false,
"hasPreviousPage": false
},
"totalCount": 2
}
}
}

0 comments on commit ec06012

Please sign in to comment.