Skip to content

Commit

Permalink
Update documentation that mentions predict_minibatch.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 555990930
  • Loading branch information
bdu91 authored and LIT team committed Aug 11, 2023
1 parent 6fdcbfe commit ce38565
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 26 deletions.
21 changes: 11 additions & 10 deletions docs/setup/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -175,22 +175,23 @@ <h2>Models</h2>
fields</li>
<li><code>output_spec()</code> should return a flat dict that describes the model's
predictions and any additional outputs</li>
<li><code>predict_minibatch()</code> and/or <code>predict()</code> should take a sequence of inputs
(satisfying <code>input_spec()</code>) and yields a parallel sequence of outputs
matching <code>output_spec()</code>.</li>
<li><code>predict()</code> should take a sequence of inputs (satisfying <code>input_spec()</code>)
and yields a parallel sequence of outputs matching <code>output_spec()</code>.</li>
</ul>
<p>Implementations should subclass
<a href="https://github.com/PAIR-code/lit/tree/main/lit_nlp/api/model.py"><code>Model</code></a>. An example for
<a href="https://cims.nyu.edu/~sbowman/multinli/">MultiNLI</a> might look something like:</p>
<pre class="language-py"><code class="language-py"><span class="token keyword">class</span> <span class="token class-name">NLIModel</span><span class="token punctuation">(</span>Model<span class="token punctuation">)</span><span class="token punctuation">:</span><br> <span class="token triple-quoted-string string">"""Wrapper for a Natural Language Inference model."""</span><br><br> NLI_LABELS <span class="token operator">=</span> <span class="token punctuation">[</span><span class="token string">'entailment'</span><span class="token punctuation">,</span> <span class="token string">'neutral'</span><span class="token punctuation">,</span> <span class="token string">'contradiction'</span><span class="token punctuation">]</span><br><br> <span class="token keyword">def</span> <span class="token function">__init__</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> model_path<span class="token punctuation">,</span> <span class="token operator">**</span>kw<span class="token punctuation">)</span><span class="token punctuation">:</span><br> <span class="token comment"># Load the model into memory so we're ready for interactive use.</span><br> self<span class="token punctuation">.</span>_model <span class="token operator">=</span> _load_my_model<span class="token punctuation">(</span>model_path<span class="token punctuation">,</span> <span class="token operator">**</span>kw<span class="token punctuation">)</span><br><br> <span class="token comment">##</span><br> <span class="token comment"># LIT API implementations</span><br> <span class="token keyword">def</span> <span class="token function">predict</span><span class="token punctuation">(</span>self<span class="token punctuation">,</span> inputs<span class="token punctuation">:</span> List<span class="token punctuation">[</span>Input<span class="token punctuation">]</span><span class="token punctuation">)</span> <span class="token operator">-</span><span class="token operator">></span> Iterable<span class="token punctuation">[</span>Preds<span class="token punctuation">]</span><span class="token punctuation">:</span><br> <span class="token triple-quoted-string string">"""Predict on a single minibatch of examples."""</span><br> examples <span class="token operator">=</span> <span class="token punctuation">[</span>self<span class="token punctuation">.</span>_model<span class="token punctuation">.</span>convert_dict_input<span class="token punctuation">(</span>d<span class="token punctuation">)</span> <span class="token keyword">for</span> d <span class="token keyword">in</span> inputs<span class="token punctuation">]</span> <span class="token comment"># any custom preprocessing</span><br> <span class="token keyword">return</span> self<span class="token punctuation">.</span>_model<span class="token punctuation">.</span>predict_examples<span class="token punctuation">(</span>examples<span class="token punctuation">)</span> <span class="token comment"># returns a dict for each input</span><br><br> <span class="token keyword">def</span> <span class="token function">input_spec</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span><br> <span class="token triple-quoted-string string">"""Describe the inputs to the model."""</span><br> <span class="token keyword">return</span> <span class="token punctuation">{</span><br> <span class="token string">'premise'</span><span class="token punctuation">:</span> lit_types<span class="token punctuation">.</span>TextSegment<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span><br> <span class="token string">'hypothesis'</span><span class="token punctuation">:</span> lit_types<span class="token punctuation">.</span>TextSegment<span class="token punctuation">(</span><span class="token punctuation">)</span><span class="token punctuation">,</span><br> <span class="token punctuation">}</span><br><br> <span class="token keyword">def</span> <span class="token function">output_spec</span><span class="token punctuation">(</span>self<span class="token punctuation">)</span><span class="token punctuation">:</span><br> <span class="token triple-quoted-string string">"""Describe the model outputs."""</span><br> <span class="token keyword">return</span> <span class="token punctuation">{</span><br> <span class="token comment"># The 'parent' keyword tells LIT where to look for gold labels when computing metrics.</span><br> <span class="token string">'probas'</span><span class="token punctuation">:</span> lit_types<span class="token punctuation">.</span>MulticlassPreds<span class="token punctuation">(</span>vocab<span class="token operator">=</span>NLI_LABELS<span class="token punctuation">,</span> parent<span class="token operator">=</span><span class="token string">'label'</span><span class="token punctuation">)</span><span class="token punctuation">,</span><br> <span class="token punctuation">}</span></code></pre>
<p>Unlike the dataset example, this model implementation is incomplete - you'll
need to customize <code>predict()</code> (or <code>predict_minibatch()</code>) accordingly with any
pre- or post-processing needed, such as tokenization.</p>
<p>Note: The <code>Model</code> base class implements simple batching, aided by the
<code>max_minibatch_size()</code> function. This is purely for convenience, since most deep
learning models will want this behavior. But if you don't need it, you can
simply override the <code>predict()</code> function directly and handle large inputs
accordingly.</p>
need to customize <code>predict()</code> accordingly with any pre- or post-processing needed, such
as tokenization.</p>
<p>Many deep learning models support a batched prediction behavior. Thus, we provide the
<code>BatchedModel</code> class that implements simple batching. Users of this class must implement
the <code>predict_minibatch()</code> function, which should convert a <code>Sequence</code> of
<code>JsonDict</code> objects to the appropriate batch representation (typically, a
<code>Mapping</code> of strings to aligned <code>Sequences</code> or <code>Tensors</code>) before
calling the model. Optionally, you may want to override the <code>max_minibatch_size()</code>
function, which determines the batch size.</p>
<p>Note: there are a few additional methods in the model API - see
<a href="https://github.com/PAIR-code/lit/tree/main/lit_nlp/api/model.py"><code>Model</code></a> for details.</p>
<h1>Run LIT inside python notebooks</h1>
Expand Down
29 changes: 15 additions & 14 deletions documentation/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,8 @@ of three methods:
fields
* `output_spec()` should return a flat dict that describes the model's
predictions and any additional outputs
* `predict_minibatch()` and/or `predict()` should take a sequence of inputs
(satisfying `input_spec()`) and yields a parallel sequence of outputs
matching `output_spec()`.
* `predict()` should take a sequence of inputs (satisfying `input_spec()`) and
yields a parallel sequence of outputs matching `output_spec()`.

