Skip to content

Commit

Permalink
Introduce an egregious source-compatibility hack for `AsyncSequence.f…
Browse files Browse the repository at this point in the history
…latMap`

Allow `AsyncSequence.flatMap` to be defined with "incorrect" availability,
meaning that the function can refer to the `Failure` associated type
in its where clause even though the function is back-deployed to
before the `Failure` associated type was introduced.

We believe this is safe, and that this notion can be generalized to any
use of an associated type in a same-type constraint of a function
(yes, it sounds weird), but for now introduce a narrower hack to see
how things work in practice and whether it addresses all of the
source-compatibility concerns we've uncovered.
  • Loading branch information
DougGregor committed Jan 26, 2024
1 parent a3ae340 commit 4c990dc
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
26 changes: 26 additions & 0 deletions lib/Sema/TypeCheckAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2014,6 +2014,25 @@ class DeclAvailabilityChecker : public DeclVisitor<DeclAvailabilityChecker> {
Where.withReason(reason), flags);
}

/// Identify the AsyncSequence.flatMap set of functions from the
/// _Concurrency module.
static bool isAsyncSequenceFlatMap(const GenericContext *gc) {
auto func = dyn_cast<FuncDecl>(gc);
if (!func)
return false;

auto proto = func->getDeclContext()->getSelfProtocolDecl();
if (!proto ||
!proto->isSpecificProtocol(KnownProtocolKind::AsyncSequence))
return false;

ASTContext &ctx = proto->getASTContext();
if (func->getModuleContext()->getName() != ctx.Id_Concurrency)
return false;

return !func->getName().isSimpleName("flatMap");
}

void checkGenericParams(const GenericContext *ownerCtx,
const ValueDecl *ownerDecl) {
if (!ownerCtx->isGenericContext())
Expand All @@ -2031,6 +2050,13 @@ class DeclAvailabilityChecker : public DeclVisitor<DeclAvailabilityChecker> {
}

if (ownerCtx->getTrailingWhereClause()) {
// Ignore the where clause for AsyncSequence.flatMap from the
// _Concurrency module. This is an egregious hack to allow us to
// use overloading tricks to retain the behavior previously
// afforded by rethrowing conformances.
if (isAsyncSequenceFlatMap(ownerCtx))
return;

forAllRequirementTypes(WhereClauseOwner(
const_cast<GenericContext *>(ownerCtx)),
[&](Type type, TypeRepr *typeRepr) {
Expand Down
3 changes: 0 additions & 3 deletions stdlib/public/Concurrency/AsyncFlatMapSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ extension AsyncSequence {
@preconcurrency
@_alwaysEmitIntoClient
@inlinable
@available(SwiftStdlib 5.11, *)
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
Expand Down Expand Up @@ -111,7 +110,6 @@ extension AsyncSequence {
@preconcurrency
@_alwaysEmitIntoClient
@inlinable
@available(SwiftStdlib 5.11, *)
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
Expand Down Expand Up @@ -148,7 +146,6 @@ extension AsyncSequence {
@preconcurrency
@_alwaysEmitIntoClient
@inlinable
@available(SwiftStdlib 5.11, *)
public __consuming func flatMap<SegmentOfResult: AsyncSequence>(
_ transform: @Sendable @escaping (Element) async -> SegmentOfResult
) -> AsyncFlatMapSequence<Self, SegmentOfResult>
Expand Down
21 changes: 21 additions & 0 deletions test/Concurrency/async_sequence_flatmap_overloading.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: %target-swift-frontend -typecheck %s -verify

// REQUIRES: concurrency

@available(SwiftStdlib 5.1, *)
struct MyAsyncSequence<Element>: AsyncSequence {
struct AsyncIterator: AsyncIteratorProtocol {
mutating func next() -> Element? { nil }
}

func makeAsyncIterator() -> AsyncIterator { .init() }
}

@available(SwiftStdlib 5.1, *)
func testMe(ms: MyAsyncSequence<String>) {
let flatMS = ms.flatMap { string in
return MyAsyncSequence<[Character]>()
}

let _: AsyncFlatMapSequence<MyAsyncSequence<String>, MyAsyncSequence<[Character]>> = flatMS
}

0 comments on commit 4c990dc

Please sign in to comment.