Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix factories for python>=3.7 #3552

Merged
merged 3 commits into from
Nov 15, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions aiida/plugins/factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from __future__ import print_function
from __future__ import absolute_import

from inspect import isclass
from aiida.common.exceptions import InvalidEntryPointTypeError

__all__ = (
Expand Down Expand Up @@ -55,14 +56,15 @@ def CalculationFactory(entry_point_name):
:return: sub class of :py:class:`~aiida.engine.processes.calcjobs.calcjob.CalcJob`
:raises aiida.common.InvalidEntryPointTypeError: if the type of the loaded entry point is invalid.
"""
from inspect import isclass
yakutovicha marked this conversation as resolved.
Show resolved Hide resolved
from aiida.engine import CalcJob, calcfunction, is_process_function
from aiida.orm import CalcFunctionNode

entry_point_group = 'aiida.calculations'
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (CalcJob, calcfunction)

if issubclass(entry_point, CalcJob):
if isclass(entry_point) and issubclass(entry_point, CalcJob):
return entry_point

if is_process_function(entry_point) and entry_point.node_class is CalcFunctionNode:
Expand All @@ -84,7 +86,7 @@ def DataFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Data,)

if issubclass(entry_point, Data):
if isclass(entry_point) and issubclass(entry_point, Data):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -103,7 +105,7 @@ def DbImporterFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (DbImporter,)

if issubclass(entry_point, DbImporter):
if isclass(entry_point) and issubclass(entry_point, DbImporter):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -122,7 +124,7 @@ def OrbitalFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Orbital,)

if issubclass(entry_point, Orbital):
if isclass(entry_point) and issubclass(entry_point, Orbital):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -141,7 +143,7 @@ def ParserFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Parser,)

if issubclass(entry_point, Parser):
if isclass(entry_point) and issubclass(entry_point, Parser):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -160,7 +162,7 @@ def SchedulerFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Scheduler,)

if issubclass(entry_point, Scheduler):
if isclass(entry_point) and issubclass(entry_point, Scheduler):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -179,7 +181,7 @@ def TransportFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (Transport,)

if issubclass(entry_point, Transport):
if isclass(entry_point) and issubclass(entry_point, Transport):
return entry_point

raise_invalid_type_error(entry_point_name, entry_point_group, valid_classes)
Expand All @@ -199,7 +201,7 @@ def WorkflowFactory(entry_point_name):
entry_point = BaseFactory(entry_point_group, entry_point_name)
valid_classes = (WorkChain, workfunction)

if issubclass(entry_point, WorkChain):
if isclass(entry_point) and issubclass(entry_point, WorkChain):
return entry_point

if is_process_function(entry_point) and entry_point.node_class is WorkFunctionNode:
Expand Down