Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assign tasks to workers #525

Merged
merged 9 commits into from
Apr 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion azimuth/modules/base_classes/dask_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import threading
import time
import uuid
from enum import IntEnum
from functools import partial
from os.path import join as pjoin
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, cast
Expand All @@ -25,6 +26,11 @@
ConfigScope = TypeVar("ConfigScope", bound=CommonFieldsConfig)


class Worker(IntEnum):
model = 0
encoder = 0


class DaskModule(HDF5CacheMixin, Generic[ConfigScope]):
"""Abstract class that define an item of work to be computed on the cluster.

Expand All @@ -37,6 +43,7 @@ class DaskModule(HDF5CacheMixin, Generic[ConfigScope]):
"""

allowed_splits = {DatasetSplitName.train, DatasetSplitName.eval}
worker: Optional[Worker] = None

def __init__(
self,
Expand Down Expand Up @@ -118,6 +125,7 @@ def start_task_on_dataset_split(
pure=False,
dependencies=deps,
key=f"{self.task_id}_{uuid.uuid4()}", # Unique identifier
workers=self.worker,
gabegma marked this conversation as resolved.
Show resolved Hide resolved
)
# Tell that this future is used on which indices.
self.future.indices = self.get_caching_indices()
Expand Down Expand Up @@ -148,7 +156,11 @@ def start_task(self, client: Client, custom_query: Dict[str, Any]) -> "DaskModul
log.info(f"Starting custom query {self.name}")
# pure=false to be sure that everything is rerun.
self.future = client.submit(
self.compute, custom_query, key=self.custom_query_task_id(custom_query), pure=False
self.compute,
custom_query,
key=self.custom_query_task_id(custom_query),
pure=False,
workers=self.worker,
)
# Tell that this future is for custom use only.
self.future.is_custom = True
Expand Down
2 changes: 2 additions & 0 deletions azimuth/modules/base_classes/indexable_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from azimuth.config import ModelContractConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.base_classes import ConfigScope, Module
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.types import (
DatasetColumn,
DatasetSplitName,
Expand Down Expand Up @@ -76,6 +77,7 @@ def save_result(self, res: List[ModuleResponse], dm: DatasetSplitManager):
class ModelContractModule(DatasetResultModule[ModelContractConfig], abc.ABC):
required_mod_options: Set[str] = {"pipeline_index", "model_contract_method_name"}
optional_mod_options: Set[str] = DatasetResultModule.optional_mod_options | {"threshold"}
worker = Worker.model

def compute(self, batch: Dataset) -> List[ModuleResponse]:
my_func = self.route_request(assert_not_none(self.model_contract_method_name))
Expand Down
3 changes: 3 additions & 0 deletions azimuth/modules/base_classes/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from azimuth.config import ModelContractConfig, PipelineDefinition
from azimuth.dataset_split_manager import DatasetSplitManager, PredictionTableKey
from azimuth.modules.base_classes import ArtifactManager, ConfigScope, DaskModule
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.types import DatasetColumn, DatasetSplitName, ModuleOptions, ModuleResponse
from azimuth.types.general.module_arguments import ModuleEffectiveArguments
from azimuth.utils.conversion import md5_hash
Expand Down Expand Up @@ -164,6 +165,8 @@ def get_model(self):
Raises:
ValueError if no valid pipeline exists.
"""
if self.worker != Worker.model:
raise RuntimeError("This module cannot load the model. Modify self.worker.")
_ = self.get_pipeline_definition() # Validate current pipeline exists
return self.artifact_manager.get_model(self.config, self.mod_options.pipeline_index)

Expand Down
5 changes: 5 additions & 0 deletions azimuth/modules/dataset_analysis/similarity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from azimuth.config import SimilarityConfig, SimilarityOptions
from azimuth.dataset_split_manager import FEATURE_FAISS, DatasetSplitManager
from azimuth.modules.base_classes import DatasetResultModule, IndexableModule
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.modules.task_execution import get_task_result
from azimuth.types import Array, DatasetColumn, DatasetSplitName, ModuleOptions
from azimuth.types.similarity_analysis import FAISSResponse
Expand All @@ -28,6 +29,8 @@
class FAISSModule(IndexableModule[SimilarityConfig]):
"""Compute the FAISS features for a dataset split."""

worker = Worker.encoder

def __init__(
self,
dataset_split_name: DatasetSplitName,
Expand All @@ -45,6 +48,8 @@ def get_encoder_name_or_path(self):
return model_name_or_path

def get_encoder(self):
gabegma marked this conversation as resolved.
Show resolved Hide resolved
if self.worker != Worker.encoder:
raise RuntimeError("This module cannot load the encoder. Modify self.worker.")
if self.encoder is None:
with FileLock(os.path.join(self.cache_dir, "st.lock")):
self.encoder = SentenceTransformer(self.get_encoder_name_or_path())
Expand Down
4 changes: 4 additions & 0 deletions azimuth/modules/perturbation_testing/perturbation_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from azimuth.config import PerturbationTestingConfig
from azimuth.dataset_split_manager import DatasetSplitManager
from azimuth.modules.base_classes import DatasetResultModule
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping
from azimuth.types import (
DatasetColumn,
Expand Down Expand Up @@ -58,6 +59,9 @@ class PerturbationTestingModule(DatasetResultModule[PerturbationTestingConfig]):
"""

required_mod_options = {"pipeline_index"}
# This module doesn't call self.get_model() but requires the model (predict_task.compute(batch))
gabegma marked this conversation as resolved.
Show resolved Hide resolved
# TODO Find a more robust way to determine when modules require models.
worker = Worker.model

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions azimuth/modules/validation/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from azimuth.config import ModelContractConfig
from azimuth.modules.base_classes import AggregationModule
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.modules.model_contract_task_mapping import model_contract_task_mapping
from azimuth.types import ModuleOptions, SupportedMethod, SupportedModelContract
from azimuth.types.validation import ValidationResponse
Expand Down Expand Up @@ -36,6 +37,7 @@ def try_calling_function(self, fn, *args, **kwargs) -> Optional[Any]:

class ValidationModule(AggregationModule[ModelContractConfig]):
optional_mod_options = {"pipeline_index"}
worker = Worker.model

def compute_on_dataset_split(self) -> List[ValidationResponse]: # type: ignore
cuda_available = torch.cuda.is_available()
Expand Down
2 changes: 2 additions & 0 deletions azimuth/modules/word_analysis/top_words.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from azimuth.config import TopWordsConfig
from azimuth.modules.base_classes import FilterableModule
from azimuth.modules.base_classes.dask_module import Worker
from azimuth.modules.task_execution import get_task_result
from azimuth.modules.word_analysis.tokens_to_words import TokensToWordsModule
from azimuth.types import ModuleOptions
Expand All @@ -33,6 +34,7 @@ class TopWordsModule(FilterableModule[TopWordsConfig]):
"th_importance",
"force_no_saliency",
}
worker = Worker.model

@staticmethod
def count_words(list_of_words: List[str], top_x: int) -> List[TopWordsResult]:
Expand Down
17 changes: 16 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ bokeh = "<3"
# Documentation
mkdocs = "^1.2.3"
mkdocs-material = "^8.1.7"
memory-profiler = "^0.61.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down