Skip to content

Commit

Permalink
Refactor to classmethods
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrielmbmb committed May 14, 2024
1 parent ac88df4 commit 3384de9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 34 deletions.
38 changes: 19 additions & 19 deletions src/distilabel/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import logging
import sys
from abc import ABC, abstractmethod
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, Field, PrivateAttr
Expand Down Expand Up @@ -112,17 +111,17 @@ def generate(
"""
pass

@property
def generate_parameters(self) -> List["inspect.Parameter"]:
@classmethod
def generate_parameters(cls) -> List["inspect.Parameter"]:
"""Returns the parameters of the `generate` method.
Returns:
A list containing the parameters of the `generate` method.
"""
return list(inspect.signature(self.generate).parameters.values())
return list(inspect.signature(cls.generate).parameters.values())

@property
def runtime_parameters_names(self) -> "RuntimeParametersNames":
@classmethod
def runtime_parameters_names(cls) -> "RuntimeParametersNames":
"""Returns the runtime parameters of the `LLM`, which are combination of the
attributes of the `LLM` type hinted with `RuntimeParameter` and the parameters
of the `generate` method that are not `input` and `num_generations`.
Expand All @@ -131,19 +130,20 @@ def runtime_parameters_names(self) -> "RuntimeParametersNames":
A dictionary with the name of the runtime parameters as keys and a boolean
indicating if the parameter is optional or not.
"""
runtime_parameters = super().runtime_parameters_names
runtime_parameters = super().runtime_parameters_names()
runtime_parameters["generation_kwargs"] = {}

# runtime parameters from the `generate` method
for param in self.generate_parameters:
for param in cls.generate_parameters():
if param.name in ["input", "inputs", "num_generations"]:
continue
is_optional = param.default != inspect.Parameter.empty
runtime_parameters["generation_kwargs"][param.name] = is_optional

return runtime_parameters

def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
@classmethod
def get_runtime_parameters_info(cls) -> List[Dict[str, Any]]:
"""Gets the information of the runtime parameters of the `LLM` such as the name
and the description. This function is meant to include the information of the runtime
parameters in the serialized data of the `LLM`.
Expand All @@ -159,7 +159,7 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
if runtime_parameter_info["name"] == "generation_kwargs"
)

generate_docstring_args = self.generate_parsed_docstring["args"]
generate_docstring_args = cls.generate_parsed_docstring()["args"]

generation_kwargs_info["keys"] = []
for key, value in generation_kwargs_info["optional"].items():
Expand All @@ -172,14 +172,14 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:

return runtime_parameters_info

@cached_property
def generate_parsed_docstring(self) -> "Docstring":
@classmethod
def generate_parsed_docstring(cls) -> "Docstring":
"""Returns the parsed docstring of the `generate` method.
Returns:
The parsed docstring of the `generate` method.
"""
return parse_google_docstring(self.generate)
return parse_google_docstring(cls.generate)

def get_last_hidden_states(self, inputs: List["ChatType"]) -> List["HiddenState"]:
"""Method to get the last hidden states of the model for a list of inputs.
Expand Down Expand Up @@ -208,23 +208,23 @@ class AsyncLLM(LLM):

_event_loop: "asyncio.AbstractEventLoop" = PrivateAttr(default=None)

@property
def generate_parameters(self) -> List[inspect.Parameter]:
@classmethod
def generate_parameters(cls) -> List[inspect.Parameter]:
"""Returns the parameters of the `agenerate` method.
Returns:
A list containing the parameters of the `agenerate` method.
"""
return list(inspect.signature(self.agenerate).parameters.values())
return list(inspect.signature(cls.agenerate).parameters.values())

@cached_property
def generate_parsed_docstring(self) -> "Docstring":
@classmethod
def generate_parsed_docstring(cls) -> "Docstring":
"""Returns the parsed docstring of the `agenerate` method.
Returns:
The parsed docstring of the `agenerate` method.
"""
return parse_google_docstring(self.agenerate)
return parse_google_docstring(cls.agenerate)

@property
def event_loop(self) -> "asyncio.AbstractEventLoop":
Expand Down
39 changes: 24 additions & 15 deletions src/distilabel/mixins/runtime_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ class RuntimeParametersMixin(BaseModel):

_runtime_parameters: Dict[str, Any] = PrivateAttr(default_factory=dict)

@property
def runtime_parameters_names(self) -> RuntimeParametersNames:
@classmethod
def runtime_parameters_names(cls) -> RuntimeParametersNames:
"""Returns a dictionary containing the name of the runtime parameters of the class
as keys and whether the parameter is required or not as values.
Expand All @@ -56,19 +56,24 @@ def runtime_parameters_names(self) -> RuntimeParametersNames:

runtime_parameters = {}

for name, field_info in self.model_fields.items(): # type: ignore
for name, field_info in cls.model_fields.items(): # type: ignore
is_runtime_param, is_optional = _is_runtime_parameter(field_info)
if is_runtime_param:
runtime_parameters[name] = is_optional
continue

attr = getattr(self, name)
if isinstance(attr, RuntimeParametersMixin):
runtime_parameters[name] = attr.runtime_parameters_names
klass = field_info.annotation
if (
klass
and inspect.isclass(klass)
and issubclass(RuntimeParametersMixin, klass)
):
runtime_parameters[name] = klass.runtime_parameters_names()

return runtime_parameters

def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
@classmethod
def get_runtime_parameters_info(cls) -> List[Dict[str, Any]]:
"""Gets the information of the runtime parameters of the class such as the name and
the description. This function is meant to include the information of the runtime
parameters in the serialized data of the class.
Expand All @@ -77,21 +82,25 @@ def get_runtime_parameters_info(self) -> List[Dict[str, Any]]:
A list containing the information for each runtime parameter of the class.
"""
runtime_parameters_info = []
for name, field_info in self.model_fields.items(): # type: ignore
if name not in self.runtime_parameters_names:
for name, field_info in cls.model_fields.items(): # type: ignore
if name not in cls.runtime_parameters_names():
continue

attr = getattr(self, name)
if isinstance(attr, RuntimeParametersMixin):
klass = field_info.annotation
if (
klass
and inspect.isclass(klass)
and issubclass(RuntimeParametersMixin, klass)
):
runtime_parameters_info.append(
{
"name": name,
"runtime_parameters_info": attr.get_runtime_parameters_info(),
"runtime_parameters_info": klass.get_runtime_parameters_info(),
}
)
continue

info = {"name": name, "optional": self.runtime_parameters_names[name]}
info = {"name": name, "optional": cls.runtime_parameters_names()[name]}
if field_info.description is not None:
info["description"] = field_info.description
runtime_parameters_info.append(info)
Expand All @@ -106,9 +115,9 @@ def set_runtime_parameters(self, runtime_parameters: Dict[str, Any]) -> None:
runtime_parameters: A dictionary containing the values of the runtime parameters
to set.
"""
runtime_parameters_names = list(self.runtime_parameters_names.keys())
runtime_parameters_names = list(self.runtime_parameters_names().keys())
for name, value in runtime_parameters.items():
if name not in self.runtime_parameters_names:
if name not in self.runtime_parameters_names():
# Check done just to ensure the unit tests for the mixin run
if getattr(self, "pipeline", None):
closest = difflib.get_close_matches(
Expand Down

0 comments on commit 3384de9

Please sign in to comment.