In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import numpy as np
import embedders
import anndata

In [4]:
# Load the scRNA embeddings

data = torch.tensor(np.load("/teamspace/studios/this_studio/embedders/data/blood_cell_scrna/embeddings_s2_e2_h2_3.npy"))
data.shape
idx = np.random.choice(data.shape[0], 10_000, replace=False)
data = data[idx]  # Take it easy

# Also, let's add that dummy dimension for E2
# data = torch.hstack([data[:, :3], torch.ones(data.shape[0], 1), data[:, 3:]])
data[0]

tensor([-0.9965,  0.0600,  0.0579, -0.0310, -0.0046,  1.0001,  0.0134,  0.0055,
         1.0000,  0.0060,  0.0057,  1.0000, -0.0021, -0.0024])

In [5]:
classes = torch.tensor(
    [
        int(x)
        for x in anndata.read_h5ad("/teamspace/studios/this_studio/embedders/data/blood_cell_scrna/adata.h5ad").obs[
            "cell_type"
        ]
    ]
)
classes = classes[idx]  # Take it easy
classes.shape

torch.Size([10000])

In [17]:
# Device management
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
data = data.to(device)
classes = classes.to(device)

In [18]:
# Initialize the appropriate product manifold, which we'll use for indexing

pm = embedders.manifolds.ProductManifold(signature=[(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)], device=device)

In [25]:
# First, we can compute the angles of all 2-d projections

angle_vals = torch.zeros(data.shape[0], pm.dim, device=device)

for i, M in enumerate(pm.P):
    dims = pm.man2dim[i]
    dims_target = pm.man2intrinsic[i]
    if M.type in ["H", "S"]:
        angle_vals[:, dims_target] = torch.atan2(data[:, dims[0]].view(-1, 1), data[:, dims[1:]])
    elif M.type == "E":
        angle_vals[:, dims_target] = torch.atan2(torch.tensor(1), data[:, dims])

angle_vals.shape  # Note that we have gone from (1000, 14) to (1000, 10), the number of intrinsic dimensions

torch.Size([10000, 10])

In [27]:
def circular_greater(angles, threshold):
    """
    Check if angles are in the half-circle (threshold, threshold + pi)
    """
    return ((angles - threshold + torch.pi) % (2 * torch.pi)) - torch.pi > 0


def calculate_info_gain(values, labels):
    batch_size, n_dim = values.shape
    n_classes = labels.max().item() + 1

    # Calculate total Gini impurity without bincount
    label_one_hot = torch.nn.functional.one_hot(labels, n_classes).float()
    class_probs = label_one_hot.mean(dim=0)
    total_gini = 1 - (class_probs**2).sum()

    # Initialize arrays for left and right counts
    left_counts = torch.zeros((batch_size, n_dim, n_classes), device=values.device)
    right_counts = torch.zeros((batch_size, n_dim, n_classes), device=values.device)

    # Calculate left and right counts for each potential split
    for i in range(batch_size):
        mask = circular_greater(values, values[i].unsqueeze(0))
        for j in range(n_dim):
            left_mask = ~mask[:, j]
            right_mask = mask[:, j]
            left_counts[i, j] = label_one_hot[left_mask].sum(dim=0)
            right_counts[i, j] = label_one_hot[right_mask].sum(dim=0)

    # Calculate Gini impurities for left and right partitions
    left_total = left_counts.sum(dim=-1, keepdim=True).clamp(min=1)
    right_total = right_counts.sum(dim=-1, keepdim=True).clamp(min=1)
    left_gini = 1 - ((left_counts / left_total) ** 2).sum(dim=-1)
    right_gini = 1 - ((right_counts / right_total) ** 2).sum(dim=-1)

    # Calculate weighted Gini impurity
    left_weight = left_total.squeeze(-1) / batch_size
    right_weight = right_total.squeeze(-1) / batch_size
    weighted_gini = left_weight * left_gini + right_weight * right_gini

    # Calculate information gain
    info_gain = total_gini - weighted_gini

    return info_gain


ig = calculate_info_gain(angle_vals, classes)

# What's the index?
best_idx = torch.argmax(ig)

In [90]:
from hyperdt.torch.product_space_DT import ProductSpaceDT
from hyperdt.torch.tree import DecisionNode
from hyperdt.torch.hyperbolic_trig import _hyperbolic_midpoint


