Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 72 additions & 11 deletions src/fmcore/framework/_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,20 +44,62 @@
from pydantic import ConfigDict, conint, model_validator

from fmcore.constants import DataSplit, MLType, MLTypeSchema

from ._dataset import Dataset, Datasets
from ._metric import Metric, Metrics
from ._task_mixins import TaskOrStr, TaskRegistryMixin
from ._predictions import Predictions
from ._dataset import Dataset, Datasets
from ._task_mixins import TaskOrStr, TaskRegistryMixin

MODEL_PARAMS_FILE_NAME: str = "__model_params__.pkl"


class Algorithm(TaskRegistryMixin, Registry, ABC):
_allow_multiple_subclasses: ClassVar[bool] = (
True ## Allows multiple subclasses registered to the same task.
)
_allow_subclass_override: ClassVar[bool] = True ## Allows replacement of subclass with same name.

"""
Base class for all algorithm implementations including GenerativeLM, Classifier, etc.

This class provides a common interface and set of utilities for training and predicting using
Machine Learning algorithms.
It is task-agnostic, and concrete subclasses are required to define the following attributes:
- tasks: A list of tasks the algorithm supports.
- inputs: The Dataset type that the algorithm can process durning train_step.
- outputs: The Predictions type that the algorithm produces during predict_step.

The class encapsulates functionality for training, prediction, evaluation, and parameter management.
Its design ensures that any algorithm subclass adheres to a consistent usage pattern within the framework.
Subclasses must implement abstract methods (at least: initialize(), predict_step(), and _create_predictions())
Often, the framework will extend this with a subclass which is task-specific (e.g. Classifier,
Embedder, etc), adding additional utilities to make it easier to add concrete algorithms for a specific
task.

Example usage:
class MyClassifier(Algorithm):
tasks = ["classification"]
inputs = MyDataset ## Replace with an appropriate Dataset subclass
outputs = MyPredictions ## Replace with an appropriate Predictions subclass

def initialize(self, model_dir=None):
## Initialize the classifier model; for example, load pre-trained weights from model_dir.
pass

def predict_step(self, batch, **kwargs):
## Process the batch and perform a prediction operation.
return ...

def _create_predictions(self, batch, predictions, **kwargs):
## Convert raw predictions into an instance of MyPredictions.
return MyPredictions(predictions)

# Instantiate and use MyClassifier:
classifier = MyClassifier.of(name="MyClassifier", task="classification")
predictions = classifier.predict(dataset)
"""

## Allows multiple subclasses registered to the same task.
_allow_multiple_subclasses: ClassVar[bool] = True
## Allows replacement of subclass with same name.
_allow_subclass_override: ClassVar[bool] = True

## Class variables for algorithm metadata
dataset_statistics: ClassVar[Tuple[Union[str, Dict], ...]] = ()
namespace: ClassVar[Optional[str]] = None
description: ClassVar[Optional[str]] = None
Expand All @@ -67,25 +109,31 @@ class Algorithm(TaskRegistryMixin, Registry, ABC):
num_steps_trained: int = 0
num_rows_trained: int = 0

## Input and output data types for the algorithm
inputs: ClassVar[Type[Dataset]]
feature_mltypes: ClassVar[Optional[Tuple[MLType, ...]]] = None
outputs: ClassVar[Type[Predictions]]

## Default parameters for batching
default_batching_params: ClassVar[Dict[str, Any]] = {}

## Configuration for the model
model_config = ConfigDict(
## Mutable+Extra = allows dynamically adding new items.
extra="allow",
)

@classmethod
def _pre_registration_hook(cls):
## Validate inputs and outputs before registration
cls.inputs = cls._validate_inputs(cls.inputs)
cls.outputs = cls._validate_outputs_type(cls.outputs)
## Ensure 'row_count' is included in dataset statistics
cls.dataset_statistics = tuple(as_set("row_count") | as_set(cls.dataset_statistics))

