Skip to content

Commit

Permalink
Reduce unnecessary resolve in type providers
Browse files Browse the repository at this point in the history
GitOrigin-RevId: e870ae4e4f5ec206c2389c62f3d8f74c6ff824a3
  • Loading branch information
sproshev authored and intellij-monorepo-bot committed Nov 11, 2020
1 parent 3da8305 commit 7aef830
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.jetbrains.python.codeInsight.stdlib

import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.containers.isNullOrEmpty
import com.jetbrains.python.PyNames
Expand All @@ -20,7 +21,23 @@ import one.util.streamex.StreamEx
class PyDataclassTypeProvider : PyTypeProviderBase() {

override fun getReferenceExpressionType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
return getDataclassTypeForCallee(referenceExpression, context) ?: getDataclassesReplaceType(referenceExpression, context)
return getDataclassesReplaceType(referenceExpression, context)
}

override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref<PyType>? {
val result = when {
referenceTarget is PyClass && anchor is PyCallExpression -> getDataclassTypeForClass(referenceTarget, context)
referenceTarget is PyParameter && referenceTarget.isSelf && anchor is PyCallExpression -> {
PsiTreeUtil.getParentOfType(referenceTarget, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.let {
it.containingClass?.let { getDataclassTypeForClass(it, context) }
}
}
else -> null
}

return PyTypeUtil.notNullToRef(result)
}

override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref<PyType>? {
Expand Down Expand Up @@ -48,30 +65,6 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {

companion object {

private fun getDataclassTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null

val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context)
val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false)

return PyUtil.filterTopPriorityResults(resolveResults)
.asSequence()
.map {
when {
it is PyClass -> getDataclassTypeForClass(it, context)
it is PyParameter && it.isSelf -> {
PsiTreeUtil.getParentOfType(it, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.let {
it.containingClass?.let { getDataclassTypeForClass(it, context) }
}
}
else -> null
}
}
.firstOrNull { it != null }
}

private fun getDataclassesReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null
val callee = call.callee as? PyReferenceExpression ?: return null
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.impl.StubAwareComputation
import com.jetbrains.python.psi.impl.stubs.PyNamedTupleStubImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.stubs.PyNamedTupleStub
import com.jetbrains.python.psi.types.*
import one.util.streamex.StreamEx
Expand All @@ -37,11 +36,6 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
return fieldTypeForTypingNTFunctionInheritor
}

val namedTupleTypeForCallee = getNamedTupleTypeForCallee(referenceExpression, context)
if (namedTupleTypeForCallee != null) {
return namedTupleTypeForCallee
}

val namedTupleReplaceType = getNamedTupleReplaceType(referenceExpression, context)
if (namedTupleReplaceType != null) {
return namedTupleReplaceType
Expand Down Expand Up @@ -71,6 +65,7 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
return when {
referenceTarget is PyFunction && anchor is PyCallExpression -> getNamedTupleFunctionType(referenceTarget, context, anchor)
referenceTarget is PyTargetExpression -> getNamedTupleTypeForTarget(referenceTarget, context)
referenceTarget is PyClass && anchor is PyCallExpression -> getNamedTupleTypeForNTInheritorAsCallee(referenceTarget, context)
else -> null
}
}
Expand All @@ -94,43 +89,6 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
)
}

private fun getNamedTupleTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null

val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context)
val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false)

for (element in PyUtil.filterTopPriorityResults(resolveResults)) {
if (element is PyTargetExpression) {
val result = getNamedTupleTypeForTarget(element, context)
if (result != null) {
return result
}
}

if (element is PyClass) {
val result = getNamedTupleTypeForTypingNTInheritorAsCallee(element, context)
if (result != null) {
return result
}
}

if (element is PyTypedElement) {
val type = context.getType(element)
if (type is PyClassLikeType) {
val superClassTypes = type.getSuperClassTypes(context)

val superNTType = superClassTypes.asSequence().filterIsInstance<PyNamedTupleType>().firstOrNull()
if (superNTType != null) {
return PyCallableTypeImpl(superNTType.getParameters(context), type.toInstance())
}
}
}
}

return null
}

private fun getNamedTupleReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null

Expand Down Expand Up @@ -190,16 +148,21 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
.compute(context)
}

