In [3]:
import dataclasses
import functools
import itertools
import pathlib
from typing import Iterable, Union
import uuid

import bioframe as bf
import hg
import ipywidgets
import jscatter
import matplotlib.colors as colors
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import traitlets
from IPython.display import display

data_dir = pathlib.Path.cwd() / "data"

## Linked embeddings of single-locus interaction profiles

In [4]:
def load_eigs(stages: tuple[str, ...]) -> list[pd.DataFrame]:
    eigs = []
    for name in stages:
        eigs_path = (
            data_dir
            / f"eigvecs.proj__ESC-Trypsin-FA-DSG-HindIII.{name}-Trypsin-FA-DSG-HindIII.E1-E10.hg38.50000.pq"
        )
        eigs_df = pd.read_parquet(eigs_path).dropna()
        eigs.append(eigs_df)

    # create shared index for all dataframes
    index = functools.reduce(pd.Index.intersection, [df.index for df in eigs])
    return [df.loc[index] for df in eigs]

def load_track_metadata() -> tuple[pd.DataFrame, dict[str, dict]]:
    tracks = pd.read_parquet(data_dir / "tracks.hg38.50000.pq")
    clusters = pd.read_parquet(
        data_dir
        / "clusters.joint.proj__ESC-Trypsin-FA-DSG-HindIII.E1-E10.50000.kmeans_sm.pq"
    ).filter(regex="^kmeans_sm\d+$")
    for col in clusters:
        clusters[col] = clusters[col].astype(str)
    return tracks.join(clusters), {
        "GC": dict(norm=colors.Normalize(vmin=0.35, vmax=0.65), map="RdYlBu_r"),
        "centel_abs": dict(norm=colors.Normalize(vmin=0, vmax=149043529), map="Greys"),
        "kmeans_sm8": dict(
            relabel={ "7": "A1", "5": "A2", "6": "AB", "4": "A3", "3": "A3", "1": "B0", "0": "B0", "2": "B4", "8": np.nan },
            map={'A1': '#e23838', 'A2': '#f78200', 'AB': '#5ebd3e', 'A3': '#ffb900', 'B0': 'cornflowerblue', 'B4': '#973999'}
        )
    }

def init_dropdowns(
    x: str,
    y: str,
    color: str,
    scatters: Iterable[jscatter.Scatter],
):
    tracks, color_config = load_track_metadata()
    xy_options = [f"E{i}" for i in range(1, 11)]    
    color_options = [c for c in tracks.columns if c not in ["start", "end"]]
    color_options.remove(color)
    color_options.insert(0, color)

    x_dropdown = ipywidgets.Dropdown(options=xy_options, value=x, description="x:")

    def on_change_x(change):
        for scatter in scatters:
            scatter.x(change.new)

    y_dropdown = ipywidgets.Dropdown(options=xy_options, value=y, description="y:")

    def on_change_y(change):
        for scatter in scatters:
            scatter.y(change.new)

    c_dropdown = ipywidgets.Dropdown(
        options=color_options, value=color, description="color:"
    )
    
    def extract_color_series(
        track_data: pd.DataFrame,
        field: str,
        color_kwargs: Union[dict, None],
    ):
        data = track_data[field]
        if color_kwargs is None:
            color_kwargs = {}
            if data.dtype.name in ("object", "category"):
                data = data.astype("category") # ensure categorical
                color_kwargs["map"] = dict(zip(data.cat.categories, jscatter.glasbey_dark))
            else:
                color_kwargs["norm"] = colors.Normalize(vmin=data.min(), vmax=data.max())
                color_kwargs["map"] = "viridis_r"
            return data, color_kwargs
        if color_kwargs and "relabel" in color_kwargs:
            data = data.map(color_kwargs["relabel"])
                        
        return data, color_kwargs

    def on_change_color(change):
        field = change["new"]
        for scatter in scatters:
            track_data = tracks.loc[scatter._data.index]
            data, kwargs = extract_color_series(track_data, field, color_config.get(field))
            scatter._data["_color"] = data
            scatter.color(by="_color", **kwargs)

    x_dropdown.observe(on_change_x, names=["value"])
    y_dropdown.observe(on_change_y, names=["value"])
    c_dropdown.observe(on_change_color, names=["value"])
    on_change_color(dict(new=color))

    return x_dropdown, y_dropdown, c_dropdown

