Skip to content

Commit

Permalink
chore: add even more docs
Browse files Browse the repository at this point in the history
  • Loading branch information
BenTenmann committed Dec 4, 2021
1 parent 29328ad commit ec4320d
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions src/setriq/modules/_distances.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def forward(self, sequences: List[str]) -> List[float]:

class TcrDistComponent(Metric):
"""
The TcrDistComponent class.
The TcrDistComponent class. Inherits from Metric.
Examples
--------
Expand Down Expand Up @@ -105,8 +105,8 @@ def forward(self, sequences: List[str]) -> List[float]:

class TcrDist(Metric):
"""
TcrDist class. This is a container class for individual TcrDistComponent instances. Components are executed
sequentially and their results aggregated at the end (summation).
TcrDist class. Inherits from Metric. It is a container class for individual TcrDistComponent instances. Components
are executed sequentially and their results aggregated at the end (summation).
Attributes
----------
Expand All @@ -120,7 +120,7 @@ class TcrDist(Metric):
... {'cdr_1': '', 'cdr_2': '', 'cdr_3': 'CASS-HIANY'},
... {'cdr_1': '', 'cdr_2': '', 'cdr_3': 'CASRGAT--Q'}
... ]
>>> metric = TcrDist() # will produce a warning stating default parameters (Dash et al)
>>> metric = TcrDist() # will produce a warning stating default configuration (Dash et al)
>>> distances = metric(sequences)
References
Expand Down Expand Up @@ -165,24 +165,31 @@ def __init__(self, **components):
>>>
>>> metric = TcrDist(cmp_1=component_1, cmp_2=component_2)
Keep in mind that the keys will be used to assiociate the components to the relevant input, i.e. in this case
the input should take the shape:
Keep in mind that the keys will be used to assiociate the components to the relevant input fields, i.e. in this
case the input should take the shape of:
>>> [{'cmp_1': '...', 'cmp_2': '...'}, ...]
additional keys will have no effect.
"""
parts = []

# user-defined configuration
if components:
for name, component in components.items():
# some type checking
if not isinstance(component, TcrDistComponent):
raise TypeError(f'{repr(name)} is not of type {TcrDistComponent}')

self.__setattr__(name, component)
parts.append(name)

# default configuration
else:
for name, definition in self._default:
self.__setattr__(name, TcrDistComponent(**definition))
parts.append(name)

# warn user that default has been initialised and inform required input format
warnings.warn(self._default_msg, UserWarning)

self.components = parts
Expand All @@ -196,22 +203,45 @@ def _check_input_format(self, ipt):

@property
def required_input_keys(self) -> List[str]:
"""
Get the keys (=fields) required in the input to TcrDist instance.
Returns
-------
required_input_keys : List[str]
returns a list of strings signifying the keys required in the input
"""
return self.components

@property
def default_definition(self) -> List[tuple]:
"""
Get the default TcrDistComponent schema as defined by Dash et al.
Returns
-------
default_schema : List[tuple]
returns the schema for the TcrDistComponent instances held in the default configuration
"""
return self._default

def forward(self, sequences: List[Dict[str, str]]) -> List[float]:
# check the input keys provided -- assumes consistency
self._check_input_format(sequences[0])

# iterate through components and collect component output
out = []
for part in self.components:
# gather sequences of the associated field into a list
sqs = glom(sequences, [part])
component = self.__getattribute__(part)

# execute component on list of associated sequences
result = component(sqs)
out.append(result)

# aggregate the component outputs
out = np.array(out).sum(axis=0)
return out.tolist()

0 comments on commit ec4320d

Please sign in to comment.