Skip to content

Commit

Permalink
Removes fit_transform_with_metadata() API, refactors ProjectorModel.
Browse files Browse the repository at this point in the history
* Removes fit_transform_with_metadata() API from LIT.
* Moves ProjectorModel from components/projector.py to api/model.py to make it accessible to lib/ and components/ modules.
* Updates components/projector.py type annotations to reference lit_nlp.api.model.ProjectorModel.
* Refactors PCAModel to inherit from lit_nlp.api.model.ProjectorModel.
* Refactors UmapModel to inherit from lit_nlp.api.model.ProjectorModel.
* Refactors CachingModelWrapper.fit_transform() to check inheritance from ProjectorModel before calling wrapped.fit_transform().
* Removes fit_transform() API from Model base class.

PiperOrigin-RevId: 552842861
  • Loading branch information
RyanMullins authored and LIT team committed Aug 1, 2023
1 parent 3f38adb commit e30e59a
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 58 deletions.
42 changes: 24 additions & 18 deletions lit_nlp/api/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,6 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
raise NotImplementedError('get_embedding_table() not implemented for ' +
self.__class__.__name__)

def fit_transform(
self, inputs: Iterable[types.JsonDict]
) -> Iterable[types.JsonDict]:
"""For internal use by UMAP and other sklearn-based models."""
raise NotImplementedError(
'fit_transform() not implemented for ' + self.__class__.__name__)

def fit_transform_with_metadata(
self, indexed_inputs: Iterable[types.IndexedInput]
) -> Iterable[types.JsonDict]:
"""For internal use by UMAP and other sklearn-based models."""
return self.fit_transform((ii['data'] for ii in indexed_inputs))

##
# Concrete implementations of common functions.
def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]:
Expand Down Expand Up @@ -301,11 +288,6 @@ def output_spec(self) -> types.Spec:
def get_embedding_table(self) -> tuple[list[str], np.ndarray]:
return self.wrapped.get_embedding_table()

def fit_transform_with_metadata(
self, indexed_inputs: Iterable[types.IndexedInput]
):
return self.wrapped.fit_transform_with_metadata(indexed_inputs)


class BatchedRemoteModel(Model):
"""Generic base class for remotely-hosted models.
Expand Down Expand Up @@ -363,3 +345,27 @@ def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]:
list of outputs, following model.output_spec()
"""
return


class ProjectorModel(Model, metaclass=abc.ABCMeta):
"""LIT Model API for dimensionality reduction."""

##
# Training methods
@abc.abstractmethod
def fit_transform(self, inputs: Iterable[JsonDict]) -> list[JsonDict]:
"""For internal use by SciKit Learn-based models."""
pass

##
# LIT model API
def input_spec(self):
# 'x' denotes input features
return {'x': types.Embeddings()}

def output_spec(self):
# 'z' denotes projected embeddings
return {'z': types.Embeddings()}

def max_minibatch_size(self, **unused_kw):
return 1000
6 changes: 3 additions & 3 deletions lit_nlp/components/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@
"""Implementation of PCA as a dimensionality reduction model."""

from absl import logging
from lit_nlp.components import projection
from lit_nlp.api import model
from lit_nlp.lib import utils
import numpy as np


class PCAModel(projection.ProjectorModel):
class PCAModel(model.ProjectorModel):
"""LIT model API implementation for PCA."""

def __init__(self, **pca_kw):
Expand Down Expand Up @@ -57,7 +57,7 @@ def fit_transform(self, inputs):
# LIT model API
def predict_minibatch(self, inputs, **unused_kw):
if not self._fitted:
return ({"z": [0, 0, 0]} for i in inputs)
return ({"z": [0, 0, 0]} for _ in inputs)
x = np.stack([i["x"] for i in inputs])
x = x - self._mean
zs = np.dot(x, self._evecs)
Expand Down
36 changes: 3 additions & 33 deletions lit_nlp/components/projection.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,7 @@
projections).
"""

import abc
from collections.abc import Iterable, Hashable, Sequence
from collections.abc import Hashable, Sequence
import threading
from typing import Optional

Expand All @@ -51,35 +50,6 @@
Spec = types.Spec


class ProjectorModel(lit_model.Model, metaclass=abc.ABCMeta):
"""LIT model API implementation for dimensionality reduction."""

##
# Training methods
@abc.abstractmethod
def fit_transform(self, inputs: Iterable[JsonDict]) -> list[JsonDict]:
pass

##
# LIT model API
def input_spec(self):
# 'x' denotes input features
return {"x": types.Embeddings()}

def output_spec(self):
# 'z' denotes projected embeddings
return {"z": types.Embeddings()}

@abc.abstractmethod
def predict_minibatch(
self, inputs: Iterable[JsonDict], **unused_kw
) -> list[JsonDict]:
pass

def max_minibatch_size(self, **unused_kw):
return 1000


class ProjectionInterpreter(lit_components.Interpreter):
"""Interpreter API implementation for dimensionality reduction model."""

Expand All @@ -88,7 +58,7 @@ def __init__(
model: lit_model.Model,
inputs: Sequence[JsonDict],
model_outputs: Optional[list[JsonDict]],
projector: ProjectorModel,
projector: lit_model.ProjectorModel,
field_name: str,
name: str,
):
Expand Down Expand Up @@ -166,7 +136,7 @@ class ProjectionManager(lit_components.Interpreter):
this is not explicitly enforced.
"""

def __init__(self, model_class: type[ProjectorModel]):
def __init__(self, model_class: type[lit_model.ProjectorModel]):
self._lock = threading.RLock()
self._instances: dict[Hashable, ProjectionInterpreter] = {}
# Used to construct new instances, given config['proj_kw']
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/components/umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@
"""Implementation of UMAP as a dimensionality reduction model."""

from absl import logging
from lit_nlp.components import projection
from lit_nlp.api import model
from lit_nlp.lib import utils
import numpy as np
import umap


class UmapModel(projection.ProjectorModel):
class UmapModel(model.ProjectorModel):
"""LIT model API implementation for UMAP."""

def __init__(self, **umap_kw):
Expand Down
10 changes: 8 additions & 2 deletions lit_nlp/lib/caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,15 @@ def key_fn(self, d) -> CacheKey:
##
# For internal use
def fit_transform(self, inputs: Iterable[types.JsonDict]):
"""For use with UMAP and other preprocessing transforms."""
"""Cache projections from ProjectorModel dimensionality reducers."""
wrapped = self.wrapped
if not isinstance(wrapped, lit_model.ProjectorModel):
raise TypeError(
"Attempted to call fit_transform() on a non-ProjectorModel."
)

inputs_as_list = list(inputs)
outputs = list(self.wrapped.fit_transform(inputs_as_list))
outputs = list(wrapped.fit_transform(inputs_as_list))
with self._cache.lock:
for inp, output in zip(inputs_as_list, outputs):
self._cache.put(output, self.key_fn(inp))
Expand Down

0 comments on commit e30e59a

Please sign in to comment.