Skip to content

Commit

Permalink
[DOP-14025] Replace asyncio.gather with asyncio.TaskGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
dolfinus committed Apr 19, 2024
1 parent f0ba198 commit c4bc7e1
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 71 deletions.
25 changes: 17 additions & 8 deletions syncmaster/backend/api/v1/connections.py
@@ -1,5 +1,6 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0
import asyncio
from typing import get_args

from fastapi import APIRouter, Depends, Query, status
Expand Down Expand Up @@ -251,20 +252,28 @@ async def copy_connection(
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> StatusResponseSchema:
resource_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=connection_id,
)
async with asyncio.TaskGroup() as tasks:
resource_role_task = tasks.create_task(
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=connection_id,
)
)
target_group_role_task = tasks.create_task(
unit_of_work.connection.get_group_permission(
user=current_user,
group_id=copy_connection_data.new_group_id,
)
)

resource_role, target_group_role = resource_role_task.result(), target_group_role_task.result()

if resource_role == Permission.NONE:
raise ConnectionNotFoundError

if copy_connection_data.remove_source and resource_role < Permission.DELETE:
raise ActionNotAllowedError

target_group_role = await unit_of_work.connection.get_group_permission(
user=current_user,
group_id=copy_connection_data.new_group_id,
)
if target_group_role == Permission.NONE:
raise GroupNotFoundError

Expand Down
138 changes: 75 additions & 63 deletions syncmaster/backend/api/v1/transfers/router.py
@@ -1,6 +1,8 @@
# SPDX-FileCopyrightText: 2023-2024 MTS (Mobile Telesystems)
# SPDX-License-Identifier: Apache-2.0

import asyncio

from fastapi import APIRouter, Depends, Query, status
from kombu.exceptions import KombuError

Expand Down Expand Up @@ -81,9 +83,18 @@ async def create_transfer(
if group_permission < Permission.WRITE:
raise ActionNotAllowedError

target_connection = await unit_of_work.connection.read_by_id(transfer_data.target_connection_id)
source_connection = await unit_of_work.connection.read_by_id(transfer_data.source_connection_id)
queue = await unit_of_work.queue.read_by_id(transfer_data.queue_id)
async with asyncio.TaskGroup() as tasks:
target_connection_task = tasks.create_task(
unit_of_work.connection.read_by_id(transfer_data.target_connection_id)
)
source_connection_task = tasks.create_task(
unit_of_work.connection.read_by_id(transfer_data.source_connection_id)
)
queue_task = tasks.create_task(unit_of_work.queue.read_by_id(transfer_data.queue_id))

target_connection = target_connection_task.result()
source_connection = source_connection_task.result()
queue = queue_task.result()

if (
target_connection.group_id != source_connection.group_id
Expand All @@ -92,6 +103,9 @@ async def create_transfer(
):
raise DifferentTransferAndConnectionsGroupsError

if transfer_data.group_id != queue.group_id:
raise DifferentTransferAndQueueGroupError

if target_connection.data["type"] != transfer_data.target_params.type:
raise DifferentTypeConnectionsAndParamsError(
connection_type=target_connection.data["type"],
Expand All @@ -106,9 +120,6 @@ async def create_transfer(
params_type=transfer_data.source_params.type,
)

if transfer_data.group_id != queue.group_id:
raise DifferentTransferAndQueueGroupError

transfer_data = process_file_transfer_directory_path(transfer_data) # type: ignore

async with unit_of_work:
Expand Down Expand Up @@ -152,46 +163,53 @@ async def copy_transfer(
current_user: User = Depends(get_user(is_active=True)),
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> StatusCopyTransferResponseSchema:
resource_role = await unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
)
if resource_role == Permission.NONE:
raise TransferNotFoundError
async with asyncio.TaskGroup() as tasks:
resource_role_task = tasks.create_task(
unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
)
)
target_group_role_task = tasks.create_task(
unit_of_work.transfer.get_group_permission(
user=current_user,
group_id=transfer_data.new_group_id,
)
)
transfer_task = tasks.create_task(unit_of_work.transfer.read_by_id(transfer_id))

resource_role = resource_role_task.result()
target_group_role = target_group_role_task.result()
transfer = transfer_task.result()

# Check: user can delete transfer
if transfer_data.remove_source and resource_role < Permission.DELETE:
raise ActionNotAllowedError

target_group_role = await unit_of_work.transfer.get_group_permission(
user=current_user,
group_id=transfer_data.new_group_id,
)
if target_group_role < Permission.WRITE:
raise ActionNotAllowedError

transfer = await unit_of_work.transfer.read_by_id(transfer_id=transfer_id)

# Check: user can copy connection
source_connection_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.source_connection_id,
)
if source_connection_role == Permission.NONE:
raise ConnectionNotFoundError
async with asyncio.TaskGroup() as tasks:
source_connection_role_task = tasks.create_task(
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.source_connection_id,
)
)
target_connection_role_task = tasks.create_task(
unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.target_connection_id,
)
)
target_queue_task = tasks.create_task(unit_of_work.queue.read_by_id(transfer_data.new_queue_id))