private fun getNamedTupleTypeForTypingNTInheritorAsCallee(cls: PyClass, context: TypeEvalContext): PyType? {
if (isTypingNamedTupleDirectInheritor(cls, context)) {
private fun getNamedTupleTypeForNTInheritorAsCallee(cls: PyClass, context: TypeEvalContext): PyType? {
val parameters = if (isTypingNamedTupleDirectInheritor(cls, context)) {
val name = cls.name ?: return null
val tupleClass = PyPsiFacade.getInstance(cls.project).createClassByQName(PyTypingTypeProvider.NAMEDTUPLE, cls) ?: return null
val namedTupleType = PyNamedTupleType(tupleClass, name, collectTypingNTInheritorFields(cls, context), true, true, cls)

return PyCallableTypeImpl(namedTupleType.getParameters(context), cls.getType(context)?.toInstance())
namedTupleType.getParameters(context)
}
else {
val superNTType = cls.getSuperClassTypes(context).firstOrNull(PyNamedTupleType::class.java::isInstance) ?: return null

return null
superNTType.getParameters(context)
}

return PyCallableTypeImpl(parameters, cls.getType(context)?.toInstance())
}

private fun getNamedTupleTypeFromStub(targetOrCall: PsiElement, stub: PyNamedTupleStub?, context: TypeEvalContext): PyNamedTupleType? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.impl.PyEvaluator
import com.jetbrains.python.psi.impl.StubAwareComputation
import com.jetbrains.python.psi.impl.stubs.PyTypedDictStubImpl
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.stubs.PyTypedDictStub
import com.jetbrains.python.psi.types.*
import com.jetbrains.python.psi.types.PyTypedDictType.Companion.TYPED_DICT_FIELDS_PARAMETER
Expand Down Expand Up @@ -105,53 +104,25 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() {
private fun getTypedDictTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null

val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context)
val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false)

for (element in PyUtil.filterTopPriorityResults(resolveResults)) {
if (element is PyTargetExpression) {
val result = getTypedDictTypeForTarget(element, context)
if (result != null) {
return result
}
}

if (element is PyClass) {
val result = getTypedDictTypeForTypingTDInheritorAsCallee(element, context, false)
if (result != null) {
return result
}
}

if (element is PyTypedElement) {
val type = context.getType(element)
if (type is PyClassType) {
if (isTypingTypedDictInheritor(type.pyClass, context)) {
return getTypedDictTypeForTypingTDInheritorAsCallee(type.pyClass, context, false)
}
}
}
if (isTypedDict(referenceExpression, context)) {
val parameters = mutableListOf<PyCallableParameter>()

if (isTypedDict(referenceExpression, context)) {
val parameters = mutableListOf<PyCallableParameter>()

val builtinCache = PyBuiltinCache.getInstance(referenceExpression)
val languageLevel = LanguageLevel.forElement(referenceExpression)
val generator = PyElementGenerator.getInstance(referenceExpression.project)

parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_NAME_PARAMETER, builtinCache.getStringType(languageLevel)))
val dictClassType = builtinCache.dictType
parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_FIELDS_PARAMETER,
if (dictClassType != null) PyCollectionTypeImpl(dictClassType.pyClass, false,
listOf(builtinCache.strType, null))
else dictClassType))
parameters.add(
PyCallableParameterImpl.nonPsi(TYPED_DICT_TOTAL_PARAMETER,
builtinCache.boolType,
generator.createExpressionFromText(languageLevel, PyNames.TRUE)))

return PyCallableTypeImpl(parameters, null)
}
val builtinCache = PyBuiltinCache.getInstance(referenceExpression)
val languageLevel = LanguageLevel.forElement(referenceExpression)
val generator = PyElementGenerator.getInstance(referenceExpression.project)

parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_NAME_PARAMETER, builtinCache.getStringType(languageLevel)))
val dictClassType = builtinCache.dictType
parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_FIELDS_PARAMETER,
if (dictClassType != null) PyCollectionTypeImpl(dictClassType.pyClass, false,
listOf(builtinCache.strType, null))
else dictClassType))
parameters.add(
PyCallableParameterImpl.nonPsi(TYPED_DICT_TOTAL_PARAMETER,
builtinCache.boolType,
generator.createExpressionFromText(languageLevel, PyNames.TRUE)))

return PyCallableTypeImpl(parameters, null)
}

return null
Expand Down Expand Up @@ -396,7 +367,7 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() {
return typedDictFieldsFromKeysAndValues(fields, context)
}

private fun typedDictFieldsFromKeysAndValues(fields: Map<String, PyExpression?>, context: TypeEvalContext): TDFields? {
private fun typedDictFieldsFromKeysAndValues(fields: Map<String, PyExpression?>, context: TypeEvalContext): TDFields {
val result = TDFields()
for ((name, type) in fields) {
result[name] = if (type != null) PyTypedDictType.FieldTypeAndTotality(context.getType(type))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,9 @@

import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.ResolveResult;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator;
import com.jetbrains.python.psi.impl.StubAwareComputation;
import com.jetbrains.python.psi.impl.stubs.PyTypingNewTypeStubImpl;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.stubs.PyTypingNewTypeStub;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
Expand All @@ -24,11 +21,6 @@ public class PyTypingNewTypeTypeProvider extends PyTypeProviderBase {
: null;
}

@Override
public @Nullable PyType getReferenceExpressionType(@NotNull PyReferenceExpression referenceExpression, @NotNull TypeEvalContext context) {
return getNewTypeForCallee(referenceExpression, context);
}

@Override
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (referenceTarget instanceof PyTargetExpression) {
Expand All @@ -38,24 +30,6 @@ public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNul
return null;
}

@Nullable
private static PyTypingNewType getNewTypeForCallee(@NotNull PyReferenceExpression referenceExpression, @NotNull TypeEvalContext context) {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null;

final PyResolveContext resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context);
final ResolveResult[] resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false);

for (PsiElement element : PyUtil.filterTopPriorityResults(resolveResults)) {
if (element instanceof PyTargetExpression) {
final PyTypingNewType typeForTarget = getNewTypeForTarget((PyTargetExpression)element, context);
if (typeForTarget != null) {
return typeForTarget;
}
}
}
return null;
}

@Nullable
private static PyTypingNewType getNewTypeForTarget(@NotNull PyTargetExpression target, @NotNull TypeEvalContext context) {
return StubAwareComputation.on(target)
Expand Down

0 comments on commit 7aef830

Please sign in to comment.