In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import pathlib

import jscatter
import pandas as pd

import embcomp as ec

data_dir = pathlib.Path.cwd() / ".." / "data" / "mair-2022-ismb"
name = "TISSUE_138_samples_FM96_OM138_035_CD45_live_fcs_110595"

raw = pd.read_parquet(data_dir / f"{name}_umap.parquet")
annotated = pd.read_parquet(data_dir / f"{name}_umap_annotated.parquet")
raw_knn_indices = ec.metrics.kneighbors(raw[["x", "y"]], k=100)
ann_knn_indices = ec.metrics.kneighbors(annotated[["x", "y"]], k=100)
labels = raw.cellType

In [3]:
import re
import uuid
from typing import Sequence, Union

import jinja2

_HTML_TEMPLATE = jinja2.Template(
    """
<!doctype html>
<html>
    {% for (marker, score) in annotation.items() %}<span style="color: transparent;text-shadow: 0 0 {{ score }}px #000;background-color:red">{{ marker }}</span>{% endfor %}
</html>
"""
)

DivergingBarChart = """
// Copyright 2021 Observable, Inc.
// Released under the ISC license.
// https://observablehq.com/@d3/diverging-bar-chart
function DivergingBarChart(data, {
  x = d => d, // given d in data, returns the (quantitative) x-value
  y = (d, i) => i, // given d in data, returns the (ordinal) y-value
  title, // given d in data, returns the title text
  marginTop = 30, // top margin, in pixels
  marginRight = 40, // right margin, in pixels
  marginBottom = 10, // bottom margin, in pixels
  marginLeft = 40, // left margin, in pixels
  width = 640, // outer width of chart, in pixels
  height, // the outer height of the chart, in pixels
  xType = d3.scaleLinear, // type of x-scale
  xDomain, // [xmin, xmax]
  xRange = [marginLeft, width - marginRight], // [left, right]
  xFormat, // a format specifier string for the x-axis
  xLabel, // a label for the x-axis
  yPadding = 0.1, // amount of y-range to reserve to separate bars
  yDomain, // an array of (ordinal) y-values
  yRange, // [top, bottom]
  colors = d3.schemePiYG[3] // [negative, …, positive] colors
} = {}) {
  // Compute values.
  const X = d3.map(data, x);
  const Y = d3.map(data, y);

  // Compute default domains, and unique the y-domain.
  if (xDomain === undefined) xDomain = d3.extent(X);
  if (yDomain === undefined) yDomain = Y;
  yDomain = new d3.InternSet(yDomain);

  // Omit any data not present in the y-domain.
  // Lookup the x-value for a given y-value.
  const I = d3.range(X.length).filter(i => yDomain.has(Y[i]));
  const YX = d3.rollup(I, ([i]) => X[i], i => Y[i]);

  // Compute the default height.
  if (height === undefined) height = Math.ceil((yDomain.size + yPadding) * 25) + marginTop + marginBottom;
  if (yRange === undefined) yRange = [marginTop, height - marginBottom];

  // Construct scales, axes, and formats.
  const xScale = xType(xDomain, xRange);
  const yScale = d3.scaleBand(yDomain, yRange).padding(yPadding);
  const xAxis = d3.axisTop(xScale).ticks(width / 80, xFormat);
  const yAxis = d3.axisLeft(yScale).tickSize(0).tickPadding(6);
  const format = xScale.tickFormat(100, xFormat);

  // Compute titles.
  if (title === undefined) {
    title = i => `${Y[i]}\n${format(X[i])}`;
  } else if (title !== null) {
    const O = d3.map(data, d => d);
    const T = title;
    title = i => T(O[i], i, data);
  }

  const svg = d3.create("svg")
      .attr("width", width)
      .attr("height", height)
      .attr("viewBox", [0, 0, width, height])
      .attr("style", "max-width: 100%; height: auto; height: intrinsic;");

  svg.append("g")
      .attr("transform", `translate(0,${marginTop})`)
      .call(xAxis)
      .call(g => g.select(".domain").remove())
      .call(g => g.selectAll(".tick line").clone()
          .attr("y2", height - marginTop - marginBottom)
          .attr("stroke-opacity", 0.1))
      .call(g => g.append("text")
          .attr("x", xScale(0))
          .attr("y", -22)
          .attr("fill", "currentColor")
          .attr("text-anchor", "center")
          .text(xLabel));

  const bar = svg.append("g")
    .selectAll("rect")
    .data(I)
    .join("rect")
      .attr("fill", i => colors[X[i] > 0 ? colors.length - 1 : 0])
      .attr("x", i => Math.min(xScale(0), xScale(X[i])))
      .attr("y", i => yScale(Y[i]))
      .attr("width", i => Math.abs(xScale(X[i]) - xScale(0)))
      .attr("height", yScale.bandwidth());

  if (title) bar.append("title")
      .text(title);

  svg.append("g")
      .attr("text-anchor", "end")
      .attr("font-family", "sans-serif")
      .attr("font-size", 10)
    .selectAll("text")
    .data(I)
    .join("text")
      .attr("text-anchor", i => X[i] < 0 ? "end" : "start")
      .attr("x", i => xScale(X[i]) + Math.sign(X[i] - 0) * 4)
      .attr("y", i => yScale(Y[i]) + yScale.bandwidth() / 2)
      .attr("dy", "0.35em")
      .text(i => format(X[i]));

  svg.append("g")
      .attr("transform", `translate(${xScale(0)},0)`)
      .call(yAxis)
      .call(g => g.selectAll(".tick text")
        .filter(y => YX.get(y) < 0)
          .attr("text-anchor", "start")
          .attr("x", 6));
          
  return svg.node();
}
"""

