diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyDataclassTypeProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyDataclassTypeProvider.kt index 76db699dadeda..5fe395c6272ff 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyDataclassTypeProvider.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyDataclassTypeProvider.kt @@ -4,6 +4,7 @@ package com.jetbrains.python.codeInsight.stdlib import com.intellij.openapi.util.Ref +import com.intellij.psi.PsiElement import com.intellij.psi.util.PsiTreeUtil import com.intellij.util.containers.isNullOrEmpty import com.jetbrains.python.PyNames @@ -20,7 +21,23 @@ import one.util.streamex.StreamEx class PyDataclassTypeProvider : PyTypeProviderBase() { override fun getReferenceExpressionType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { - return getDataclassTypeForCallee(referenceExpression, context) ?: getDataclassesReplaceType(referenceExpression, context) + return getDataclassesReplaceType(referenceExpression, context) + } + + override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref? { + val result = when { + referenceTarget is PyClass && anchor is PyCallExpression -> getDataclassTypeForClass(referenceTarget, context) + referenceTarget is PyParameter && referenceTarget.isSelf && anchor is PyCallExpression -> { + PsiTreeUtil.getParentOfType(referenceTarget, PyFunction::class.java) + ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } + ?.let { + it.containingClass?.let { getDataclassTypeForClass(it, context) } + } + } + else -> null + } + + return PyTypeUtil.notNullToRef(result) } override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref? { @@ -48,30 +65,6 @@ class PyDataclassTypeProvider : PyTypeProviderBase() { companion object { - private fun getDataclassTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? { - if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null - - val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context) - val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false) - - return PyUtil.filterTopPriorityResults(resolveResults) - .asSequence() - .map { - when { - it is PyClass -> getDataclassTypeForClass(it, context) - it is PyParameter && it.isSelf -> { - PsiTreeUtil.getParentOfType(it, PyFunction::class.java) - ?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD } - ?.let { - it.containingClass?.let { getDataclassTypeForClass(it, context) } - } - } - else -> null - } - } - .firstOrNull { it != null } - } - private fun getDataclassesReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? { val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null val callee = call.callee as? PyReferenceExpression ?: return null diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyNamedTupleTypeProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyNamedTupleTypeProvider.kt index 09d86c108d7a0..487f9dbe284d5 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyNamedTupleTypeProvider.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/stdlib/PyNamedTupleTypeProvider.kt @@ -10,7 +10,6 @@ import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyCallExpressionNavigator import com.jetbrains.python.psi.impl.StubAwareComputation import com.jetbrains.python.psi.impl.stubs.PyNamedTupleStubImpl -import com.jetbrains.python.psi.resolve.PyResolveContext import com.jetbrains.python.psi.stubs.PyNamedTupleStub import com.jetbrains.python.psi.types.* import one.util.streamex.StreamEx @@ -37,11 +36,6 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() { return fieldTypeForTypingNTFunctionInheritor } - val namedTupleTypeForCallee = getNamedTupleTypeForCallee(referenceExpression, context) - if (namedTupleTypeForCallee != null) { - return namedTupleTypeForCallee - } - val namedTupleReplaceType = getNamedTupleReplaceType(referenceExpression, context) if (namedTupleReplaceType != null) { return namedTupleReplaceType @@ -71,6 +65,7 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() { return when { referenceTarget is PyFunction && anchor is PyCallExpression -> getNamedTupleFunctionType(referenceTarget, context, anchor) referenceTarget is PyTargetExpression -> getNamedTupleTypeForTarget(referenceTarget, context) + referenceTarget is PyClass && anchor is PyCallExpression -> getNamedTupleTypeForNTInheritorAsCallee(referenceTarget, context) else -> null } } @@ -94,43 +89,6 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() { ) } - private fun getNamedTupleTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { - if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null - - val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context) - val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false) - - for (element in PyUtil.filterTopPriorityResults(resolveResults)) { - if (element is PyTargetExpression) { - val result = getNamedTupleTypeForTarget(element, context) - if (result != null) { - return result - } - } - - if (element is PyClass) { - val result = getNamedTupleTypeForTypingNTInheritorAsCallee(element, context) - if (result != null) { - return result - } - } - - if (element is PyTypedElement) { - val type = context.getType(element) - if (type is PyClassLikeType) { - val superClassTypes = type.getSuperClassTypes(context) - - val superNTType = superClassTypes.asSequence().filterIsInstance().firstOrNull() - if (superNTType != null) { - return PyCallableTypeImpl(superNTType.getParameters(context), type.toInstance()) - } - } - } - } - - return null - } - private fun getNamedTupleReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? { val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null @@ -190,16 +148,21 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() { .compute(context) } - private fun getNamedTupleTypeForTypingNTInheritorAsCallee(cls: PyClass, context: TypeEvalContext): PyType? { - if (isTypingNamedTupleDirectInheritor(cls, context)) { + private fun getNamedTupleTypeForNTInheritorAsCallee(cls: PyClass, context: TypeEvalContext): PyType? { + val parameters = if (isTypingNamedTupleDirectInheritor(cls, context)) { val name = cls.name ?: return null val tupleClass = PyPsiFacade.getInstance(cls.project).createClassByQName(PyTypingTypeProvider.NAMEDTUPLE, cls) ?: return null val namedTupleType = PyNamedTupleType(tupleClass, name, collectTypingNTInheritorFields(cls, context), true, true, cls) - return PyCallableTypeImpl(namedTupleType.getParameters(context), cls.getType(context)?.toInstance()) + namedTupleType.getParameters(context) } + else { + val superNTType = cls.getSuperClassTypes(context).firstOrNull(PyNamedTupleType::class.java::isInstance) ?: return null - return null + superNTType.getParameters(context) + } + + return PyCallableTypeImpl(parameters, cls.getType(context)?.toInstance()) } private fun getNamedTupleTypeFromStub(targetOrCall: PsiElement, stub: PyNamedTupleStub?, context: TypeEvalContext): PyNamedTupleType? { diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt index 5864ac3e2c982..44be3e493d958 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt @@ -12,7 +12,6 @@ import com.jetbrains.python.psi.impl.PyCallExpressionNavigator import com.jetbrains.python.psi.impl.PyEvaluator import com.jetbrains.python.psi.impl.StubAwareComputation import com.jetbrains.python.psi.impl.stubs.PyTypedDictStubImpl -import com.jetbrains.python.psi.resolve.PyResolveContext import com.jetbrains.python.psi.stubs.PyTypedDictStub import com.jetbrains.python.psi.types.* import com.jetbrains.python.psi.types.PyTypedDictType.Companion.TYPED_DICT_FIELDS_PARAMETER @@ -105,53 +104,25 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() { private fun getTypedDictTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? { if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null - val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context) - val resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false) - - for (element in PyUtil.filterTopPriorityResults(resolveResults)) { - if (element is PyTargetExpression) { - val result = getTypedDictTypeForTarget(element, context) - if (result != null) { - return result - } - } - - if (element is PyClass) { - val result = getTypedDictTypeForTypingTDInheritorAsCallee(element, context, false) - if (result != null) { - return result - } - } - - if (element is PyTypedElement) { - val type = context.getType(element) - if (type is PyClassType) { - if (isTypingTypedDictInheritor(type.pyClass, context)) { - return getTypedDictTypeForTypingTDInheritorAsCallee(type.pyClass, context, false) - } - } - } + if (isTypedDict(referenceExpression, context)) { + val parameters = mutableListOf() - if (isTypedDict(referenceExpression, context)) { - val parameters = mutableListOf() - - val builtinCache = PyBuiltinCache.getInstance(referenceExpression) - val languageLevel = LanguageLevel.forElement(referenceExpression) - val generator = PyElementGenerator.getInstance(referenceExpression.project) - - parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_NAME_PARAMETER, builtinCache.getStringType(languageLevel))) - val dictClassType = builtinCache.dictType - parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_FIELDS_PARAMETER, - if (dictClassType != null) PyCollectionTypeImpl(dictClassType.pyClass, false, - listOf(builtinCache.strType, null)) - else dictClassType)) - parameters.add( - PyCallableParameterImpl.nonPsi(TYPED_DICT_TOTAL_PARAMETER, - builtinCache.boolType, - generator.createExpressionFromText(languageLevel, PyNames.TRUE))) - - return PyCallableTypeImpl(parameters, null) - } + val builtinCache = PyBuiltinCache.getInstance(referenceExpression) + val languageLevel = LanguageLevel.forElement(referenceExpression) + val generator = PyElementGenerator.getInstance(referenceExpression.project) + + parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_NAME_PARAMETER, builtinCache.getStringType(languageLevel))) + val dictClassType = builtinCache.dictType + parameters.add(PyCallableParameterImpl.nonPsi(TYPED_DICT_FIELDS_PARAMETER, + if (dictClassType != null) PyCollectionTypeImpl(dictClassType.pyClass, false, + listOf(builtinCache.strType, null)) + else dictClassType)) + parameters.add( + PyCallableParameterImpl.nonPsi(TYPED_DICT_TOTAL_PARAMETER, + builtinCache.boolType, + generator.createExpressionFromText(languageLevel, PyNames.TRUE))) + + return PyCallableTypeImpl(parameters, null) } return null @@ -396,7 +367,7 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() { return typedDictFieldsFromKeysAndValues(fields, context) } - private fun typedDictFieldsFromKeysAndValues(fields: Map, context: TypeEvalContext): TDFields? { + private fun typedDictFieldsFromKeysAndValues(fields: Map, context: TypeEvalContext): TDFields { val result = TDFields() for ((name, type) in fields) { result[name] = if (type != null) PyTypedDictType.FieldTypeAndTotality(context.getType(type)) diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingNewTypeTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingNewTypeTypeProvider.java index 8376f0d36708e..ca959f77dec76 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingNewTypeTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingNewTypeTypeProvider.java @@ -2,12 +2,9 @@ import com.intellij.openapi.util.Ref; import com.intellij.psi.PsiElement; -import com.intellij.psi.ResolveResult; import com.jetbrains.python.psi.*; -import com.jetbrains.python.psi.impl.PyCallExpressionNavigator; import com.jetbrains.python.psi.impl.StubAwareComputation; import com.jetbrains.python.psi.impl.stubs.PyTypingNewTypeStubImpl; -import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.stubs.PyTypingNewTypeStub; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; @@ -24,11 +21,6 @@ public class PyTypingNewTypeTypeProvider extends PyTypeProviderBase { : null; } - @Override - public @Nullable PyType getReferenceExpressionType(@NotNull PyReferenceExpression referenceExpression, @NotNull TypeEvalContext context) { - return getNewTypeForCallee(referenceExpression, context); - } - @Override public Ref getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) { if (referenceTarget instanceof PyTargetExpression) { @@ -38,24 +30,6 @@ public Ref getReferenceType(@NotNull PsiElement referenceTarget, @NotNul return null; } - @Nullable - private static PyTypingNewType getNewTypeForCallee(@NotNull PyReferenceExpression referenceExpression, @NotNull TypeEvalContext context) { - if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null; - - final PyResolveContext resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(context); - final ResolveResult[] resolveResults = referenceExpression.getReference(resolveContext).multiResolve(false); - - for (PsiElement element : PyUtil.filterTopPriorityResults(resolveResults)) { - if (element instanceof PyTargetExpression) { - final PyTypingNewType typeForTarget = getNewTypeForTarget((PyTargetExpression)element, context); - if (typeForTarget != null) { - return typeForTarget; - } - } - } - return null; - } - @Nullable private static PyTypingNewType getNewTypeForTarget(@NotNull PyTargetExpression target, @NotNull TypeEvalContext context) { return StubAwareComputation.on(target)