Implementations should subclass
[`Model`](../lit_nlp/api/model.py). An example for
Expand Down Expand Up @@ -234,14 +233,16 @@ class NLIModel(Model):
```

Unlike the dataset example, this model implementation is incomplete - you'll
need to customize `predict()` (or `predict_minibatch()`) accordingly with any
pre- or post-processing needed, such as tokenization.
need to customize `predict()` accordingly with any pre- or post-processing
needed, such as tokenization.

Note: The `Model` base class implements simple batching, aided by the
`max_minibatch_size()` function. This is purely for convenience, since most deep
learning models will want this behavior. But if you don't need it, you can
simply override the `predict()` function directly and handle large inputs
accordingly.
Many deep learning models support a batched prediction behavior. Thus, we
provide the `BatchedModel` class that implements simple batching. Users of this
class must implement the `predict_minibatch()` function, which should convert
a `Sequence` of `JsonDict` objects to the appropriate batch representation
(typically, a `Mapping` of strings to aligned `Sequences` or `Tensors`) before
calling the model. Optionally, you may want to override the
`max_minibatch_size()` function, which determines the batch size.

Note: there are a few additional methods in the model API - see
[`Model`](../lit_nlp/api/model.py) for details.
Expand Down Expand Up @@ -318,11 +319,11 @@ can accept pre-tokenized inputs might have the following spec:
}
```

And in the model's `predict()` or `predict_minibatch()`, you would have logic to
use these and bypass the tokenizer:
And in the model's `predict()`, you would have logic to use these and bypass the
tokenizer:

```python
def predict_minibatch(inputs):
def predict(inputs):
input_tokens = [ex.get('tokens') or self.tokenizer.tokenize(ex['text'])
for ex in inputs]
# ...rest of your predict logic...
Expand Down Expand Up @@ -818,7 +819,7 @@ Name | Description
`TextSegment` | Natural language text, untokenized. | `string`
`GeneratedText` | Untokenized text, generated from a model (such as seq2seq). | `string`
`URL` | TextSegment, but interpreted as a URL. | `string`
`GeneratedURL` | Genrated TextSegment, but interpreted as a URL (i.e., it maye not be real/is inappropriate as a label). | `string`
`GeneratedURL` | Generated TextSegment, but interpreted as a URL (i.e., it maye not be real/is inappropriate as a label). | `string`
`SearchQuery` | TextSegment, but interpreted as a search query. | `string`
`String` | Opaque string data; ignored by components such as perturbation methods that operate on natural language. | `string`
`ReferenceTexts` | Multiple texts, such as a set of references for summarization or MT. | `List[Tuple[string, float]]`
Expand Down
4 changes: 2 additions & 2 deletions documentation/components.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ every invocation.) Generally, you'll need to:

* In your model's `__init__()`, build the graph, create a persistent TF
session, and load the model weights.
* In your `predict()` or `predict_minibatch()` function, build a feed dict and
call `session.run` directly.
* In your `predict()` function, build a feed dict and call `session.run`
directly.

Alternatively, you can export to a `SavedModel` and load this in an eager mode
runtime. This leads to much simpler code (see
Expand Down

0 comments on commit ce38565

Please sign in to comment.