Skip to content

Commit

Permalink
feat(refactor):
Browse files Browse the repository at this point in the history
  • Loading branch information
henrikstranneheim committed Sep 14, 2023
1 parent e8e916c commit 23083f0
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 50 deletions.
37 changes: 35 additions & 2 deletions tests/store/crud/test_update.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import subprocess
from datetime import datetime
from typing import List, Optional
from typing import Dict, List, Optional

import pytest

from tests.mocks.store_mock import MockStore
from trailblazer.apps.slurm.api import get_squeue_result
from trailblazer.apps.slurm.models import SqueueResult
from trailblazer.constants import TrailblazerStatus
from trailblazer.constants import CharacterFormat, TrailblazerStatus
from trailblazer.exc import MissingAnalysis, TrailblazerError
from trailblazer.store.filters.user_filters import UserFilter, apply_user_filter
from trailblazer.store.models import Analysis, User
Expand Down Expand Up @@ -242,3 +243,35 @@ def test_update_analysis_comment_when_existing(analysis_store: MockStore, case_i

# THEN comments should have been added
assert analysis.comment == f"{first_comment} {second_comment}"


def test_update_analysis_from_slurm_run_status(
analysis_store: MockStore,
squeue_stream_jobs: str,
mocker,
ongoing_analysis_case_id: str,
slurm_squeue_output: Dict[str, str],
):
"""Test updating analysis jobs when given squeue results."""
# GIVEN an analysis and a squeue stream
analysis: Analysis = analysis_store.get_query(table=Analysis).first()
assert not analysis.jobs

# GIVEN SLURM squeue output for an analysis
mocker.patch(
"trailblazer.store.crud.update.get_slurm_squeue_output",
return_value=subprocess.check_output(
["cat", slurm_squeue_output.get(ongoing_analysis_case_id)]
).decode(CharacterFormat.UNICODE_TRANSFORMATION_FORMAT_8),
)

# WHEN updating the analysis
analysis_store.update_analysis_from_slurm_output(
analysis_id=analysis.id, analysis_host="a_host"
)
updated_analysis: Analysis = analysis_store.get_analysis(
case_id=analysis.family, started_at=analysis.started_at, status=TrailblazerStatus.RUNNING
)

# THEN it should update the analysis jobs
assert updated_analysis.jobs
32 changes: 0 additions & 32 deletions tests/store/test_store_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,38 +20,6 @@ def test_setup_db(store: MockStore):
assert store.engine.table_names()


def test_update_analysis_from_slurm_run_status(
analysis_store: MockStore,
squeue_stream_jobs: str,
mocker,
ongoing_analysis_case_id: str,
slurm_squeue_output: Dict[str, str],
):
"""Test updating analysis jobs when given squeue results."""
# GIVEN an analysis and a squeue stream
analysis: Analysis = analysis_store.get_query(table=Analysis).first()
assert not analysis.jobs

# GIVEN SLURM squeue output for an analysis
mocker.patch(
FUNC_GET_SLURM_SQUEUE_OUTPUT_PATH,
return_value=subprocess.check_output(
["cat", slurm_squeue_output.get(ongoing_analysis_case_id)]
).decode(CharacterFormat.UNICODE_TRANSFORMATION_FORMAT_8),
)

# WHEN updating the analysis
analysis_store.update_analysis_from_slurm_output(
analysis_id=analysis.id, analysis_host="a_host"
)
updated_analysis: Analysis = analysis_store.get_analysis(
case_id=analysis.family, started_at=analysis.started_at, status=TrailblazerStatus.RUNNING
)

# THEN it should update the analysis jobs
assert updated_analysis.jobs


@pytest.mark.parametrize(
"case_id, status",
[
Expand Down
16 changes: 0 additions & 16 deletions trailblazer/store/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,22 +100,6 @@ def update_run_status(self, analysis_id: int, analysis_host: Optional[str] = Non
analysis_id=analysis_id, analysis_host=analysis_host
)

def update_analysis_from_slurm_output(
self, analysis_id: int, analysis_host: Optional[str] = False
) -> None:
"""Query SLURM for entries related to given analysis, and update the analysis in the database."""
analysis: Optional[Analysis] = self.get_analysis_with_id(analysis_id=analysis_id)
try:
self._update_analysis_from_slurm_squeue_output(
analysis=analysis, analysis_host=analysis_host
)
except Exception as exception:
LOG.error(
f"Error updating analysis for: case - {analysis.family} : {exception.__class__.__name__}"
)
analysis.status = TrailblazerStatus.ERROR
self.commit()

@staticmethod
def query_tower(config_file: str, case_id: str) -> TowerAPI:
"""Parse a config file to extract a NF Tower workflow ID and return a TowerAPI.
Expand Down
16 changes: 16 additions & 0 deletions trailblazer/store/crud/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,22 @@ def _update_analysis_from_slurm_squeue_output(
analysis.logged_at = datetime.now()
self.commit()

def update_analysis_from_slurm_output(
self, analysis_id: int, analysis_host: Optional[str] = False
) -> None:
"""Query SLURM for entries related to given analysis, and update the analysis in the database."""
analysis: Optional[Analysis] = self.get_analysis_with_id(analysis_id=analysis_id)
try:
self._update_analysis_from_slurm_squeue_output(
analysis=analysis, analysis_host=analysis_host
)
except Exception as exception:
LOG.error(
f"Error updating analysis for: case - {analysis.family} : {exception.__class__.__name__}"
)
analysis.status = TrailblazerStatus.ERROR
self.commit()

def update_case_analyses_as_deleted(self, case_id: str) -> Optional[List[Analysis]]:
"""Mark analyses connected to a case as deleted."""
analyses: Optional[List[Analysis]] = self.get_analyses_for_case(case_id=case_id)
Expand Down

0 comments on commit 23083f0

Please sign in to comment.