Skip to content

Commit

Permalink
Performance improvements, particularly for deeply nested shapes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Benjamin-Dobell committed Jan 3, 2023
1 parent f01e7eb commit b8c20ac
Show file tree
Hide file tree
Showing 20 changed files with 218 additions and 130 deletions.
Expand Up @@ -43,7 +43,7 @@ class AssignTypeInspection : StrictInspection() {
// Get owner class
val assigneeOwnerType = assignee.guessParentType(context)

if (assigneeOwnerType is TyTable && resolvedValue is TyTable && assigneeOwnerType.table == resolvedValue.table) {
if (assigneeOwnerType is TyTable && resolvedValue is TyTable && assigneeOwnerType.psi == resolvedValue.psi) {
return
}

Expand Down Expand Up @@ -87,7 +87,7 @@ class AssignTypeInspection : StrictInspection() {

val variableType = assignee.guessType(context)

if (variableType == null || (variableType is TyTable && resolvedValue is TyTable && variableType.table == resolvedValue.table)) {
if (variableType == null || (variableType is TyTable && resolvedValue is TyTable && variableType.psi == resolvedValue.psi)) {
return
}

Expand Down
Expand Up @@ -182,6 +182,8 @@ class LuaCommentImpl(node: ASTNode) : ASTWrapperPsiElement(node), LuaComment {
val map = list.associateBy { it.varName }

object : TySubstitutor() {
override val name = "name substitutor"

override fun substitute(context: SearchContext, clazz: ITyClass): ITy {
return map[clazz.className] ?: super.substitute(context, clazz)
}
Expand Down
9 changes: 6 additions & 3 deletions src/main/java/com/tang/intellij/lua/psi/LuaParamInfo.kt
Expand Up @@ -23,7 +23,10 @@ import com.tang.intellij.lua.Constants
import com.tang.intellij.lua.search.SearchContext
import com.tang.intellij.lua.stubs.readTyNullable
import com.tang.intellij.lua.stubs.writeTyNullable
import com.tang.intellij.lua.ty.*
import com.tang.intellij.lua.ty.ITy
import com.tang.intellij.lua.ty.ITySubstitutor
import com.tang.intellij.lua.ty.Primitives
import com.tang.intellij.lua.ty.TyMultipleResults

/**
* parameter info
Expand All @@ -36,14 +39,14 @@ class LuaParamInfo(val name: String, val ty: ITy?) {
return other is LuaParamInfo && other.ty == ty
}

fun equals(context: SearchContext, other: LuaParamInfo): Boolean {
fun equals(context: SearchContext, other: LuaParamInfo, equalityFlags: Int): Boolean {
if (ty == null) {
return other.ty == null
} else if (other.ty == null) {
return false
}

return ty.equals(context, other.ty)
return ty.equals(context, other.ty, equalityFlags)
}

override fun hashCode(): Int {
Expand Down
Expand Up @@ -153,7 +153,7 @@ private class ScopedTypeTreeScope(override val psi: LuaTypeScope, override val t
val classTag = if (cls is TySerializedClass) {
LuaClassIndex.find(context, cls.className)
} else if (cls is TyPsiDocClass) {
cls.tagClass
cls.psi
} else null

// Need to ensure we don't check the same scope *without* beforeIndex
Expand Down
5 changes: 2 additions & 3 deletions src/main/java/com/tang/intellij/lua/psi/PsiExtension.kt
Expand Up @@ -104,9 +104,8 @@ private fun LuaExpression<*>.shouldBeInternal(context: SearchContext): ITy? {
var ret: ITy = Primitives.VOID
Ty.eachResolved(context, fTy) {
if (it is ITyFunction) {
var sig = it.matchSignature(context, p2)?.substitutedSignature ?: it.mainSignature
val substitutor = p2.createSubstitutor(context, sig)
sig = sig.substitute(context, substitutor)
val sig = it.matchSignature(context, p2)?.substitutedSignature
?: it.mainSignature.substitute(context, p2.createSubstitutor(context, it.mainSignature))
ret = ret.union(context, sig.getArgTy(idx))
}
}
Expand Down
14 changes: 10 additions & 4 deletions src/main/java/com/tang/intellij/lua/ty/ProblemUtil.kt
Expand Up @@ -90,6 +90,12 @@ object ProblemUtil {
return false
}

// We perform a non-structural (i.e. inheritance) check first as a happy path optimization. This is a *very*
// significant optimization when you've deeply nested shapes whose members are other (non-anonymous) shapes.
if (target.contravariantOf(context, source, varianceFlags or TyVarianceFlags.NON_STRUCTURAL)) {
return true;
}

val sourceSubstitutor = source.getMemberSubstitutor(context)
val targetSubstitutor = target.getMemberSubstitutor(context)

Expand Down Expand Up @@ -204,7 +210,7 @@ object ProblemUtil {
val targetMemberTy = (if (indexTy != null) {
val targetMember = target.findIndexer(context, indexTy)

if (targetMember?.guessIndexType(context)?.equals(context, indexTy) == true) {
if (targetMember?.guessIndexType(context)?.equals(context, indexTy, 0) == true) {
// If the target index type == source index type, then we have already checked compatibility of this member above.
return@processMembers true
}
Expand Down Expand Up @@ -232,7 +238,7 @@ object ProblemUtil {
// TODO: DRY
if (varianceFlags and TyVarianceFlags.STRICT_UNKNOWN != 0 || !sourceMemberTy.isUnknown) {
if (varianceFlags and TyVarianceFlags.WIDEN_TABLES == 0) {
if (!targetMemberTy.equals(context, sourceMemberTy)) {
if (!targetMemberTy.equals(context, sourceMemberTy, 0)) {
isContravariant = false

if (processProblem != null && sourceElement != null) {
Expand Down Expand Up @@ -311,7 +317,7 @@ object ProblemUtil {
resolvedSourceTy.lazyInit(context)

if ((varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 || resolvedSourceTy.isAnonymousTable) && resolvedSourceTy.isShape(context)) {
val sourceIsInline = resolvedSourceTy is TyTable && resolvedSourceTy.table == sourceElement
val sourceIsInline = resolvedSourceTy is TyTable && resolvedSourceTy.psi == sourceElement
val indexes = sortedMapOf<Int, PsiElement>()
var foundNumberIndexer = false

Expand Down Expand Up @@ -453,7 +459,7 @@ object ProblemUtil {
base.lazyInit(context)
}

if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && resolvedSourceTy is TyTable && resolvedSourceTy.table == sourceElement && base.isShape(context)) {
if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && resolvedSourceTy is TyTable && resolvedSourceTy.psi == sourceElement && base.isShape(context)) {
isContravariant = contravariantOfShape(
context,
resolvedTargetTy,
Expand Down
58 changes: 41 additions & 17 deletions src/main/java/com/tang/intellij/lua/ty/Ty.kt
Expand Up @@ -29,9 +29,11 @@ import com.tang.intellij.lua.codeInsight.inspection.MatchFunctionSignatureInspec
import com.tang.intellij.lua.ext.recursionGuard
import com.tang.intellij.lua.project.LuaSettings
import com.tang.intellij.lua.psi.LuaCallExpr
import com.tang.intellij.lua.psi.LuaPsiElement
import com.tang.intellij.lua.psi.LuaTableExpr
import com.tang.intellij.lua.psi.argList
import com.tang.intellij.lua.search.SearchContext
import conditionallyCached
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.contract

Expand Down Expand Up @@ -68,6 +70,7 @@ class TyFlags {
const val UNKNOWN = 0x20 // Unless STRICT_UNKNOWN is enabled, this type is covariant of all other types.
}
}

class TyVarianceFlags {
companion object {
const val STRICT_UNKNOWN = 0x1 // When enabled UNKNOWN types are no longer treated as covariant of all types.
Expand All @@ -78,10 +81,28 @@ class TyVarianceFlags {
}
}

class TyEqualityFlags {
companion object {
const val NON_STRUCTURAL = 0x1 // Treat shapes as classes i.e. a shape is only covariant of another shape if it explicitly inherits from it.

fun fromVarianceFlags(varianceFlags: Int): Int {
return if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL != 0) {
TyEqualityFlags.NON_STRUCTURAL
} else {
0
}
}
}
}

data class SignatureMatchResult(val signature: IFunSignature?, val substitutedSignature: IFunSignature?, val returnTy: ITy)

typealias ProcessTypeMember = (ownerTy: ITy, member: TypeMember) -> Boolean

interface IPsiTy<T : LuaPsiElement> {
val psi: T
}

interface ITy : Comparable<ITy> {
val kind: TyKind

Expand All @@ -91,7 +112,7 @@ interface ITy : Comparable<ITy> {

val booleanType: ITy

fun equals(context: SearchContext, other: ITy): Boolean
fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean

fun union(context: SearchContext, ty: ITy): ITy

Expand Down Expand Up @@ -501,8 +522,9 @@ abstract class Ty(override val kind: TyKind) : ITy {

final override var flags: Int = 0

override val displayName: String
get() = TyRenderer.SIMPLE.render(this)
override val displayName by conditionallyCached({ !(this is IPsiTy<*>) || !SearchContext.get(psi.project).isDumb }) {
TyRenderer.SIMPLE.render(this)
}

// Lazy initialization because Primitives.TRUE is itself a Ty that needs to be instantiated and refers to itself.
override val booleanType: ITy by lazy { Primitives.TRUE }
Expand Down Expand Up @@ -548,7 +570,19 @@ abstract class Ty(override val kind: TyKind) : ITy {

val resolvedOther = resolve(context, other)

if (this.equals(context, resolvedOther)) {
if (varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 && isShape(context)) {
val isContravariant: Boolean? = recursionGuard(resolvedOther, {
// Note: ProblemUtil.contravariantOfShape will call back into this method with
// TyVarianceFlags.NON_STRUCTURAL set as a fast nominal check, before checking structurally.
ProblemUtil.contravariantOfShape(context, this, resolvedOther, varianceFlags)
})

if (isContravariant != null) {
return isContravariant
}
}

if (this.equals(context, resolvedOther, TyEqualityFlags.fromVarianceFlags(varianceFlags))) {
return true
}

Expand All @@ -566,16 +600,6 @@ abstract class Ty(override val kind: TyKind) : ITy {
return true
}

if ((varianceFlags and TyVarianceFlags.NON_STRUCTURAL == 0 || other.isAnonymousTable) && isShape(context)) {
val isContravariant: Boolean? = recursionGuard(resolvedOther, {
ProblemUtil.contravariantOfShape(context, this, resolvedOther, varianceFlags)
})

if (isContravariant != null) {
return isContravariant
}
}

val otherSuper = other.getSuperType(context)
return otherSuper != null && contravariantOf(context, otherSuper, varianceFlags)
}
Expand Down Expand Up @@ -871,7 +895,7 @@ class TyUnknown : Ty(TyKind.Unknown) {
this.flags = this.flags or TyFlags.UNKNOWN
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.UNKNOWN) {
return true
}
Expand Down Expand Up @@ -904,7 +928,7 @@ class TyNil : Ty(TyKind.Nil) {

override val booleanType = Primitives.FALSE

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.NIL) {
return true
}
Expand All @@ -931,7 +955,7 @@ class TyNil : Ty(TyKind.Nil) {

class TyVoid : Ty(TyKind.Void) {

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (other === Primitives.VOID) {
return true
}
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/com/tang/intellij/lua/ty/TyAlias.kt
Expand Up @@ -53,12 +53,12 @@ class TyAlias(override val name: String,
return other is ITyAlias && other.name == name && other.flags == flags
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (this === other) {
return true
}

return ty.equals(context, other)
return ty.equals(context, other, equalityFlags)
}

override fun hashCode(): Int {
Expand Down
18 changes: 9 additions & 9 deletions src/main/java/com/tang/intellij/lua/ty/TyArray.kt
Expand Up @@ -31,13 +31,13 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
return other is ITyArray && base == other.base
}

override fun equals(context: SearchContext, other: ITy): Boolean {
override fun equals(context: SearchContext, other: ITy, equalityFlags: Int): Boolean {
if (this === other) {
return true
}

val resolvedOther = Ty.resolve(context, other)
return resolvedOther is ITyArray && base.equals(context, resolvedOther.base)
return resolvedOther is ITyArray && base.equals(context, resolvedOther.base, equalityFlags)
}

override fun hashCode(): Int {
Expand All @@ -56,7 +56,7 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
val resolvedBase = Ty.resolve(context, base)

if (other is ITyArray) {
return resolvedBase.equals(context, other.base)
return resolvedBase.equals(context, other.base, TyEqualityFlags.fromVarianceFlags(varianceFlags))
|| (varianceFlags and TyVarianceFlags.WIDEN_TABLES != 0 && resolvedBase.contravariantOf(context, other.base, varianceFlags))
}

Expand Down Expand Up @@ -88,7 +88,7 @@ open class TyArray(override val base: ITy) : Ty(TyKind.Array), ITyArray {
}

return varianceFlags and TyVarianceFlags.WIDEN_TABLES != 0
|| Ty.resolve(context, resolvedBase).equals(context, indexedMemberType)
|| Ty.resolve(context, resolvedBase).equals(context, indexedMemberType, TyEqualityFlags.fromVarianceFlags(varianceFlags))
|| (resolvedBase.isUnknown && varianceFlags and TyVarianceFlags.STRICT_UNKNOWN == 0)
}

Expand Down Expand Up @@ -185,14 +185,14 @@ object TyArraySerializer : TySerializer<ITyArray>() {
}
}

class TyDocArray(val luaDocArrTy: LuaDocArrTy, base: ITy = luaDocArrTy.ty.getType()) : TyArray(base) {
class TyDocArray(override val psi: LuaDocArrTy, base: ITy = psi.ty.getType()) : TyArray(base), IPsiTy<LuaDocArrTy> {
override fun processIndexer(context: SearchContext, indexTy: ITy, exact: Boolean, deep: Boolean, process: ProcessTypeMember): Boolean {
if (exact) {
if (Primitives.NUMBER.equals(context, indexTy)) {
return process(this, luaDocArrTy)
if (Primitives.NUMBER.equals(context, indexTy, 0)) {
return process(this, psi)
}
} else if (Primitives.NUMBER.contravariantOf(context, indexTy, TyVarianceFlags.STRICT_UNKNOWN)) {
return process(this, luaDocArrTy)
return process(this, psi)
}

return true
Expand All @@ -202,7 +202,7 @@ class TyDocArray(val luaDocArrTy: LuaDocArrTy, base: ITy = luaDocArrTy.ty.getTyp
val substitutedBase = TyMultipleResults.getResult(context, base.substitute(context, substitutor))

return if (substitutedBase !== base) {
TyDocArray(luaDocArrTy, substitutedBase)
TyDocArray(psi, substitutedBase)
} else {
this
}
Expand Down

0 comments on commit b8c20ac

Please sign in to comment.