diff --git a/arrow-reflect-annotations/src/main/kotlin/MetaModule.kt b/arrow-reflect-annotations/src/main/kotlin/MetaModule.kt index d872a73..f368d04 100644 --- a/arrow-reflect-annotations/src/main/kotlin/MetaModule.kt +++ b/arrow-reflect-annotations/src/main/kotlin/MetaModule.kt @@ -10,5 +10,6 @@ interface MetaModule: Module { val decorator: Decorator val pure: Pure val immutable: Immutable + val disallowLambdaCapture: DisallowLambdaCapture } diff --git a/arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt b/arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt index 18a65ce..cbdb17c 100644 --- a/arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt +++ b/arrow-reflect-annotations/src/main/kotlin/arrow/meta/Meta.kt @@ -2,6 +2,7 @@ package arrow.meta import org.jetbrains.kotlin.fir.FirAnnotationContainer import org.jetbrains.kotlin.fir.FirLabel +import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.contracts.* import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.expressions.* @@ -28,7 +29,7 @@ annotation class Meta { fun intercept(args: List, func: (List) -> Out): Out override fun FirMetaCheckerContext.functionCall(functionCall: FirFunctionCall): FirStatement { - val newCall = if (isDecorated(functionCall, session)) { + val newCall = if (session.isDecorated(functionCall)) { //language=kotlin val call: FirCall = decoratedCall(functionCall) call @@ -37,15 +38,14 @@ annotation class Meta { } @OptIn(SymbolInternals::class) - private fun isDecorated(newElement: FirFunctionCall, session: FirSession): Boolean = + private fun FirSession.isDecorated(newElement: FirFunctionCall): Boolean = newElement.toResolvedCallableSymbol()?.fir?.annotations?.hasAnnotation( classId = ClassId.topLevel( FqName( annotation.java.canonicalName ) - ), - session = session - ) == true + ) + , this) == true private fun FirMetaContext.decoratedCall( newElement: FirFunctionCall diff --git a/arrow-reflect-annotations/src/main/kotlin/arrow/meta/samples/DisallowLambdaCapture.kt b/arrow-reflect-annotations/src/main/kotlin/arrow/meta/samples/DisallowLambdaCapture.kt new file mode 100644 index 0000000..431f467 --- /dev/null +++ b/arrow-reflect-annotations/src/main/kotlin/arrow/meta/samples/DisallowLambdaCapture.kt @@ -0,0 +1,60 @@ +package arrow.meta.samples + +import arrow.meta.Diagnostics +import arrow.meta.FirMetaCheckerContext +import arrow.meta.Meta +import arrow.meta.samples.DisallowLambdaCaptureErrors.UnsafeCaptureDetected +import org.jetbrains.kotlin.fir.FirSession +import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction +import org.jetbrains.kotlin.fir.declarations.InlineStatus +import org.jetbrains.kotlin.fir.declarations.findArgumentByName +import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId +import org.jetbrains.kotlin.fir.expressions.FirAnnotation +import org.jetbrains.kotlin.fir.expressions.FirConstExpression +import org.jetbrains.kotlin.fir.expressions.FirFunctionCall +import org.jetbrains.kotlin.fir.expressions.toResolvedCallableSymbol +import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlin.name.FqName +import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.types.ConstantValueKind + +object DisallowLambdaCaptureErrors : Diagnostics.Error { + val UnsafeCaptureDetected by error1() +} + +@Meta +@Target(AnnotationTarget.FUNCTION) +annotation class DisallowLambdaCapture(val msg: String = "") { + companion object : Meta.Checker.Expression, + Diagnostics(UnsafeCaptureDetected) { + + val annotation = DisallowLambdaCapture::class.java + + override fun FirMetaCheckerContext.check(expression: FirFunctionCall) { + val nameArg = expression + .disallowLambdaCaptureAnnotation(session)?.findArgumentByName(Name.identifier(DisallowLambdaCapture::msg.name)) + val userMsg = + if (nameArg is FirConstExpression<*> && nameArg.kind == ConstantValueKind.String) nameArg.value as? String + else null + scopeDeclarations.filterIsInstance().forEach { scope -> + if (scope.inlineStatus != InlineStatus.Inline) { + expression.report( + UnsafeCaptureDetected, + userMsg + ?: "detected call to member @DisallowLambdaCapture `${+expression}` in non-inline anonymous function" + ) + } + } + } + + private fun FirFunctionCall.disallowLambdaCaptureAnnotation(session: FirSession): FirAnnotation? = + toResolvedCallableSymbol()?.fir?.getAnnotationByClassId( + ClassId( + FqName(annotation.`package`.name), + Name.identifier(annotation.simpleName) + ), + session + ) + } +} + diff --git a/arrow-reflect-compiler-plugin/src/main/kotlin/arrow/reflect/compiler/plugin/fir/checkers/FirMetaAdditionalCheckersExtension.kt b/arrow-reflect-compiler-plugin/src/main/kotlin/arrow/reflect/compiler/plugin/fir/checkers/FirMetaAdditionalCheckersExtension.kt index 5d0a35c..b6abb4f 100644 --- a/arrow-reflect-compiler-plugin/src/main/kotlin/arrow/reflect/compiler/plugin/fir/checkers/FirMetaAdditionalCheckersExtension.kt +++ b/arrow-reflect-compiler-plugin/src/main/kotlin/arrow/reflect/compiler/plugin/fir/checkers/FirMetaAdditionalCheckersExtension.kt @@ -60,25 +60,42 @@ class FirMetaAdditionalCheckersExtension( } private inline fun invokeChecker( - superType: KClass<*>, - element: E, - session: FirSession, - context: CheckerContext, - reporter: DiagnosticReporter + superType: KClass<*>, + element: E, + session: FirSession, + context: CheckerContext, + reporter: DiagnosticReporter ) { - if (element is FirAnnotationContainer && element.isMetaAnnotated(session)) { - val annotations = element.metaAnnotations(session) - val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter) - invokeMeta( - false, - metaContext, - annotations, - superType = superType, - methodName = "check", - element - ) - } + if ((element is FirAnnotationContainer && element.isMetaAnnotated(session)) || (element is FirFunctionCall && element.isCallToAnnotatedFunction( + session + )) + ) { + val annotations = + when (element) { + is FirFunctionCall -> + element.metaAnnotations(session) + element.toResolvedCallableSymbol()?.fir?.metaAnnotations( + session + ).orEmpty() + is FirAnnotationContainer -> element.metaAnnotations(session) + else -> emptyList() + } + val metaContext = FirMetaCheckerContext(templateCompiler, session, context, reporter) + invokeMeta( + false, + metaContext, + annotations, + superType = superType, + methodName = "check", + element + ) } + } + + private inline fun E.isCallToAnnotatedFunction( + session: FirSession + ): Boolean { + return toResolvedCallableSymbol()?.fir?.isMetaAnnotated(session) == true + } override val typeCheckers: TypeCheckers get() = super.typeCheckers diff --git a/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.fir.txt b/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.fir.txt new file mode 100644 index 0000000..1712eb1 --- /dev/null +++ b/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.fir.txt @@ -0,0 +1,35 @@ +FILE: capture_test.kt + package foo.bar + + public abstract interface Raise : R|kotlin/Any| { + @R|arrow/meta/samples/DisallowLambdaCapture|(msg = String(It's unsafe to capture `raise` inside non-inline anonymous functions)) public abstract fun raise(e: R|E|): R|kotlin/Nothing| + + } + context(R|foo/bar/Raise|) + public final fun shouldNotCapture(): R|() -> kotlin/Unit| { + ^shouldNotCapture fun (): R|kotlin/Unit| { + this@R|foo/bar/shouldNotCapture|.R|SubstitutionOverride|(String(boom)) + } + + } + context(R|foo/bar/Raise|) + public final fun inlineCaptureOk(): R|kotlin/Unit| { + R|kotlin/collections/listOf|(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|( = map@fun (it: R|kotlin/Int|): R|kotlin/Nothing| { + this@R|foo/bar/inlineCaptureOk|.R|SubstitutionOverride|(String(boom)) + } + ) + } + context(R|foo/bar/Raise|) + public final fun leakedNotOk(): R|() -> kotlin/Unit| { + ^leakedNotOk fun (): R|kotlin/Unit| { + R|kotlin/collections/listOf|(vararg(Int(1), Int(2), Int(3))).R|kotlin/collections/map|( = map@fun (it: R|kotlin/Int|): R|kotlin/Nothing| { + this@R|foo/bar/leakedNotOk|.R|SubstitutionOverride|(String(boom)) + } + ) + } + + } + context(R|foo/bar/Raise|) + public final fun ok(): R|kotlin/Unit| { + this@R|foo/bar/ok|.R|SubstitutionOverride|(String(boom)) + } diff --git a/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.kt b/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.kt new file mode 100644 index 0000000..39f69a8 --- /dev/null +++ b/arrow-reflect-compiler-plugin/src/testData/diagnostics/capture_test.kt @@ -0,0 +1,27 @@ +package foo.bar + +import arrow.meta.samples.DisallowLambdaCapture + +interface Raise { + @DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions") fun raise(e: E): Nothing +} + +context(Raise) +fun shouldNotCapture(): () -> Unit { + return { raise("boom") } +} + +context(Raise) +fun inlineCaptureOk(): Unit { + listOf(1, 2, 3).map { raise("boom") } +} + +context(Raise) +fun leakedNotOk(): () -> Unit = { + listOf(1, 2, 3).map { raise("boom") } +} + +context(Raise) +fun ok(): Unit { + raise("boom") +} diff --git a/arrow-reflect-compiler-plugin/src/testGenerated/arrow/reflect/compiler/plugin/runners/DiagnosticTestGenerated.java b/arrow-reflect-compiler-plugin/src/testGenerated/arrow/reflect/compiler/plugin/runners/DiagnosticTestGenerated.java index a50cc34..a84ba2b 100644 --- a/arrow-reflect-compiler-plugin/src/testGenerated/arrow/reflect/compiler/plugin/runners/DiagnosticTestGenerated.java +++ b/arrow-reflect-compiler-plugin/src/testGenerated/arrow/reflect/compiler/plugin/runners/DiagnosticTestGenerated.java @@ -21,6 +21,12 @@ public void testAllFilesPresentInDiagnostics() throws Exception { KtTestUtil.assertAllTestsPresentByMetadataWithExcluded(this.getClass(), new File("src/testData/diagnostics"), Pattern.compile("^(.+)\\.kt$"), null, true); } + @Test + @TestMetadata("capture_test.kt") + public void testCapture_test() throws Exception { + runTest("src/testData/diagnostics/capture_test.kt"); + } + @Test @TestMetadata("immutable_test.kt") public void testImmutable_test() throws Exception { diff --git a/sandbox/src/main/kotlin/Sample.kt b/sandbox/src/main/kotlin/Sample.kt index feae105..ef16dd7 100644 --- a/sandbox/src/main/kotlin/Sample.kt +++ b/sandbox/src/main/kotlin/Sample.kt @@ -1,12 +1,38 @@ -package example +package foo.bar -import arrow.meta.samples.Product +import arrow.meta.samples.DisallowLambdaCapture +import kotlin.contracts.* -@Product -data class Sample(val name: String, val age: Int) +interface Raise { + @DisallowLambdaCapture("It's unsafe to capture `raise` inside non-inline anonymous functions") + fun raise(e: E): Nothing +} + +context(Raise) +fun shouldNotCapture(): () -> Unit { + return { raise("boom") } +} + +context(Raise) +fun inlineCaptureOk(): Unit { + listOf(1, 2, 3).map { raise("boom") } +} + +@OptIn(ExperimentalContracts::class) +fun exactlyOne(f: () -> Unit): Unit { + contract { + callsInPlace(f, InvocationKind.EXACTLY_ONCE) + } +} +@OptIn(ExperimentalContracts::class) +fun exactlyOnce(f: () -> Unit): Unit { + contract { + callsInPlace(f, InvocationKind.EXACTLY_ONCE) + } +} -fun main() { - val properties = Sample("j", 12).product() - println(properties) +context(Raise) +fun ok(): () -> Unit = { + exactlyOnce { raise("boom") } }