diff --git a/python/python/bystro/ancestry/listener.py b/python/python/bystro/ancestry/listener.py index 99a512082..8387cd28d 100644 --- a/python/python/bystro/ancestry/listener.py +++ b/python/python/bystro/ancestry/listener.py @@ -1,11 +1,11 @@ """Provide a worker for the ancestry model.""" import argparse -from collections.abc import Callable import logging from pathlib import Path import boto3 # type: ignore +from botocore.exceptions import ClientError # type: ignore import msgspec import pandas as pd import pyarrow.dataset as ds # type: ignore @@ -17,6 +17,8 @@ from bystro.beanstalkd.messages import BaseMessage, CompletedJobMessage, SubmittedJobMessage from bystro.beanstalkd.worker import ProgressPublisher, QueueConf, get_progress_reporter, listen +from bystro.utils.timer import Timer + logging.basicConfig( filename="ancestry_listener.log", level=logging.DEBUG, @@ -30,19 +32,66 @@ PCA_FILE = "pca.csv" RFC_FILE = "rfc.skop" +models_cache: dict[str, AncestryModel] = {} + + +def _get_model_from_s3(assembly: str) -> AncestryModel: + if assembly in models_cache: + logger.info("Model for assembly %s found in cache.", assembly) + return models_cache[assembly] -def _get_model_from_s3() -> AncestryModel: s3_client = boto3.client("s3") - s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=PCA_FILE, Filename=PCA_FILE) - s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=RFC_FILE, Filename=RFC_FILE) + pca_local_key = f"{assembly}_pca.csv" + rfc_local_key = f"{assembly}_rfc.skop" + + pca_file_key = f"{assembly}/{pca_local_key}" + rfc_file_key = f"{assembly}/{rfc_local_key}" + + logger.info("Downloading PCA file %s", pca_file_key) + + with Timer() as timer: + try: + s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=pca_file_key, Filename=pca_local_key) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + raise ValueError( + f"{assembly} ancestry PCA file not found. This assembly is not supported." + ) + raise # Re-raise the exception if it's not a "NoSuchKey" error + + try: + s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=rfc_file_key, Filename=rfc_local_key) + except ClientError as e: + if e.response["Error"]["Code"] == "NoSuchKey": + raise ValueError( + f"{assembly} ancestry model not found. This assembly is not supported." + ) + raise + + logger.debug("Downloaded PCA file and RFC file in %f seconds", timer.elapsed_time) + + with Timer() as timer: + logger.info("Loading PCA file %s", pca_local_key) + pca_loadings_df = pd.read_csv(pca_local_key, index_col=0) + + logger.info("Loading RFC file %s", rfc_local_key) + rfc = skops_load(rfc_local_key) + + logger.debug("Loaded PCA and RFC files in %f seconds", timer.elapsed_time) - logger.info("Loading PCA file %s", PCA_FILE) - pca_loadings_df = pd.read_csv(PCA_FILE, index_col=0) - logger.info("Loading RFC file %s", RFC_FILE) - rfc = skops_load(RFC_FILE) logger.info("Loaded ancestry models from S3") - return AncestryModel(pca_loadings_df, rfc) + + model = AncestryModel(pca_loadings_df, rfc) + + # Update the cache with the new model + if len(models_cache) >= 2: + # Remove the oldest loaded model to maintain cache size + oldest_assembly = next(iter(models_cache)) + del models_cache[oldest_assembly] + models_cache[assembly] = model + + return model class AncestryJobData(BaseMessage, frozen=True, rename="camel"): @@ -57,10 +106,13 @@ class AncestryJobData(BaseMessage, frozen=True, rename="camel"): The path to the dosage matrix file. out_dir: str The directory to write the results to. + assembly: str + The genome assembly used for the dosage matrix. """ dosage_matrix_path: str out_dir: str + assembly: str class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True, rename="camel"): @@ -76,25 +128,20 @@ def _load_queue_conf(queue_conf_path: str) -> QueueConf: return QueueConf(addresses=beanstalk_conf["addresses"], tubes=beanstalk_conf["tubes"]) -def handler_fn_factory( - ancestry_model: AncestryModel, -) -> Callable[[ProgressPublisher, AncestryJobData], AncestryResults]: - """Return partialed handler_fn with ancestry_model loaded.""" - - def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> AncestryResults: - """Do ancestry job, wrapping infer_ancestry for beanstalk.""" - # Separating _handler_fn from infer_ancestry in order to separate ML from infra concerns, - # and especially to keep infer_ancestry eager. +def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> AncestryResults: + """Do ancestry job, wrapping infer_ancestry for beanstalk.""" + # Separating _handler_fn from infer_ancestry in order to separate ML from infra concerns, + # and especially to keep infer_ancestry eager. - # not doing anything with this reporter at the moment, we're - # simply threading it through for later. - _reporter = get_progress_reporter(publisher) + # not doing anything with this reporter at the moment, we're + # simply threading it through for later. + _reporter = get_progress_reporter(publisher) - dataset = ds.dataset(job_data.dosage_matrix_path, format="arrow") + dataset = ds.dataset(job_data.dosage_matrix_path, format="arrow") - return infer_ancestry(ancestry_model, dataset) + ancestry_model = _get_model_from_s3(job_data.assembly) - return handler_fn + return infer_ancestry(ancestry_model, dataset) def submit_msg_fn(ancestry_job_data: AncestryJobData) -> SubmittedJobMessage: @@ -121,22 +168,6 @@ def completed_msg_fn( ) -def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None: - """Run ancestry listener.""" - handler_fn_with_models = handler_fn_factory(ancestry_model) - logger.info( - "Ancestry worker is listening on addresses: %s, tube: %s...", queue_conf.addresses, ANCESTRY_TUBE - ) - listen( - AncestryJobData, - handler_fn_with_models, - submit_msg_fn, - completed_msg_fn, - queue_conf, - ANCESTRY_TUBE, - ) - - if __name__ == "__main__": parser = argparse.ArgumentParser(description="Process some config files.") parser.add_argument( @@ -146,8 +177,17 @@ def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None: required=True, ) args = parser.parse_args() - - ancestry_model = _get_model_from_s3() queue_conf = _load_queue_conf(args.queue_conf) - main(ancestry_model, queue_conf) + logger.info( + "Ancestry worker is listening on addresses: %s, tube: %s...", queue_conf.addresses, ANCESTRY_TUBE + ) + + listen( + job_data_type=AncestryJobData, + handler_fn=handler_fn, + submit_msg_fn=submit_msg_fn, + completed_msg_fn=completed_msg_fn, + queue_conf=queue_conf, + tube=ANCESTRY_TUBE, + ) diff --git a/python/python/bystro/ancestry/tests/test_inference.py b/python/python/bystro/ancestry/tests/test_inference.py index d8cc9af69..065adbe5a 100644 --- a/python/python/bystro/ancestry/tests/test_inference.py +++ b/python/python/bystro/ancestry/tests/test_inference.py @@ -84,7 +84,7 @@ def test_infer_ancestry(): @pytest.mark.integration() def test_infer_ancestry_from_model(): - ancestry_model = _get_model_from_s3() + ancestry_model = _get_model_from_s3("hg38") # Generate an arrow table that contains genotype dosages for 1000 samples variants = list(ancestry_model.pca_loadings_df.index) diff --git a/python/python/bystro/ancestry/tests/test_listener.py b/python/python/bystro/ancestry/tests/test_listener.py index 825d0efb5..0329dbe82 100644 --- a/python/python/bystro/ancestry/tests/test_listener.py +++ b/python/python/bystro/ancestry/tests/test_listener.py @@ -3,12 +3,12 @@ from bystro.ancestry.listener import ( AncestryJobData, - handler_fn_factory, + handler_fn, submit_msg_fn, completed_msg_fn, SubmittedJobMessage, AncestryJobCompleteMessage, - AncestryResults, + AncestryResults ) from bystro.ancestry.tests.test_inference import ( ANCESTRY_MODEL, @@ -20,21 +20,20 @@ from bystro.beanstalkd.worker import ProgressPublisher -handler_fn = handler_fn_factory(ANCESTRY_MODEL) - - def test_submit_fn(): ancestry_job_data = AncestryJobData( submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/path/to/some/dir", + assembly="hg38", ) submitted_job_message = submit_msg_fn(ancestry_job_data) assert isinstance(submitted_job_message, SubmittedJobMessage) -def test_handler_fn_happy_path(tmpdir): +def test_handler_fn_happy_path(mocker, tmpdir): + mocker.patch("bystro.ancestry.listener._get_model_from_s3", return_value=ANCESTRY_MODEL) dosage_path = "some_dosage.feather" f1 = tmpdir.join(dosage_path) @@ -45,7 +44,7 @@ def test_handler_fn_happy_path(tmpdir): host="127.0.0.1", port=1234, queue="my_queue", message=progress_message ) ancestry_job_data = AncestryJobData( - submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir) + submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir), assembly="hg38" ) ancestry_response = handler_fn(publisher, ancestry_job_data) @@ -62,7 +61,10 @@ def test_handler_fn_happy_path(tmpdir): def test_completion_fn(tmpdir): ancestry_job_data = AncestryJobData( - submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir) + submission_id="my_submission_id2", + dosage_matrix_path="some_dosage.feather", + out_dir=str(tmpdir), + assembly="hg38", ) ancestry_results, _ = _infer_ancestry() @@ -93,7 +95,10 @@ def test_completion_message(): def test_job_data_from_beanstalkd(): ancestry_job_data = AncestryJobData( - submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/foo" + submission_id="my_submission_id2", + dosage_matrix_path="some_dosage.feather", + out_dir="/foo", + assembly="hg38", ) serialized_values = json.encode(ancestry_job_data) @@ -101,6 +106,7 @@ def test_job_data_from_beanstalkd(): "submissionId": "my_submission_id2", "dosageMatrixPath": "some_dosage.feather", "outDir": "/foo", + "assembly": "hg38", } serialized_expected_value = json.encode(expected_value)