Skip to content

Commit

Permalink
SimpleSentimentModel should inherit from BatchedModel.
Browse files Browse the repository at this point in the history
This class implements the abstract method `predict_minibatch` in `BatchedModel`.

Addressing #1361.

PiperOrigin-RevId: 586041904
  • Loading branch information
bdu91 authored and LIT team committed Nov 28, 2023
1 parent 0dcb31d commit ac8ed59
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions lit_nlp/examples/simple_pytorch_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _from_pretrained(cls, *args, **kw):
return cls.from_pretrained(*args, from_tf=True, **kw)


class SimpleSentimentModel(lit_model.Model):
class SimpleSentimentModel(lit_model.BatchedModel):
"""Simple sentiment analysis model."""

LABELS = ["0", "1"] # negative, positive
Expand All @@ -103,7 +103,7 @@ def __init__(self, model_name_or_path):
##
# LIT API implementation
def max_minibatch_size(self):
# This tells lit_model.Model.predict() how to batch inputs to
# This tells lit_model.BatchedModel.predict() how to batch inputs to
# predict_minibatch().
# Alternately, you can just override predict() and handle batching yourself.
return 32
Expand Down
4 changes: 2 additions & 2 deletions lit_nlp/examples/sst_pytorch_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _from_pretrained(cls, *args, **kw):
return cls.from_pretrained(*args, from_tf=True, **kw)


class SimpleSentimentModel(lit_model.Model):
class SimpleSentimentModel(lit_model.BatchedModel):
"""Simple sentiment analysis model."""

LABELS = ["0", "1"] # negative, positive
Expand All @@ -95,7 +95,7 @@ def __init__(self, model_name_or_path):
##
# LIT API implementation
def max_minibatch_size(self):
# This tells lit_model.Model.predict() how to batch inputs to
# This tells lit_model.BatchedModel.predict() how to batch inputs to
# predict_minibatch().
# Alternately, you can just override predict() and handle batching yourself.
return 32
Expand Down

0 comments on commit ac8ed59

Please sign in to comment.