target_connection_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=transfer.target_connection_id,
)
if target_connection_role == Permission.NONE:
if Permission.NONE in [source_connection_role_task.result(), target_connection_role_task.result()]:
raise ConnectionNotFoundError

# Check: new queue exists
new_queue = await unit_of_work.queue.read_by_id(queue_id=transfer_data.new_queue_id)

# Acheck: new_queue_id and new_group_id are similar
if new_queue.group_id != transfer_data.new_group_id:
if target_queue_task.result().group_id != transfer_data.new_group_id:
raise DifferentTransferAndQueueGroupError

async with unit_of_work:
Expand Down Expand Up @@ -240,46 +258,38 @@ async def update_transfer(
unit_of_work: UnitOfWork = Depends(UnitOfWorkMarker),
) -> ReadTransferSchema:
# Check: user can update transfer
resource_role = await unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
)
async with asyncio.TaskGroup() as tasks:
resource_role_task = tasks.create_task(
unit_of_work.transfer.get_resource_permission(
user=current_user,
resource_id=transfer_id,
)
)
transfer_task = tasks.create_task(unit_of_work.transfer.read_by_id(transfer_id))

resource_role, transfer = resource_role_task.result(), transfer_task.result()

if resource_role == Permission.NONE:
raise TransferNotFoundError

if resource_role < Permission.WRITE:
raise ActionNotAllowedError

transfer = await unit_of_work.transfer.read_by_id(
transfer_id=transfer_id,
)

target_connection = await unit_of_work.connection.read_by_id(
connection_id=transfer_data.target_connection_id or transfer.target_connection_id,
)
source_connection = await unit_of_work.connection.read_by_id(
connection_id=transfer_data.source_connection_id or transfer.source_connection_id,
)

queue = await unit_of_work.queue.read_by_id(
transfer_data.new_queue_id or transfer.queue_id,
)

# Check: user can read new connections
target_connection_resource_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=target_connection.id,
)
async with asyncio.TaskGroup() as tasks:
target_connection_task = tasks.create_task(
unit_of_work.connection.read_by_id(transfer_data.target_connection_id or transfer.target_connection_id)
)
source_connection_task = tasks.create_task(
unit_of_work.connection.read_by_id(transfer_data.source_connection_id or transfer.source_connection_id)
)
queue_task = tasks.create_task(unit_of_work.queue.read_by_id(transfer_data.new_queue_id or transfer.queue_id))

source_connection_resource_role = await unit_of_work.connection.get_resource_permission(
user=current_user,
resource_id=source_connection.id,
target_connection, source_connection, queue = (
target_connection_task.result(),
source_connection_task.result(),
queue_task.result(),
)

if source_connection_resource_role == Permission.NONE or target_connection_resource_role == Permission.NONE:
raise ConnectionNotFoundError

# Check: connections and transfer group
if (
target_connection.group_id != source_connection.group_id
Expand Down Expand Up @@ -420,6 +430,7 @@ async def start_run(

async with unit_of_work:
run = await unit_of_work.run.create(transfer_id=create_run_data.transfer_id)

try:
celery.send_task("run_transfer_task", kwargs={"run_id": run.id}, queue=transfer.queue.name)
except KombuError as e:
Expand All @@ -429,6 +440,7 @@ async def start_run(
status=Status.FAILED,
)
raise CannotConnectToTaskQueueError(run_id=run.id) from e

return ReadRunSchema.from_orm(run)


Expand Down

0 comments on commit c4bc7e1

Please sign in to comment.