class TorchProductSpaceDT(ProductSpaceDT):
    def __init__(self, signature):
        sig_r = [(x[1], x[0]) for x in signature]
        super().__init__(signature=sig_r)
        self.pm = embedders.manifolds.ProductManifold(signature=signature)

    def _get_angle_vals(self, X):
        angle_vals = torch.zeros((X.shape[0], self.pm.dim), device=X.device)

        for i, M in enumerate(self.pm.P):
            dims = self.pm.man2dim[i]
            dims_target = self.pm.man2intrinsic[i]
            if M.type in ["H", "S"]:
                angle_vals[:, dims_target] = torch.atan2(X[:, dims[0]].view(-1, 1), X[:, dims[1:]])
            elif M.type == "E":
                angle_vals[:, dims_target] = torch.atan2(torch.tensor(1), X[:, dims])

        return angle_vals

    def fit(self, X, y):
        """Fit a decision tree to the data. Modified from HyperbolicDecisionTreeClassifier
        to remove multiple timelike dimensions in product space."""
        # Find all dimensions in product space (including timelike dimensions)
        self.all_dims = list(range(sum([space[0] + 1 for space in self.signature])))

        # Find indices of timelike dimensions in product space
        self.timelike_dims = [0]
        for i in range(len(self.signature) - 1):
            self.timelike_dims.append(sum([space[0] + 1 for space in self.signature[: i + 1]]))

        # Remove timelike dimensions from list of dimensions
        # self.dims_ex_time = list(np.delete(np.array(self.all_dims), self.timelike_dims))
        self.dims_ex_time = [dim for dim in self.all_dims if dim not in self.timelike_dims]

        # Get array of classes
        self.classes_ = torch.unique(y)

        # First, we can compute the angles of all 2-d projections
        angle_vals = self._get_angle_vals(X)
        self.tree = self._fit_node(X=angle_vals, y=y, depth=0)

    def _fit_node(self, X, y, depth):
        print(f"Depth {depth} with {X.shape} samples")
        # Base case
        if depth == self.max_depth or len(X) < self.min_samples_split or len(torch.unique(y)) == 1:
            value, probs = self._leaf_values(y)
            return DecisionNode(value=value, probs=probs)

        # Recursively find the best split:
        ig = calculate_info_gain(X, y)
        best_idx = torch.argmax(ig)
        best_row, best_dim = best_idx // X.shape[1], best_idx % X.shape[1]
        best_ig = ig[best_row, best_dim]

        # Since we're evaluating greater than, we need to also find the next-largest value and take the midpoint
        next_largest = torch.max(X[~circular_greater(X[:, best_dim], X[best_row, best_dim]), best_dim])

        # Midpoint computation will depend on manifold; TODO: actually do this
        # best_theta = (X[best_row, best_dim] + next_largest) / 2
        best_manifold = self.pm.P[self.pm.intrinsic2man[best_dim.item()]]
        if best_manifold.type == "H":
            best_theta = _hyperbolic_midpoint(X[best_row, best_dim], next_largest)
        elif best_manifold.type == "S":
            best_theta = (X[best_row, best_dim] + next_largest) / 2
        else:
            best_theta = torch.arctan2(torch.tensor([2.0], device=X.device), X[best_row, best_dim] + next_largest)

        # Fallback case:
        if best_ig <= 0:
            print(f"Fallback triggered at depth {depth}")
            value, probs = self._leaf_values(y)
            return DecisionNode(value=value, probs=probs)

        # Populate:
        node = DecisionNode(feature=best_dim, theta=best_theta)
        node.score = best_ig
        left, right = circular_greater(X[:, best_dim], best_theta), ~circular_greater(X[:, best_dim], best_theta)
        node.left = self._fit_node(X=X[left], y=y[left], depth=depth + 1)
        node.right = self._fit_node(X=X[right], y=y[right], depth=depth + 1)
        return node

    def predict(self, X):
        angle_vals = self._get_angle_vals(X)
        return torch.tensor([self._traverse(x).value for x in angle_vals], device=X.device)

    def _left(self, x, node):
        """Boolean: Go left?"""
        return circular_greater(x[node.feature], node.theta)

