Skip to content

Commit

Permalink
Dependencies: Update requirement mypy~=1.7 (#6188)
Browse files Browse the repository at this point in the history
This allows to get rid of many exclude statements since those
corresponded to bugs in `mypy` that have now been fixed.
  • Loading branch information
sphuber authored Nov 21, 2023
1 parent d3788ad commit c2fcad4
Show file tree
Hide file tree
Showing 17 changed files with 75 additions and 162 deletions.
2 changes: 1 addition & 1 deletion aiida/engine/processes/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def kill(self, msg: Union[str, None] = None) -> Union[bool, plumpy.futures.Futur

def done(done_future: plumpy.futures.Future):
is_all_killed = all(done_future.result())
result.set_result(is_all_killed) # type: ignore[union-attr]
result.set_result(is_all_killed)

kill_future.add_done_callback(done)

Expand Down
4 changes: 1 addition & 3 deletions aiida/engine/processes/workchains/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,7 @@ def wrapper(wrapped, instance, args, kwargs):
# When the handler will be called by the `BaseRestartWorkChain` it will pass the node as the only argument
node = args[0]

if exit_codes is not None and node.exit_status not in [
exit_code.status for exit_code in exit_codes # type: ignore[union-attr]
]:
if exit_codes is not None and node.exit_status not in [exit_code.status for exit_code in exit_codes]:
result = None
else:
result = wrapped(*args, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion aiida/engine/transports.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ async def transport_task(transport_queue, authinfo):

def do_open():
""" Actually open the transport """
if transport_request and transport_request.count > 0:
if transport_request.count > 0:
# The user still wants the transport so open it
_LOGGER.debug('Transport request opening transport for %s', authinfo)
try:
Expand Down
2 changes: 1 addition & 1 deletion aiida/orm/nodes/data/singlefile.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def open(self, path: str, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
def open(self, path: None, mode: t.Literal['rb']) -> t.Iterator[t.BinaryIO]:
...

@contextlib.contextmanager # type: ignore[misc]
@contextlib.contextmanager
def open(self, path: str | None = None, mode: t.Literal['r', 'rb'] = 'r') -> t.Iterator[t.BinaryIO | t.TextIO]:
"""Return an open file handle to the content of this data node.
Expand Down
23 changes: 9 additions & 14 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,8 @@ def _clear(self) -> None:

with self.transaction() as session:
session.execute(
DbSetting.__table__.update().where(
DbSetting.key == REPOSITORY_UUID_KEY # type: ignore[attr-defined]
).values(val=repository_uuid)
DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY
).values(val=repository_uuid)
)

def get_repository(self) -> 'DiskObjectStoreRepositoryBackend':
Expand Down Expand Up @@ -358,17 +357,14 @@ def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]) -> None:

session = self.get_session()
# Delete the membership of these nodes to groups.
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete)) # type: ignore[attr-defined]
session.query(DbGroupNode).filter(DbGroupNode.dbnode_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch')
# Delete the links coming out of the nodes marked for deletion.
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
session.query(DbLink).filter(DbLink.input_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the links pointing to the nodes marked for deletion.
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
session.query(DbLink).filter(DbLink.output_id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')
# Delete the actual nodes
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))
).delete(synchronize_session='fetch') # type: ignore[attr-defined]
session.query(DbNode).filter(DbNode.id.in_(list(pks_to_delete))).delete(synchronize_session='fetch')

def get_backend_entity(self, model: base.Base) -> BackendEntity:
"""
Expand All @@ -386,10 +382,9 @@ def set_global_variable(

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
if session.query(DbSetting).filter(DbSetting.key == key).count(): # type: ignore[attr-defined]
if session.query(DbSetting).filter(DbSetting.key == key).count():
if overwrite:
session.query(DbSetting).filter(DbSetting.key == key
).update(dict(val=value)) # type: ignore[attr-defined]
session.query(DbSetting).filter(DbSetting.key == key).update(dict(val=value))
else:
raise ValueError(f'The setting {key} already exists')
else:
Expand All @@ -400,7 +395,7 @@ def get_global_variable(self, key: str) -> Union[None, str, int, float]:

session = self.get_session()
with (nullcontext() if self.in_transaction else self.transaction()):
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none() # type: ignore[attr-defined]
setting = session.query(DbSetting).filter(DbSetting.key == key).one_or_none()
if setting is None:
raise KeyError(f'No setting found with key {key}')
return setting.val
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,8 @@ def migrate_infer_calculation_entry_point(alembic_op):
fallback_cases.append([uuid, type_string, entry_point_string])

connection.execute(
DbNode.update().where(
DbNode.c.type == alembic_op.inline_literal(type_string) # type: ignore[attr-defined]
).values(process_type=alembic_op.inline_literal(entry_point_string))
DbNode.update().where(DbNode.c.type == alembic_op.inline_literal(type_string)
).values(process_type=alembic_op.inline_literal(entry_point_string))
)

if fallback_cases:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute( # type: ignore[var-annotated]
nodes = connection.execute(
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()

for pk, uuid in nodes:
symbols = load_numpy_array_from_repository(repo_path, uuid, 'symbols').tolist()
connection.execute(
DbNode.update().where(DbNode.c.id == pk).values( # type: ignore[attr-defined]
DbNode.update().where(DbNode.c.id == pk).values(
attributes=func.jsonb_set(DbNode.c.attributes, op.inline_literal('{"symbols"}'), cast(symbols, JSONB))
)
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def upgrade():
column('attributes', JSONB),
)

nodes = connection.execute( # type: ignore[var-annotated]
nodes = connection.execute(
select(DbNode.c.id,
DbNode.c.uuid).where(DbNode.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.'))
).fetchall()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def upgrade():
sa.column('type', sa.String),
)

nodes = connection.execute( # type: ignore[var-annotated]
nodes = connection.execute(
sa.select(node_model.c.id, node_model.c.uuid).where(
node_model.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def upgrade():
# sa.column('attributes', JSONB),
)

nodes = connection.execute( # type: ignore[var-annotated]
nodes = connection.execute(
sa.select(node_tbl.c.id, node_tbl.c.uuid).where(
node_tbl.c.type == op.inline_literal('node.data.array.trajectory.TrajectoryData.')
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,12 @@ def upgrade():
if node_count:
with get_progress_reporter()(total=node_count, desc='Updating attributes and extras') as progress:
for node in conn.execute(select(node_tbl)).all():
attr_list = conn.execute( # type: ignore[var-annotated]
select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)
).all()
attr_list = conn.execute(select(attr_tbl).where(attr_tbl.c.dbnode_id == node.id)).all()
attributes, _ = attributes_to_dict(sorted(attr_list, key=lambda a: a.key))
extra_list = conn.execute( # type: ignore[var-annotated]
select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)
).all()
extra_list = conn.execute(select(extra_tbl).where(extra_tbl.c.dbnode_id == node.id)).all()
extras, _ = attributes_to_dict(sorted(extra_list, key=lambda a: a.key))
conn.execute(
node_tbl.update().where( # type: ignore[attr-defined]
node_tbl.c.id == node.id
).values(attributes=attributes, extras=extras)
node_tbl.update().where(node_tbl.c.id == node.id).values(attributes=attributes, extras=extras)
)
progress.update()

Expand Down Expand Up @@ -135,9 +129,8 @@ def upgrade():
else:
val = setting.dval
conn.execute(
setting_tbl.update().where( # type: ignore[attr-defined]
setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
setting_tbl.update().where(setting_tbl.c.id == setting.id
).values(val=cast(val, postgresql.JSONB(astext_type=sa.Text())))
)
progress.update()

Expand Down
16 changes: 8 additions & 8 deletions aiida/storage/sqlite_zip/migrations/v1_db_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,8 @@ class DbAuthInfo(ArchiveV1Base):
nullable=True,
index=True
)
_metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated]
auth_params = Column(JSON, default=dict, nullable=True) # type: ignore[misc]
_metadata = Column('metadata', JSON, default=dict, nullable=True)
auth_params = Column(JSON, default=dict, nullable=True)
enabled = Column(Boolean, default=True, nullable=True)


Expand Down Expand Up @@ -96,7 +96,7 @@ class DbComputer(ArchiveV1Base):
description = Column(Text, default='', nullable=True)
scheduler_type = Column(String(255), default='', nullable=True)
transport_type = Column(String(255), default='', nullable=True)
_metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated]
_metadata = Column('metadata', JSON, default=dict, nullable=True)


class DbGroupNodes(ArchiveV1Base):
Expand Down Expand Up @@ -126,7 +126,7 @@ class DbGroup(ArchiveV1Base):
type_string = Column(String(255), default='', nullable=True, index=True)
time = Column(DateTime(timezone=True), default=timezone.now, nullable=True)
description = Column(Text, default='', nullable=True)
extras = Column(JSON, default=dict, nullable=False) # type: ignore[misc]
extras = Column(JSON, default=dict, nullable=False)
user_id = Column(
Integer,
ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'),
Expand All @@ -152,7 +152,7 @@ class DbLog(ArchiveV1Base):
index=True
)
message = Column(Text(), default='', nullable=True)
_metadata = Column('metadata', JSON, default=dict, nullable=True) # type: ignore[var-annotated]
_metadata = Column('metadata', JSON, default=dict, nullable=True)


class DbNode(ArchiveV1Base):
Expand All @@ -168,9 +168,9 @@ class DbNode(ArchiveV1Base):
description = Column(Text(), default='', nullable=True)
ctime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True)
mtime = Column(DateTime(timezone=True), default=timezone.now, nullable=True, index=True)
attributes = Column(JSON) # type: ignore[var-annotated]
extras = Column(JSON) # type: ignore[var-annotated]
repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}') # type: ignore[var-annotated]
attributes = Column(JSON)
extras = Column(JSON)
repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}')
dbcomputer_id = Column(
Integer,
ForeignKey('db_dbcomputer.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'),
Expand Down
Loading

0 comments on commit c2fcad4

Please sign in to comment.