Skip to content

Commit

Permalink
Add final batch of class-based LitTypes and replace existing referenc…
Browse files Browse the repository at this point in the history
…es to these classes.

PiperOrigin-RevId: 460490040
  • Loading branch information
cjqian authored and LIT team committed Jul 12, 2022
1 parent 4b777d7 commit c020d25
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 21 deletions.
14 changes: 7 additions & 7 deletions lit_nlp/api/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,13 +289,13 @@ class CategoryLabel(String):


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class _Tensor(LitType):
class _Tensor1D(LitType):
"""A tensor type."""
default: Sequence[float] = None


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class MulticlassPreds(_Tensor):
class MulticlassPreds(_Tensor1D):
"""Multiclass predicted probabilities, as <float>[num_labels]."""
# Vocabulary is required here for decoding model output.
# Usually this will match the vocabulary in the corresponding label field.
Expand Down Expand Up @@ -371,13 +371,13 @@ class MultiSegmentAnnotations(_List):


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class Embeddings(_Tensor):
class Embeddings(_Tensor1D):
"""Embeddings or model activations, as fixed-length <float>[emb_dim]."""
pass


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class _GradientsBase(_Tensor):
class _GradientsBase(_Tensor1D):
"""Shared gradient attributes."""
align: Optional[Text] = None # name of a Tokens field
grad_for: Optional[Text] = None # name of Embeddings field
Expand All @@ -393,13 +393,13 @@ class Gradients(_GradientsBase):


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class _InfluenceEncodings(_Tensor):
class _InfluenceEncodings(_Tensor1D):
"""A single vector of <float>[enc_dim]."""
grad_target: Optional[Text] = None # class for computing gradients (string)


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class TokenEmbeddings(_Tensor):
class TokenEmbeddings(_Tensor1D):
"""Per-token embeddings, as <float>[num_tokens, emb_dim]."""
align: Optional[Text] = None # name of a Tokens field

Expand All @@ -417,7 +417,7 @@ class ImageGradients(_GradientsBase):


@attr.s(auto_attribs=True, frozen=True, kw_only=True)
class AttentionHeads(_Tensor):
class AttentionHeads(_Tensor1D):
"""One or more attention heads, as <float>[num_heads, num_tokens, num_tokens]."""
# input and output Tokens fields; for self-attention these can be the same
align_in: Text
Expand Down

0 comments on commit c020d25

Please sign in to comment.