In [95]:
# Let's test it out

tpsdt = TorchProductSpaceDT(signature=[(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)])
tpsdt.pm = tpsdt.pm.to(device)
tpsdt.fit(data, classes)

Depth 0 with 10000 samples
X shape, torch.Size([10000, 10]), y shape torch.Size([10000])
Depth 1 with 7838 samples
X shape, torch.Size([7838, 10]), y shape torch.Size([7838])
Depth 2 with 5825 samples
X shape, torch.Size([5825, 10]), y shape torch.Size([5825])
Depth 3 with 1605 samples
X shape, torch.Size([1605, 10]), y shape torch.Size([1605])
Depth 3 with 4220 samples
X shape, torch.Size([4220, 10]), y shape torch.Size([4220])
Depth 2 with 2013 samples
X shape, torch.Size([2013, 10]), y shape torch.Size([2013])
Depth 3 with 2013 samples
X shape, torch.Size([2013, 10]), y shape torch.Size([2013])
Depth 3 with 0 samples
X shape, torch.Size([0, 10]), y shape torch.Size([0])
Depth 1 with 2162 samples
X shape, torch.Size([2162, 10]), y shape torch.Size([2162])
Depth 2 with 2162 samples
X shape, torch.Size([2162, 10]), y shape torch.Size([2162])
Depth 3 with 2162 samples
X shape, torch.Size([2162, 10]), y shape torch.Size([2162])
Depth 3 with 0 samples
X shape, torch.Size([0, 10]), y shape

In [96]:
tpsdt.score(data, classes).sum() / data.shape[0]

  return torch.tensor(self.predict(X) == y)


tensor(0.2224, device='cuda:0')

In [97]:
tpsdt.predict(data).unique()

tensor([1, 4, 7, 9], device='cuda:0')

In [109]:
# Confirm this is the same thing we get with the non-torch version
# Hmm, not exactly. I wonder why...

from hyperdt.product_space_DT import ProductSpaceDT

tpsdt = TorchProductSpaceDT(signature=[(1, 2), (0, 2), (-1, 2), (-1, 2), (-1, 2)])
tpsdt.fit(data[:500], classes[:500])
y_torch = tpsdt.predict(data[500:1000])

psdt = ProductSpaceDT(signature=[(2, 1), (2, 0), (2, -1), (2, -1), (2, -1)])
data_stacked = np.hstack([data[:, :3].cpu().numpy(), torch.ones(data.shape[0], 1), data[:, 3:].cpu().numpy()])
psdt.fit(data_stacked[:500], classes[:500].cpu().numpy())
y_numpy = psdt.predict(data_stacked[500:1000])

print((y_torch.cpu().numpy() == y_numpy).sum() / y_torch.shape[0])
print((y_torch.cpu().numpy() == classes[500:1000].cpu().numpy()).sum() / y_torch.shape[0])
print((y_numpy == classes[500:1000].cpu().numpy()).sum() / y_torch.shape[0])

Depth 0 with 500 samples
X shape, torch.Size([500, 10]), y shape torch.Size([500])
Depth 1 with 376 samples
X shape, torch.Size([376, 10]), y shape torch.Size([376])
Depth 2 with 269 samples
X shape, torch.Size([269, 10]), y shape torch.Size([269])
Depth 3 with 122 samples
X shape, torch.Size([122, 10]), y shape torch.Size([122])
Depth 3 with 147 samples
X shape, torch.Size([147, 10]), y shape torch.Size([147])
Depth 2 with 107 samples
X shape, torch.Size([107, 10]), y shape torch.Size([107])
Depth 3 with 107 samples
X shape, torch.Size([107, 10]), y shape torch.Size([107])
Depth 3 with 0 samples
X shape, torch.Size([0, 10]), y shape torch.Size([0])
Depth 1 with 124 samples
X shape, torch.Size([124, 10]), y shape torch.Size([124])
Depth 2 with 124 samples
X shape, torch.Size([124, 10]), y shape torch.Size([124])
Depth 3 with 124 samples
X shape, torch.Size([124, 10]), y shape torch.Size([124])
Depth 3 with 0 samples
X shape, torch.Size([0, 10]), y shape torch.Size([0])
Depth 2 with 0 s