@classmethod
def _registry_keys(cls) -> Optional[Union[List[Any], Any]]:
## Generate registry keys based on tasks and class names
tasks: List = as_list(cls.tasks)
return (
tasks
Expand All @@ -95,6 +143,7 @@ def _registry_keys(cls) -> Optional[Union[List[Any], Any]]:

@classmethod
def _validate_inputs(cls, inputs: Type[Dataset]) -> Type[Dataset]:
## Ensure the input dataset supports all tasks required by the algorithm
for task in as_list(cls.tasks):
if task not in as_list(inputs.tasks):
raise ValueError(
Expand All @@ -106,6 +155,7 @@ def _validate_inputs(cls, inputs: Type[Dataset]) -> Type[Dataset]:

@classmethod
def _validate_outputs_type(cls, outputs: Type[Predictions]) -> Type[Predictions]:
## Ensure the output dataset supports all tasks required by the algorithm
for task in as_list(cls.tasks):
if task not in as_list(outputs.tasks):
raise ValueError(
Expand All @@ -116,15 +166,18 @@ def _validate_outputs_type(cls, outputs: Type[Predictions]) -> Type[Predictions]
return outputs

def __init__(self, *, stats: Optional[Metrics] = None, **kwargs):
## Initialize the algorithm with optional statistics
super(Algorithm, self).__init__(stats=stats, **kwargs)
self.stats = stats

def __str__(self):
## String representation of the algorithm with its parameters
params_str: str = self.json(indent=4, include={"hyperparams"})
out: str = f"{self.class_name} with params:\n{params_str}"
return out

class Hyperparameters(Parameters):
## Hyperparameters for the algorithm
seed: Optional[int] = None ## Seed used for randomization.
batch_size: Optional[conint(ge=1)] = None ## Training batch size. None allows inference-only models
epochs: Optional[conint(ge=1)] = None ## Number of epochs to train. None allows inference-only models
Expand All @@ -137,6 +190,7 @@ class Hyperparameters(Parameters):
@model_validator(mode="before")
@classmethod
def check_params(cls, params: Dict) -> Dict:
## Ensure that only one of 'epochs' or 'steps' is provided
if all_are_not_none(params.get("epochs"), params.get("steps")):
raise ValueError("Must pass at most one of `epochs` and `steps`; both were passed.")
return params
Expand All @@ -146,13 +200,15 @@ def dict(
include: Optional[Union[Tuple[str, ...], Set[str], Callable]] = None,
**kwargs,
) -> Dict:
## Convert hyperparameters to a dictionary
if is_function(include):
include: Tuple[str, ...] = get_fn_args(include)
if include is not None:
include: Set[str] = as_set(include)
return super().dict(include=include, **kwargs)

def __str__(self) -> str:
## String representation of the hyperparameters
params_str: str = self.json(indent=4)
out: str = f"{self.class_name}:\n{params_str}"
return out
Expand All @@ -161,16 +217,19 @@ def __str__(self) -> str:

@property
def hyperparams_str(self) -> str:
## String representation of hyperparameters for logging or display
return ";".join([f"{k}={v}" for k, v in self.hyperparams.dict().items()])

@classmethod
def create_hyperparams(cls, hyperparams: Optional[Dict] = None) -> Hyperparameters:
## Create an instance of Hyperparameters with default values
hyperparams: Dict = get_default(hyperparams, {})
return cls.Hyperparameters(**hyperparams)

@model_validator(mode="before")
@classmethod
def convert_params(cls, params: Dict) -> Dict:
## Convert and validate parameters for the algorithm
cls.set_default_param_values(params)
## This allows us to create a new Algorithm instance without specifying `hyperparams`.
## If it is specified, we will pick cls.Hyperparameters, which can be overridden by the subclass.
Expand All @@ -195,6 +254,7 @@ def of(
model_dir: Optional[Union[FileMetadata, Dict, str]] = None,
**kwargs,
) -> "Algorithm":
## Factory method to create an instance of the algorithm
kwargs: Dict = remove_nulls(kwargs)
if "algorithm" in kwargs and name is None:
name = kwargs.pop("algorithm")
Expand All @@ -208,7 +268,7 @@ def of(
cache_dir: str = tempfile.TemporaryDirectory().name ## Does not exist yet
cache_dir: FileMetadata = FileMetadata.of(cache_dir)
if all_are_none(name, task) and model_dir is not None:
# print(f'(pid={os.getpid()}) Loading "{AlgorithmClass}" model from dir: "{model_dir.path}"')
## Load model parameters from the specified directory
model_params: Dict = {
**get_default(
Algorithm.load_params(model_dir=model_dir, raise_error=False, tmpdir=cache_dir.path), {}
Expand Down Expand Up @@ -238,7 +298,7 @@ def of(
else:
AlgorithmClass: Type[Algorithm] = cls
if is_abstract(AlgorithmClass):
## Throw an error:
## Throw an error if the class is abstract
abstract_task_subclasses: Set[str] = set()
concrete_subclasses: Set[str] = set()
for key, subclasses_dict in cls._registry.items():
Expand Down Expand Up @@ -282,13 +342,13 @@ def of(
)
task: TaskOrStr = AlgorithmClass.tasks[0]
if model_dir is None:
# print(f'(pid={os.getpid()}) Creating "{AlgorithmClass}" model from scratch.')
## Create a new model from scratch
model: Algorithm = AlgorithmClass(
task=task,
**kwargs,
)
else:
# print(f'(pid={os.getpid()}) Loading "{AlgorithmClass}" model from dir: "{model_dir.path}"')
## Load model parameters from the specified directory
model_params: Dict = {
**get_default(
Algorithm.load_params(model_dir=model_dir, raise_error=False, tmpdir=cache_dir.path), {}
Expand Down Expand Up @@ -331,6 +391,7 @@ def calculate_dataset_stats(
batch_size: Optional[conint(ge=1)] = None,
**kwargs,
) -> Optional[Metrics]:
## Calculate statistics for the given dataset
data_split: DataSplit = get_default(data_split, dataset.data_split)
if data_split is None:
raise ValueError(f"Must pass data_split in either {Dataset.class_name} or explicitly.")
Expand Down
17 changes: 17 additions & 0 deletions src/fmcore/framework/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@


class Dataset(InputOutputDataMixin, Registry, ABC):
"""
A dataset is a collection of data that is used to train, validate, and test machine learning models.
It is a subclass of `InputOutputDataMixin` and `Registry`.
"""

_allow_multiple_subclasses: ClassVar[bool] = False
_allow_subclass_override: ClassVar[bool] = True
_allow_empty_features_schema: ClassVar[bool] = False
Expand Down Expand Up @@ -243,6 +248,18 @@ def load_dataset(
task: Optional[TaskOrStr] = None,
**kwargs,
) -> Optional[DatasetSubclass]:
"""
Ensures that dataset metadata is read consistently by verifying or creating a .dataset-params.json file
if needed. This avoids incorrect merges when several parameter files exist with conflicting settings.

Example usage:
>>> ds = load_dataset("/path/to/dataset")
## Prevents accidental merges by checking for consistent .dataset-params.json params:

>>> # `FileMetadata` can also be passed to specify source options:
>>> ds = load_dataset(FileMetadata(path="/tmp/myfile.parquet", format="parquet"))

"""
if dataset_source is None:
return
## Don't want to mistake with similar params used for prediction:
Expand Down
64 changes: 1 addition & 63 deletions src/fmcore/framework/_evaluator/AccelerateEvaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class AccelerateEvaluator(LocalEvaluator):
@model_validator(mode="before")
@classmethod
def set_accelerate_evaluator_params(cls, params: Dict) -> Dict:
params: Dict = cls._set_common_evaluator_params(params)
set_param_from_alias(
params, param="model_weights_dtype", alias=["weights_dtype", "model_dtype", "torch_dtype"]
)
Expand Down Expand Up @@ -69,8 +70,6 @@ def _load_model(
cache_dir: Optional[Union[FileMetadata, Dict, str]] = None,
**kwargs,
) -> PyTorch:
from fmcore.algorithm.alexa_teacher_models import ALEXA_TM_SEQ2SEQ_MODEL_NAMES

kwargs.pop("device", None) ## We manage the device-allocation in the rest of this function.
kwargs.pop("model_dir", None) ## Do not allow overriding model_dir
kwargs.pop("num_devices", None) ## Use the one passed to evaluator.
Expand All @@ -88,14 +87,6 @@ def _load_model(
# print(f'cuda_visible_devices: {EnvUtil.cuda_visible_devices()}')
# print(f'num_devices: {num_devices}')

alexa_tm_model_name: Optional[str] = self._create_hyperparams().dict().get("model_name")
if alexa_tm_model_name is None:
alexa_tm_model_name: Optional[str] = (
self._create_hyperparams().dict().get("lm", {}).get("hyperparams", {}).get("model_name")
)
# print(f'alexa_tm_model_name: {alexa_tm_model_name}')
if alexa_tm_model_name in ALEXA_TM_SEQ2SEQ_MODEL_NAMES:
return self._load_alexa_tm_model_copy(cache_dir=cache_dir, num_devices=num_devices, **kwargs)
if self.use_hf_from_pretrained:
return self._load_hf_auto_model_class(cache_dir=cache_dir, num_devices=num_devices, **kwargs)
return self._load_model_copy_accelerate(cache_dir=cache_dir, num_devices=num_devices, **kwargs)
Expand Down Expand Up @@ -273,56 +264,3 @@ def _checkpoint_file_path_in_pt_model_snapshot_dir(
f"{snapshot_matching_files}"
)
return snapshot_matching_files[0]

def _load_alexa_tm_model_copy(
self,
cache_dir: FileMetadata,
num_devices: conint(ge=0),
**kwargs,
) -> PyTorch:
from fmcore.algorithm.alexa_teacher_models import AlexaTMSeq2Seq

if num_devices % 2 != 0:
raise ValueError("AlexaTM 20B can only be distributed across an even number of devices.")
## Load the model into CPU memory:
assert cache_dir is not None
model: Algorithm = Algorithm.of(
**{
**dict(
task=self.task,
algorithm=self.AlgorithmClass,
hyperparams=self.hyperparams,
model_dir=self.model_dir,
),
**kwargs,
**dict(
cache_dir=cache_dir,
post_init=False, ## When using accelerate, first init an empty model, then split.
),
}
)
if isinstance(model, LanguageModelTaskMixin):
pt_model: AlexaTMSeq2Seq = model.lm
else:
pt_model: AlexaTMSeq2Seq = model
if not isinstance(pt_model, AlexaTMSeq2Seq):
raise ValueError(f"Expected AlexaTM 20B model, found: {type_str(pt_model)}")
## Move to GPU:
if self.model_weights_dtype == torch.float32:
pt_model.model.float()
elif self.model_weights_dtype == torch.float16:
pt_model.model.half()
elif self.model_weights_dtype == torch.bfloat16:
pt_model.model.bfloat16()
else:
raise NotImplementedError(
f"Unsupported value for `model_weights_dtype`: {self.model_weights_dtype}"
)

pt_model.model.parallelize(num_devices)
pt_model.device = "cuda"
if isinstance(model, LanguageModelTaskMixin):
model.lm = pt_model
else:
model: AlexaTMSeq2Seq = pt_model
return model
7 changes: 3 additions & 4 deletions src/fmcore/framework/_evaluator/Evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
stop_daemon,
)
from bears.util.aws import S3Util
from pydantic import ConfigDict, confloat, conint, model_validator
from pydantic import ConfigDict, confloat, conint

from fmcore import _LIBRARY_NAME
from fmcore.constants import Storage, Task
Expand All @@ -58,7 +58,7 @@ class Evaluator(MutableParameters, Registry, ABC):
_allow_subclass_override: ClassVar[bool] = True ## Allows replacement of subclass with same name.

model_config = ConfigDict(
extra="ignore",
extra="allow",
)

class RunConfig(Parameters):
Expand Down Expand Up @@ -88,9 +88,8 @@ class RunConfig(Parameters):
## Logging verbosity. 0 = zero logging, 1 = Basic logging, 2 = verbose logging, 3 = super verbose logging.
verbosity: conint(ge=0) = 1

@model_validator(mode="before")
@classmethod
def evaluator_params(cls, params: Dict):
def _set_common_evaluator_params(cls, params: Dict):
Alias.set_AlgorithmClass(params)
Alias.set_model_dir(params)
Alias.set_cache_dir(params)
Expand Down
Loading
Loading