Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@DisallowLambdaCapture checker example #22

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
1 change: 1 addition & 0 deletions arrow-reflect-annotations/src/main/kotlin/MetaModule.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ interface MetaModule: Module {
val decorator: Decorator
val pure: Pure
val immutable: Immutable
val disallowLambdaCapture: DisallowLambdaCapture
}

7 changes: 4 additions & 3 deletions arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package arrow.meta

import org.jetbrains.kotlin.fir.FirAnnotationContainer
import org.jetbrains.kotlin.fir.FirLabel
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.contracts.*
import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.expressions.*
Expand All @@ -27,7 +28,7 @@ annotation class Meta {
fun <In, Out> intercept(args: List<In>, func: (List<In>) -> Out): Out

override fun FirMetaCheckerContext.functionCall(functionCall: FirFunctionCall): FirStatement {
val newCall = if (isDecorated(functionCall)) {
val newCall = if (session.isDecorated(functionCall)) {
//language=kotlin
val call: FirCall = decoratedCall(functionCall)
call
Expand All @@ -36,14 +37,14 @@ annotation class Meta {
}

@OptIn(SymbolInternals::class)
private fun isDecorated(newElement: FirFunctionCall): Boolean =
private fun FirSession.isDecorated(newElement: FirFunctionCall): Boolean =
newElement.toResolvedCallableSymbol()?.fir?.annotations?.hasAnnotation(
ClassId.topLevel(
FqName(
annotation.java.canonicalName
)
)
) == true
, this) == true

private fun FirMetaContext.decoratedCall(
newElement: FirFunctionCall
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package arrow.meta.samples

import arrow.meta.Diagnostics
import arrow.meta.FirMetaCheckerContext
import arrow.meta.Meta
import arrow.meta.samples.DisallowLambdaCaptureErrors.UnsafeCaptureDetected
import org.jetbrains.kotlin.fir.FirSession
import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction
import org.jetbrains.kotlin.fir.declarations.InlineStatus
import org.jetbrains.kotlin.fir.declarations.findArgumentByName
import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId
import org.jetbrains.kotlin.fir.expressions.FirAnnotation
import org.jetbrains.kotlin.fir.expressions.FirConstExpression
import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
import org.jetbrains.kotlin.fir.expressions.toResolvedCallableSymbol
import org.jetbrains.kotlin.name.ClassId
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.types.ConstantValueKind

object DisallowLambdaCaptureErrors : Diagnostics.Error {
val UnsafeCaptureDetected by error1()
}

@Meta
@Target(AnnotationTarget.FUNCTION)
annotation class DisallowLambdaCapture(val msg: String = "") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this not be more precise as msg: String??

companion object : Meta.Checker.Expression<FirFunctionCall>,
Diagnostics(UnsafeCaptureDetected) {

val annotation = DisallowLambdaCapture::class.java

override fun FirMetaCheckerContext.check(expression: FirFunctionCall) {
val nameArg = expression
.disallowLambdaCaptureAnnotation(session)?.findArgumentByName(Name.identifier(DisallowLambdaCapture::msg.name))
val userMsg =
if (nameArg is FirConstExpression<*> && nameArg.kind == ConstantValueKind.String) nameArg.value as? String
else null
scopeDeclarations.filterIsInstance<FirAnonymousFunction>().forEach { scope ->
if (scope.inlineStatus != InlineStatus.Inline) {
expression.report(
UnsafeCaptureDetected,
userMsg
?: "detected call to member @DisallowLambdaCapture `${+expression}` in non-inline anonymous function"
)
}
}
}

private fun FirFunctionCall.disallowLambdaCaptureAnnotation(session: FirSession): FirAnnotation? =
toResolvedCallableSymbol()?.fir?.getAnnotationByClassId(
ClassId(
FqName(annotation.`package`.name),
Name.identifier(annotation.simpleName)
),
session
)
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -58,25 +58,42 @@ class FirMetaAdditionalCheckersExtension(
}

private inline fun <reified E : FirElement> invokeChecker(
superType: KClass<*>,
element: E,
session: FirSession,
context: CheckerContext,
reporter: DiagnosticReporter
superType: KClass<*>,
element: E,
session: FirSession,
context: CheckerContext,
reporter: DiagnosticReporter
) {
if (element is FirAnnotationContainer && element.isMetaAnnotated(session)) {
val annotations = element.metaAnnotations(session)
val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter)
invokeMeta<E, Unit>(
false,
metaContext,
annotations,
superType = superType,
methodName = "check",
element
)
}
if ((element is FirAnnotationContainer && element.isMetaAnnotated(session)) || (element is FirFunctionCall && element.isCallToAnnotatedFunction(
session
))
) {
val annotations =
when (element) {
is FirFunctionCall ->
element.metaAnnotations(session) + element.toResolvedCallableSymbol()?.fir?.metaAnnotations(
session
).orEmpty()
is FirAnnotationContainer -> element.metaAnnotations(session)
else -> emptyList()
}
val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter)
invokeMeta<E, Unit>(
false,
metaContext,
annotations,
superType = superType,
methodName = "check",
element
)
}
}

private inline fun <reified E : FirFunctionCall> E.isCallToAnnotatedFunction(
session: FirSession
): Boolean {
return toResolvedCallableSymbol()?.fir?.isMetaAnnotated(session) == true
}

override val typeCheckers: TypeCheckers
get() = super.typeCheckers
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
FILE: capture_test.kt
package foo.bar

public abstract interface Raise<in E> : R|kotlin/Any| {
@R|arrow/meta/samples/DisallowLambdaCapture|(msg = String(It's unsafe to capture `raise` inside non-inline anonymous functions)) public abstract fun raise(e: R|E|): R|kotlin/Nothing|

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun shouldNotCapture(): R|() -> kotlin/Unit| {
^shouldNotCapture fun <anonymous>(): R|kotlin/Unit| <inline=Unknown> {
this@R|foo/bar/shouldNotCapture|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun inlineCaptureOk(): R|kotlin/Unit| {
R|kotlin/collections/listOf|<R|kotlin/Int|>(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|<R|kotlin/Int|, R|kotlin/Nothing|>(<L> = map@fun <anonymous>(it: R|kotlin/Int|): R|kotlin/Nothing| <inline=Inline, kind=UNKNOWN> {
this@R|foo/bar/inlineCaptureOk|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
)
}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun leakedNotOk(): R|() -> kotlin/Unit| {
^leakedNotOk fun <anonymous>(): R|kotlin/Unit| <inline=Unknown> {
R|kotlin/collections/listOf|<R|kotlin/Int|>(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|<R|kotlin/Int|, R|kotlin/Nothing|>(<L> = map@fun <anonymous>(it: R|kotlin/Int|): R|kotlin/Nothing| <inline=Inline, kind=UNKNOWN> {
this@R|foo/bar/leakedNotOk|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
)
}

}
context(R|foo/bar/Raise<kotlin/String>|)
public final fun ok(): R|kotlin/Unit| {
this@R|foo/bar/ok|.R|SubstitutionOverride<foo/bar/Raise.raise: R|kotlin/Nothing|>|(String(boom))
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
package foo.bar

import arrow.meta.samples.DisallowLambdaCapture

interface Raise<in E> {
@DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions") fun raise(e: E): Nothing
}

context(Raise<String>)
fun shouldNotCapture(): () -> Unit {
return { <!UnsafeCaptureDetected!>raise("boom")<!> }
}

context(Raise<String>)
fun inlineCaptureOk(): Unit {
listOf(1, 2, 3).map { raise("boom") }
}

context(Raise<String>)
fun leakedNotOk(): () -> Unit = {
listOf(1, 2, 3).map { <!UnsafeCaptureDetected!>raise("boom")<!> }
}

context(Raise<String>)
fun ok(): Unit {
raise("boom")
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@ public void testAllFilesPresentInDiagnostics() throws Exception {
KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("src/testData/diagnostics"), Pattern.compile("^(.+)\\.kt$"), null, true);
}

@Test
@TestMetadata("capture_test.kt")
public void testCapture_test() throws Exception {
runTest("src/testData/diagnostics/capture_test.kt");
}

@Test
@TestMetadata("immutable_test.kt")
public void testImmutable_test() throws Exception {
Expand Down
40 changes: 33 additions & 7 deletions sandbox/src/main/kotlin/Sample.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
package example
package foo.bar

import arrow.meta.samples.Product
import arrow.meta.samples.DisallowLambdaCapture
import kotlin.contracts.*

@Product
data class Sample(val name: String, val age: Int)
interface Raise<in E> {
@DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions")
fun raise(e: E): Nothing
}

context(Raise<String>)
fun shouldNotCapture(): () -> Unit {
return { raise("boom") }
}

context(Raise<String>)
fun inlineCaptureOk(): Unit {
listOf(1, 2, 3).map { raise("boom") }
}

@OptIn(ExperimentalContracts::class)
fun exactlyOne(f: () -> Unit): Unit {
contract {
callsInPlace(f, InvocationKind.EXACTLY_ONCE)
}
}

@OptIn(ExperimentalContracts::class)
fun exactlyOnce(f: () -> Unit): Unit {
contract {
callsInPlace(f, InvocationKind.EXACTLY_ONCE)
}
Comment on lines +30 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this is checked in DisallowLambdaCapture.kt? 🤔

}

fun main() {
val properties = Sample("j", 12).product()
println(properties)
context(Raise<String>)
fun ok(): () -> Unit = {
exactlyOnce { raise("boom") }
}