def init_scatters(
    stages: tuple[str, ...] = ("ESC", "DE", "HB", "iHEP", "mHEP"),
    x: str = "E1",
    y: str = "E2",
    color: str = "kmeans_sm8",
):
    eigs = load_eigs(stages=stages)
    scatters = [jscatter.Scatter(x=x, y=y, data=data, opacity=0.5) for data in eigs]
    src = scatters[0]
    for target in scatters[1:]:
        ipywidgets.jslink((src.widget, "selection"), (target.widget, "selection"))
    
    dropdowns = init_dropdowns(x=x, y=y, color=color, scatters=scatters)
    
    component = ipywidgets.VBox([
        ipywidgets.HBox([
            ipywidgets.VBox([ipywidgets.Label(name), s.show()])
            for name, s in zip(stages, scatters)
        ]),
        ipywidgets.HBox(dropdowns)
    ])
    
    def extract_coords(ind):
        return scatters[0]._data.iloc[ind][["chrom", "start", "end"]]
    
    # we expose a single "selection" for this component, which the viewer can subscribe to
    component.add_traits(coords=traitlets.Any(extract_coords([])))
    
    ipywidgets.dlink(
        source=(scatters[0].widget, "selection"),
        target=(component, "coords"),
        transform=extract_coords,
    )
    
    return component

## Linked RNA seq violinplot

In [5]:
def init_violinplot(scatter):
    rna = pd.read_parquet(data_dir / "tracks.hg38.50000.pq").filter(regex="R1.fwd$")
    fig, ax = plt.figure(), plt.subplot(111)
    output = ipywidgets.Output()

    ax.set(xlabel="stage", ylabel='log(n + 1)')

    def on_selection_change(change):
        try:
            ax.clear()
            data = np.log10(rna.loc[change.new.index] + 1).melt()
            sns.violinplot(x="variable", y="value", data=data, ax=ax)
            output.clear_output()
            with output:
                display(fig)
        except Exception as e:
            with output:
                print(e)

    scatter.observe(on_selection_change, names=["coords"])
    return output

## HiGlass view

In [14]:
def _encode(obj):
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    return obj

@dataclasses.dataclass
class DynamicTileset:
    name: str
    chromsizes: pd.Series
    uid: str = dataclasses.field(default_factory=lambda: uuid.uuid4().hex)
    datatype = "bedlike"
        
    def __post_init__(self):
        abslen = self.chromsizes.cumsum()
        starts = pd.Series([0] + abslen[:-1].tolist(), index=abslen.index)
        self._starts =  dict(starts)
        self._tiles = []
        
    def update(self, df):
        starts = self._starts
        df = bf.cluster(df).groupby("cluster").agg({
            "chrom": "first",
            "start": "min",
            "end": "max",
            "name": "first",
            "score": "first",
            "strand": "first",
            "thickStart": "min",
            "thickEnd": "max",
            "rgb": "first",
        })
        self._tiles = [{
            "chrOffset": int(starts[r[0]]),
            "xStart": int(starts[r[0]] + r[1]),
            "xEnd": int(starts[r[0]] + r[2]),
            "importance": 0,
            "uid": uuid.uuid4().hex,
            "fields": tuple(map(_encode, r))
        } for r in df.to_records(index=False)]
        
    def info(self):
        genome_length = int(np.sum(self.chromsizes.values))
        return {
            "uuid": self.uid,
            "max_width": genome_length,
            "min_pos": [1],
            "max_pos": [genome_length],
            "max_zoom": 0,
        }
     
    def tiles(self, _tileids):
        return [(f"{self.uid}.0.0", self._tiles)]

def init_dynamic_tileset():
    bins = bf.read_table(
        data_dir / "clusters.proj__all.JOINT.50000.E1-E10.kmeans_sm8.bed",
        schema="bed9",
        schema_is_strict=True,
    )
    tileset = DynamicTileset(name="IPG clusters", chromsizes=bf.fetch_chromsizes('hg38')[:'chrY'])
    tileset.update(bins)
    return tileset, bins

def init_dynamic_track():
    tileset, bins = init_dynamic_tileset()
    track = hg.server.add(tileset).track(height=30).opts(fillOpacity=1)
    def update_tileset(coords):
        if len(coords) > 0:
            inds = bf.overlap(bins, coords, how="inner", return_index=True, return_input=False)
            tileset.update(bins.iloc[inds["index"]])
        else:
            tileset.update(bins)
        
    return track, update_tileset

def init_viewer(scatter: Union[None, traitlets.HasTraits] = None):
    conf = hg.Viewconf.parse_file("./viewconf.json")
    track, update_tileset = init_dynamic_track()
    view = conf.views[0]
    view.tracks.top.append(track)
    
    hg_viewer = conf.widget()
    
    if scatter is None:
        return hg_viewer
    
    def on_coords_change(change):
        update_tileset(change.new)
        hg_viewer.reload(dict(trackId=track.uid, viewId=view.uid))

    scatter.observe(on_coords_change, names=["coords"])
    return hg_viewer

## Demo

In [15]:
def init():
    scatter = init_scatters()
    hg_viewer = init_viewer(scatter)
    return hg_viewer, scatter

ipywidgets.VBox(init())

VBox(children=(HiGlassWidget(), VBox(children=(HBox(children=(VBox(children=(Label(value='ESC'), HBox(childrenâ€¦