In [1]:
import pandas as pd
from umap import UMAP
from sklearn.pipeline import make_pipeline 

# pip install "embetter[text]"
from embetter.text import SentenceEncoder

# Build a sentence encoder pipeline with UMAP at the end.
text_emb_pipeline = make_pipeline(
  SentenceEncoder('all-MiniLM-L6-v2'),
  UMAP()
)

# Load sentences
sentences = list(pd.read_csv("tests/data/text.csv")['text'])

# Calculate embeddings 
X_tfm = text_emb_pipeline.fit_transform(sentences)

# Write to disk. Note! Text column must be named "text"
df = pd.DataFrame({"text": sentences})
df['x'] = X_tfm[:, 0]
df['y'] = X_tfm[:, 1]

In [27]:
import jscatter
import numpy as np
import pandas as pd
from ipywidgets import HBox, VBox, HTML, Layout, Button
from IPython.display import display

class BaseTextExplorer:
    def __init__(self, dataf):
        self.dataf = dataf
        self.scatter = jscatter.Scatter(data=self.dataf, x="x", y="y", width=500, height=500)
        self.html = HTML(layout=Layout(width='600px', overflow_y='scroll', height='400px'))
        self.sample_btn = Button(description='resample')
        self.elem = HBox([scatter.show(), VBox([sample_btn, html])])
        
        self.scatter.widget.observe(lambda d: self.update(), ['selection'])
        self.sample_btn.on_click(lambda d: self.update())

    def update(self):
        texts = self.dataf.iloc[self.scatter.selection()].sample(10)["text"]
        self.html.value = ''.join([f'<p style="margin: 0px">{t}</p>' for t in texts])

    def observe(self, func):
        self.scatter.widget.observe(func, ['selection'])
        
    @property
    def selected_idx(self):
        return self.scatter.selection()

    @property
    def selected_texts(self):
        return list(self.dataf.iloc[self.selection_idx]["text"])

    @property
    def selected_dataframe(self):
        return self.dataf.iloc[self.selection_idx]

    def _repr_html_(self):
        return display(self.elem)

BaseTextExplorer(df)

HBox(children=(HBox(children=(VBox(children=(Button(button_style='primary', icon='arrows', layout=Layout(width…

<__main__.BaseTextExplorer at 0x7e42ea8ed000>