HTML_TEMPLATE = jinja2.Template(
    """
<div id="{{ id }}"></div>
<script type="module">
import * as d3 from "https://esm.sh/d3";
"""
    + DivergingBarChart
    + """
let markers = {{ markers }}
let width = 500;

let svg = DivergingBarChart(markers, {
  x: d => d.score,
  y: d => d.name,
  yDomain: markers.map(d => d.name),
  // xFormat: ".2f",
  xLabel: "- +",
  width,
  marginRight: 70,
  marginLeft: 70,
  colors: d3.schemeRdBu[3]
});

document.getElementById("{{ id }}").appendChild(svg);
</script>
"""
)


class Annotation(dict[str, float]):
    @classmethod
    def from_str(
        cls, s: str, n: int = 1
    ):  # scores: Union[Sequence[float], None] = None):
        markers = [marker for marker in re.split("(\w+[\-|\+])", s) if marker]
        return cls((marker[:-1], -n if marker[-1] == "-" else n) for marker in markers)
        # if scores is None:
        #    return cls.fromkeys(markers, 1.0)
        # return cls(zip(markers, scores))

    def __str__(self):
        return "".join(self.keys())

    def __repr__(self):
        return str(self)

    def _repr_mimebundle_(self, include, exclude):
        return {
            "text/html": HTML_TEMPLATE.render(
                id=uuid.uuid4().hex,
                markers=[dict(name=k, score=v) for k, v in self.items()],
            )
        }

    def add(self, label, n: int = 1):
        new = self.from_str(label, n)
        for name, value in new.items():
            self[name] += value
        return self


label = "CD4+CD8-CD3+CD45RA-CD27+CD19-CD103-CD28+CD69-PD1+HLADR-GranzymeB-CD25-ICOS-TCRgd-CD38-CD127-Tim3-"
columns = [f"{marker}_Windsorized" for marker in re.split("[\-|\+]", label) if marker]
scores = annotated[columns][labels == label].mean()

anno = Annotation.from_str(label)
anno

In [4]:
from dataclasses import dataclass 
import functools
import numpy.typing as npt

@dataclass
class NeighborSet:
    labels: npt.NDArray
    indices: npt.NDArray

    def counts(self):
        v = np.zeros(len(labels.unique))
        self.labels[self.indices]


@dataclass(frozen=True)
class KNNGrouper:
    knn_indices: npt.NDArray
    labels: pd.Series  # categorical series

    def __post_init__(self):
        assert len(self.labels) == len(self.knn_indices)
        assert isinstance(self.labels.dtype, pd.CategoricalDtype)

    @functools.cached_property
    def groups(self):
        return {
            label: np.unique(self.knn_indices[self.labels == label])
            for label in self.labels.cat.categories
        }

    @functools.cached_property
    def counts(self):
        out = np.zeros((len(self.groups), len(self.groups)))
        for i, knn_ind in enumerate(self.groups.values()):
            values, counts = np.unique(
                self.labels.cat.codes[knn_ind], return_counts=True
            )
            out[i, values] = counts
        return pd.DataFrame(
            out, index=self.labels.cat.categories, columns=self.labels.cat.categories
        )

In [5]:
counts = (
    raw.iloc[raw.sample(10000).index]
    .groupby("complete_faust_label")
    .size()
    .sort_values(ascending=False)
)

In [6]:
it = counts[counts > 0].items()
label, count = next(it)

anno = Annotation.from_str(label, count)

for label, count in it:
    anno.add(label, count)

anno