Skip to content

Commit

Permalink
Notebook widget improvements
Browse files Browse the repository at this point in the history
- Possible to now pass examples directly to widget.render for ad-hoc analysis.
- Fix an issue with UI state syncing back to Python on generated/newly-added
  examples.

PiperOrigin-RevId: 607757916
  • Loading branch information
iftenney authored and LIT team committed Feb 16, 2024
1 parent 77583e7 commit cdf79eb
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 8 deletions.
18 changes: 16 additions & 2 deletions lit_nlp/lib/ui_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"""
from typing import Optional

from absl import logging
import attr
from lit_nlp.api import dataset as lit_dataset
from lit_nlp.api import types
Expand Down Expand Up @@ -64,14 +65,27 @@ def update_state(self,
self._state.dataset_name = dataset_name
self._state.dataset = dataset

# This may contain 'added' datapoints not in the base dataset.
input_index = {ex["data"]["_id"]: ex for ex in indexed_inputs}

def get_example(example_id):
ex = input_index.get(example_id)
if ex is None:
ex = dataset.index.get(example_id)
return ex

if primary_id:
self._state.primary = dataset.index[primary_id]
self._state.primary = get_example(primary_id)
if self._state.primary is None:
logging.warn("State tracker: unable to find primary_id %s", primary_id)
else:
self._state.primary = None

self._state.selection = indexed_inputs

if pinned_id:
self._state.pinned = dataset.index[pinned_id]
self._state.pinned = get_example(pinned_id)
if self._state.pinned is None:
logging.warn("State tracker: unable to find pinned_id %s", pinned_id)
else:
self._state.pinned = None
34 changes: 28 additions & 6 deletions lit_nlp/notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,24 @@
through the render() method. Use the stop() method to stop the server when done.
"""

from collections.abc import Sequence
from collections.abc import Mapping, Sequence
import html
import json
import os
import pathlib
import random
from typing import cast, Optional
from typing import Any, Optional, cast
import urllib.parse

import attr
from IPython import display
from lit_nlp import dev_server
from lit_nlp import server_config
from lit_nlp.api import layout
from lit_nlp.lib import wsgi_serving

JsonDict = Mapping[str, Any]

is_colab = False
try:
import google.colab # pylint: disable=g-import-not-at-top,unused-import
Expand Down Expand Up @@ -66,6 +69,7 @@ class RenderConfig(object):
layout: Optional[str] = None
dataset: Optional[str] = None
models: Optional[Sequence[str]] = None
datapoints: Optional[Sequence[JsonDict]] = None

def get_query_str(self):
"""Convert config object to query string for LIT URL."""
Expand All @@ -75,8 +79,15 @@ def _encode(v):
return v

string_params = {
k: _encode(v) for k, v in attr.asdict(self).items() if v is not None
k: _encode(v)
for k, v in attr.asdict(self).items()
if (v is not None and k != 'datapoints')
}
if self.datapoints:
for i, ex in enumerate(self.datapoints):
for field in ex:
string_params[f'data{i}_{field}'] = _encode(ex[field])

return '?' + urllib.parse.urlencode(string_params)


Expand Down Expand Up @@ -134,21 +145,32 @@ def stop(self):
"""Stop the LIT server."""
self._server.stop()

def render(self, height=None, open_in_new_tab=False,
ui_params: Optional[RenderConfig] = None):
def render(
self,
height=None,
open_in_new_tab=False,
ui_params: Optional[RenderConfig] = None,
data: Optional[Sequence[JsonDict]] = None,
):
"""Render the LIT UI in the output cell.
To immediately analyze specifiic example(s), use the data= parameter:
widget.render(..., data=[{"prompt": "Hello world "}])
Args:
height: Optional height to display the LIT UI in pixels. If not specified,
then the height specified in the constructor is used.
then the height specified in the constructor is used.
open_in_new_tab: Whether to show the UI in a new tab instead of in the
output cell. Defaults to false.
ui_params: Optional configuration options for the LIT UI's state.
data: Optional examples to load directly to the UI (via URL params).
"""
if not height:
height = self._height
if not ui_params:
ui_params = RenderConfig()
if data:
ui_params.datapoints = data
if is_colab:
_display_colab(self._server.port, height, open_in_new_tab, ui_params)
else:
Expand Down

0 comments on commit cdf79eb

Please sign in to comment.