Skip to content

Commit

Permalink
[WasmJs] Add support for external class reflection
Browse files Browse the repository at this point in the history
Fix #KT-64890
  • Loading branch information
igoriakovlev authored and qodana-bot committed Apr 30, 2024
1 parent 0872420 commit 1a630b5
Show file tree
Hide file tree
Showing 24 changed files with 332 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,6 @@ class JsIntrinsics(private val irBuiltIns: IrBuiltIns, val context: JsIrBackendC
override val getKClassFromExpression = getInternalWithoutPackage("getKClassFromExpression")
override val primitiveClassesObject = context.getIrClass(FqName("kotlin.reflect.js.internal.PrimitiveClasses"))
override val kTypeClass: IrClassSymbol = context.getIrClass(FqName("kotlin.reflect.KType"))
override val getClassData: IrSimpleFunctionSymbol get() = jsClass
}

internal val reflectionSymbols: JsReflectionSymbols = JsReflectionSymbols()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,8 +758,8 @@ private val jsClassUsageInReflectionPhase = makeBodyLoweringPhase(
)

private val classReferenceLoweringPhase = makeBodyLoweringPhase(
::ClassReferenceLowering,
name = "ClassReferenceLowering",
::JsClassReferenceLowering,
name = "JsClassReferenceLowering",
description = "Handle class references",
prerequisite = setOf(jsClassUsageInReflectionPhase)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,4 +57,7 @@ class JsMapping : DefaultMapping() {

val wasmExternalClassToInstanceCheck =
DefaultDelegateFactory.newDeclarationToDeclarationMapping<IrClass, IrSimpleFunction>()

val wasmGetJsClass =
DefaultDelegateFactory.newDeclarationToDeclarationMapping<IrClass, IrSimpleFunction>()
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ package org.jetbrains.kotlin.ir.backend.js

import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
import org.jetbrains.kotlin.ir.types.IrType

interface ReflectionSymbols {
val getKClassFromExpression: IrSimpleFunctionSymbol
val getKClass: IrSimpleFunctionSymbol
val getClassData: IrSimpleFunctionSymbol
val createKType: IrSimpleFunctionSymbol?
val createDynamicKType: IrSimpleFunctionSymbol?
val createKTypeParameter: IrSimpleFunctionSymbol?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,36 @@ import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid
import org.jetbrains.kotlin.name.Name
import org.jetbrains.kotlin.types.*

class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLoweringPass {
class JsClassReferenceLowering(context: JsIrBackendContext) : ClassReferenceLowering(context) {
private val getClassData = context.intrinsics.jsClass

private val reflectionSymbols get() = context.reflectionSymbols
override fun callGetKClass(
returnType: IrType,
typeArgument: IrType
): IrCall {
val primitiveKClass =
getFinalPrimitiveKClass(returnType, typeArgument) ?: getOpenPrimitiveKClass(returnType, typeArgument)

if (primitiveKClass != null)
return primitiveKClass

return JsIrBuilder.buildCall(reflectionSymbols.getKClass, returnType, listOf(typeArgument))
.apply {
putValueArgument(0, callGetClassByType(typeArgument))
}
}

private fun callGetClassByType(type: IrType) =
JsIrBuilder.buildCall(
getClassData,
typeArguments = listOf(type),
origin = JsStatementOrigins.CLASS_REFERENCE
)
}

abstract class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLoweringPass {

protected val reflectionSymbols get() = context.reflectionSymbols

private val primitiveClassProperties by lazy(LazyThreadSafetyMode.NONE) {
reflectionSymbols.primitiveClassesObject.owner.declarations.filterIsInstance<IrProperty>()
Expand Down Expand Up @@ -96,7 +123,7 @@ class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLowering
)
}

private fun getFinalPrimitiveKClass(returnType: IrType, typeArgument: IrType): IrCall? {
protected fun getFinalPrimitiveKClass(returnType: IrType, typeArgument: IrType): IrCall? {
for ((typePredicate, v) in finalPrimitiveClasses) {
if (typePredicate(typeArgument))
return getPrimitiveClass(v, returnType)
Expand All @@ -106,7 +133,7 @@ class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLowering
}


private fun getOpenPrimitiveKClass(returnType: IrType, typeArgument: IrType): IrCall? {
protected fun getOpenPrimitiveKClass(returnType: IrType, typeArgument: IrType): IrCall? {
for ((typePredicate, v) in openPrimitiveClasses) {
if (typePredicate(typeArgument))
return getPrimitiveClass(v, returnType)
Expand All @@ -123,28 +150,10 @@ class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLowering
return null
}

private fun callGetKClass(
abstract fun callGetKClass(
returnType: IrType = reflectionSymbols.getKClass.owner.returnType,
typeArgument: IrType
): IrCall {
val primitiveKClass =
getFinalPrimitiveKClass(returnType, typeArgument) ?: getOpenPrimitiveKClass(returnType, typeArgument)

if (primitiveKClass != null)
return primitiveKClass

return JsIrBuilder.buildCall(reflectionSymbols.getKClass, returnType, listOf(typeArgument))
.apply {
putValueArgument(0, callGetClassByType(typeArgument))
}
}

private fun callGetClassByType(type: IrType) =
JsIrBuilder.buildCall(
reflectionSymbols.getClassData,
typeArguments = listOf(type),
origin = JsStatementOrigins.CLASS_REFERENCE
)
): IrCall

private fun buildCall(name: IrSimpleFunctionSymbol, vararg args: IrExpression): IrExpression =
JsIrBuilder.buildCall(name).apply {
Expand All @@ -154,7 +163,6 @@ class ClassReferenceLowering(val context: JsCommonBackendContext) : BodyLowering
}

private fun createKType(type: IrType, visitedTypeParams: MutableSet<IrTypeParameter>): IrExpression {

if (type is IrSimpleType)
return createSimpleKType(type, visitedTypeParams)
if (type is IrDynamicType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,8 +489,8 @@ private val staticMembersLoweringPhase = makeWasmModulePhase(
)

private val classReferenceLoweringPhase = makeWasmModulePhase(
::ClassReferenceLowering,
name = "ClassReferenceLowering",
::WasmClassReferenceLowering,
name = "WasmClassReferenceLowering",
description = "Handle class references"
)

Expand Down Expand Up @@ -681,6 +681,8 @@ val wasmPhases = SameTypeNamedCompilerPhase(

wasmStringSwitchOptimizerLowering then

associatedObjectsLowering then

complexExternalDeclarationsToTopLevelFunctionsLowering then
complexExternalDeclarationsUsagesLowering then

Expand Down Expand Up @@ -727,8 +729,6 @@ val wasmPhases = SameTypeNamedCompilerPhase(
eraseVirtualDispatchReceiverParametersTypes then
bridgesConstructionPhase then

associatedObjectsLowering then

objectDeclarationLoweringPhase then
genericReturnTypeLowering then
unitToVoidLowering then
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ class WasmSymbols(

internal inner class WasmReflectionSymbols : ReflectionSymbols {
override val createKType: IrSimpleFunctionSymbol = getInternalFunction("createKType")
override val getClassData: IrSimpleFunctionSymbol = getInternalFunction("wasmGetTypeInfoData")
override val getKClass: IrSimpleFunctionSymbol = getInternalFunction("getKClass")
override val getKClassFromExpression: IrSimpleFunctionSymbol = getInternalFunction("getKClassFromExpression")
override val createDynamicKType: IrSimpleFunctionSymbol get() = error("Dynamic type is not supported by WASM")
override val createDynamicKType: IrSimpleFunctionSymbol get() = error("Dynamic type is not supported by Wasm")
override val createKTypeParameter: IrSimpleFunctionSymbol = getInternalFunction("createKTypeParameter")
override val getStarKTypeProjection = getInternalFunction("getStarKTypeProjection")
override val createCovariantKTypeProjection = getInternalFunction("createCovariantKTypeProjection")
Expand All @@ -67,6 +66,7 @@ class WasmSymbols(

val getTypeInfoTypeDataByPtr: IrSimpleFunctionSymbol = getInternalFunction("getTypeInfoTypeDataByPtr")
val wasmTypeInfoData: IrClassSymbol = getInternalClass("TypeInfoData")
val kClassImpl: IrClassSymbol = getInternalClass("KClassImpl")
}

internal val reflectionSymbols: WasmReflectionSymbols = WasmReflectionSymbols()
Expand Down Expand Up @@ -359,6 +359,8 @@ class WasmSymbols(

internal val throwAsJsException: IrSimpleFunctionSymbol =
getInternalFunction("throwAsJsException")

val kExternalClassImpl: IrClassSymbol = getInternalClass("KExternalClassImpl")
}

private val wasmExportClass = getIrClass(FqName("kotlin.wasm.WasmExport"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class AssociatedObjectsLowering(val context: WasmBackendContext) : FileLoweringP
for (klassAnnotation in declaration.annotations) {
val annotationClass = klassAnnotation.symbol.owner.parentClassOrNull ?: continue
if (klassAnnotation.valueArgumentsCount != 1) continue
if (declaration.isEffectivelyExternal()) continue
val associatedObject = klassAnnotation.associatedObject() ?: continue

val builder = cachedBuilder ?: context.createIrBuilder(context.wasmSymbols.initAssociatedObjects)
Expand Down Expand Up @@ -104,7 +105,7 @@ private fun IrBuilderWithScope.createAssociatedObjectAdd(
)
addCall.putValueArgument(
3,
irGetObjectValue(irBuiltIns.anyType, associatedObject)
irGetObjectValue(associatedObject.defaultType, associatedObject)
)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,19 @@ import org.jetbrains.kotlin.backend.common.lower.createIrBuilder
import org.jetbrains.kotlin.backend.wasm.WasmBackendContext
import org.jetbrains.kotlin.config.AnalysisFlags
import org.jetbrains.kotlin.config.languageVersionSettings
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.backend.js.lower.calls.EnumIntrinsicsUtils
import org.jetbrains.kotlin.ir.backend.js.utils.erasedUpperBound
import org.jetbrains.kotlin.ir.backend.js.utils.isEqualsInheritedFromAny
import org.jetbrains.kotlin.ir.builders.*
import org.jetbrains.kotlin.ir.declarations.IrConstructor
import org.jetbrains.kotlin.ir.declarations.IrFile
import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction
import org.jetbrains.kotlin.ir.expressions.IrCall
import org.jetbrains.kotlin.ir.expressions.IrExpression
import org.jetbrains.kotlin.ir.expressions.impl.IrConstructorCallImpl
import org.jetbrains.kotlin.ir.expressions.putClassTypeArgument
import org.jetbrains.kotlin.ir.util.toIrConst
import org.jetbrains.kotlin.ir.symbols.IrSimpleFunctionSymbol
import org.jetbrains.kotlin.ir.types.*
Expand Down Expand Up @@ -157,35 +162,35 @@ class BuiltInsLowering(val context: WasmBackendContext) : FileLoweringPass {
val newSymbol = irBuiltins.suspendFunctionN(arity).getSimpleFunction("invoke")!!
return irCall(call, newSymbol, argumentsAsReceivers = true)
}
symbols.reflectionSymbols.getClassData -> {
context.reflectionSymbols.getKClass -> {
val type = call.getTypeArgument(0)!!
val klass = type.classOrNull?.owner ?: error("Invalid type")

val typeId = builder.irCall(symbols.wasmTypeId).also {
it.putTypeArgument(0, type)
val constructorArgument: IrExpression
val kclassConstructor: IrConstructor
if (klass.isEffectivelyExternal()) {
check(context.configuration.get(JSConfigurationKeys.WASM_TARGET, WasmTarget.JS) == WasmTarget.JS) { "External classes reflection in WASI mode are not supported" }
kclassConstructor = symbols.jsRelatedSymbols.kExternalClassImpl.owner.constructors.first()
constructorArgument = getExternalKClassCtorArgument(type, builder)
} else {
kclassConstructor = symbols.reflectionSymbols.kClassImpl.owner.constructors.first()
constructorArgument = getKClassCtorArgument(type, builder)
}

if (!klass.isInterface) {
return builder.irCall(context.wasmSymbols.reflectionSymbols.getTypeInfoTypeDataByPtr).also {
it.putValueArgument(0, typeId)
}
} else {
val infoDataCtor = symbols.reflectionSymbols.wasmTypeInfoData.constructors.first()
val fqName = type.classFqName!!
val fqnShouldBeEmitted =
context.configuration.languageVersionSettings.getFlag(AnalysisFlags.allowFullyQualifiedNameInKClass)
val packageName = if (fqnShouldBeEmitted) fqName.parentOrNull()?.asString() ?: "" else ""
val typeName = fqName.shortName().asString()

return with(builder) {
irCallConstructor(infoDataCtor, emptyList()).also {
it.putValueArgument(0, typeId)
it.putValueArgument(1, packageName.toIrConst(context.irBuiltIns.stringType))
it.putValueArgument(2, typeName.toIrConst(context.irBuiltIns.stringType))
}
}
return IrConstructorCallImpl(
startOffset = UNDEFINED_OFFSET,
endOffset = UNDEFINED_OFFSET,
type = kclassConstructor.returnType,
symbol = kclassConstructor.symbol,
typeArgumentsCount = 1,
valueArgumentsCount = 1,
constructorTypeArgumentsCount = 0
).also {
it.putClassTypeArgument(0, type)
it.putValueArgument(0, constructorArgument)
}
}

symbols.enumValueOfIntrinsic ->
return EnumIntrinsicsUtils.transformEnumValueOfIntrinsic(call)
symbols.enumValuesIntrinsic ->
Expand All @@ -197,6 +202,40 @@ class BuiltInsLowering(val context: WasmBackendContext) : FileLoweringPass {
return call
}

private fun getKClassCtorArgument(type: IrType, builder: DeclarationIrBuilder): IrExpression {
val klass = type.classOrNull?.owner ?: error("Invalid type")

val typeId = builder.irCall(symbols.wasmTypeId).also {
it.putTypeArgument(0, type)
}

if (!klass.isInterface) {
return builder.irCall(context.wasmSymbols.reflectionSymbols.getTypeInfoTypeDataByPtr).also {
it.putValueArgument(0, typeId)
}
} else {
val fqName = type.classFqName!!
val fqnShouldBeEmitted =
context.configuration.languageVersionSettings.getFlag(AnalysisFlags.allowFullyQualifiedNameInKClass)
val packageName = if (fqnShouldBeEmitted) fqName.parentOrNull()?.asString() ?: "" else ""
val typeName = fqName.shortName().asString()

return builder.irCallConstructor(symbols.reflectionSymbols.wasmTypeInfoData.constructors.first(), emptyList()).also {
it.putValueArgument(0, typeId)
it.putValueArgument(1, packageName.toIrConst(context.irBuiltIns.stringType))
it.putValueArgument(2, typeName.toIrConst(context.irBuiltIns.stringType))
}
}
}

private fun getExternalKClassCtorArgument(type: IrType, builder: DeclarationIrBuilder): IrExpression {
val klass = type.classOrNull?.owner ?: error("Invalid type")
check(klass.kind != ClassKind.INTERFACE) { "External interface must not be a class literal" }
val classGetClassFunction = context.mapping.wasmGetJsClass[klass]!!
val wrappedGetClassIfAny = context.mapping.wasmJsInteropFunctionToWrapper[classGetClassFunction] ?: classGetClassFunction
return builder.irCall(wrappedGetClassIfAny)
}

override fun lower(irFile: IrFile) {
val builder = context.createIrBuilder(irFile.symbol)
irFile.transformChildrenVoid(object : IrElementTransformerVoidWithContext() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@ import org.jetbrains.kotlin.backend.wasm.utils.getWasmImportDescriptor
import org.jetbrains.kotlin.descriptors.ClassKind
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.UNDEFINED_OFFSET
import org.jetbrains.kotlin.ir.backend.js.utils.getJsModule
import org.jetbrains.kotlin.ir.backend.js.utils.getJsNameOrKotlinName
import org.jetbrains.kotlin.ir.backend.js.utils.getJsQualifier
import org.jetbrains.kotlin.ir.backend.js.utils.realOverrideTarget
import org.jetbrains.kotlin.ir.backend.js.utils.*
import org.jetbrains.kotlin.ir.builders.declarations.addValueParameter
import org.jetbrains.kotlin.ir.builders.declarations.buildFun
import org.jetbrains.kotlin.ir.builders.irCallConstructor
Expand All @@ -28,6 +25,7 @@ import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.expressions.impl.IrCallImpl
import org.jetbrains.kotlin.ir.expressions.impl.IrConstImpl
import org.jetbrains.kotlin.ir.types.IrType
import org.jetbrains.kotlin.ir.types.makeNullable
import org.jetbrains.kotlin.ir.util.*
import org.jetbrains.kotlin.ir.visitors.*
import org.jetbrains.kotlin.name.Name
Expand Down Expand Up @@ -96,8 +94,10 @@ class ComplexExternalDeclarationsToTopLevelFunctionsLowering(val context: WasmBa
if (klass.kind == ClassKind.OBJECT)
generateExternalObjectInstanceGetter(klass)

if (klass.kind != ClassKind.INTERFACE)
if (klass.kind != ClassKind.INTERFACE) {
generateInstanceCheckForExternalClass(klass)
generateGetClassForExternalClass(klass)
}
}

fun processExternalProperty(property: IrProperty) {
Expand Down Expand Up @@ -374,6 +374,18 @@ class ComplexExternalDeclarationsToTopLevelFunctionsLowering(val context: WasmBa
}
}

fun generateGetClassForExternalClass(klass: IrClass) {
context.mapping.wasmGetJsClass[klass] = createExternalJsFunction(
klass.name,
"_\$external_class_get",
resultType = context.wasmSymbols.jsRelatedSymbols.jsAnyType.makeNullable(),
jsCode = buildString {
append("() => ")
appendExternalClassReference(klass)
}
)
}

private fun createExternalJsFunction(
originalName: Name,
suffix: String,
Expand Down

0 comments on commit 1a630b5

Please sign in to comment.