This notebook requires a pair of untransformed and annotation-based transformed embeddings. The snakemake worflow in this repository contains rules for generating these embedding pairs with UMAP for FAUST annotations (`{name}_embedding_output.parquet` ). An example pair can be generated via,

```bash
snakemake --cores all \
    data/mair-2022-ismb/TISSUE_138_samples_FM96_OM138_035_CD45_live_fcs_110595_umap.parquet \
    data/mair-2022-ismb/TISSUE_138_samples_FM96_OM138_035_CD45_live_fcs_110595_umap_annotated.parquet
```

which downloads [sample annotations](https://figshare.com/articles/dataset/ISMB_BioVis_2022_Data/20301639) from figshare and executes the embedding rules automatically.


In [None]:
import pathlib

import pandas as pd
import numpy as np
import jscatter

data_dir = pathlib.Path.cwd() / ".." / "data"
raw = pd.read_parquet("../data/seed42_umap.pq")
annotated = pd.read_parquet("../data/seed123_umap.pq")

In [None]:
import colors 

color_map = [colors.gray_dark]+colors.glasbey_light+colors.glasbey_light+colors.glasbey_light

view_config = dict(x='x', y='y', color_by='cellType', color_map=color_map, background_color='black', axes=False, opacity_unselected=0.05)
compose_config = dict(sync_selection=True, sync_hover=True, row_height=640)


In [None]:
import ipywidgets
from sklearn.neighbors import NearestNeighbors

import numpy.typing as npt

def kneighbors(X: npt.ArrayLike, k: int) -> npt.NDArray:
    # first neighbor is always self, so increment provided k by one
    nbrs = NearestNeighbors(n_neighbors=k + 1).fit(X)
    ind = nbrs.kneighbors(X, return_distance=False)
    return ind[:, 1:]

def link_neighbors(
    s1: jscatter.Scatter,
    s2: jscatter.Scatter,
    neighbors: npt.NDArray,
):
    slider = ipywidgets.IntSlider(
        value=0,
        min=0,
        max=neighbors.shape[1],
        description='k:',
    )
    
    def selection_handler(change):
        if change['new'] is None or len(change['new']) != 1:
            s2.widget.selection = []
        else:
            idx = int(change['new'])
            s1.widget.selection = neighbors[idx][:slider.value]
            s2.widget.selection = neighbors[idx][:slider.value]
            
    s1.widget.observe(selection_handler, names="selection")
                    
    scatters = ipywidgets.GridBox(
        children=[s1.show(), s2.show()],
        layout=ipywidgets.Layout(
            grid_template_columns=' '.join(['1fr' for x in range(2)]),
            grid_gap='2px'
        )
    )
    
    return ipywidgets.VBox([slider, scatters])

plot_ann_embed_umap = jscatter.Scatter(data=annotated, **view_config)
plot_raw_embed_umap = jscatter.Scatter(data=raw, **view_config)

link_neighbors(plot_ann_embed_umap, plot_raw_embed_umap, kneighbors(annotated[["x", "y"]], k=50))