# Load Basal Ganglia data

Update `path_base` below.

In [None]:
from pathlib import Path
import anndata as ad
import pandas as pd
import functools as fct
import numpy as np
import numpy.typing as npt
from tqdm import tqdm
from IPython.core.magic import register_cell_magic
from IPython import get_ipython

@register_cell_magic
def skip_if(line, cell):
    if eval(line):
        return
    get_ipython().run_cell(cell)
    
def fetch_data(download_url : str, download_file_name : Path) -> None:
    from urllib.request import urlopen
    from urllib.error import HTTPError, URLError
    import ssl
    import shutil

    if not download_url.startswith("https://"):
        raise ValueError("Only HTTPS URLs are allowed.")

    if download_file_name.exists():
        print(f"Using existing file at: {download_file_name.resolve()}")
        return

    # Ensure parent directory exists
    download_file_name.parent.mkdir(parents=True, exist_ok=True)

    try:
        print(f"Downloading file to: {download_file_name.resolve()}")

        context = ssl.create_default_context()

        # Download the file from `url` and save it locally under `file_name`:
        with urlopen(download_url, context=context, timeout=10) as response, download_file_name.open('wb') as download_out_file:
            shutil.copyfileobj(response, download_out_file)

        print(f"Downloaded file to {download_file_name.resolve()}")
    except HTTPError as e:
        print(f"HTTP error: {e.code} - {e.reason}")
    except URLError as e:
        print(f"URL error: {e.reason}")
    except Exception as e:
        print(f"Unexpected error: {e}")

def load_csv(url: str, cache_path: Path = None) -> pd.DataFrame:
    """
    Load a CSV file from a URL, caching it locally.
    If the file already exists on disk, read from disk instead of downloading.
    """
    from urllib.request import urlopen

    if cache_path is not None and cache_path.exists():
        print(f"Loading cached file: {cache_path}")
        return pd.read_csv(cache_path)

    df_csv = pd.DataFrame()

    print(f"Downloading from {url}")
    with urlopen(url) as response:
        df_csv = pd.read_csv(response)

    if cache_path is not None:
        df_csv.to_csv(cache_path)

    return df_csv

def extract_rank_data(rank_name, rank_color, *argv):
    d_ait_concat_rank     = pd.concat([df.obs[rank_name] for df in argv]).astype('category')
    d_ait_concat_rank_col = pd.concat([df.obs[rank_color] for df in argv]).astype('category')

    rank_id_names = d_ait_concat_rank.cat.categories[:].to_list()
    rank_indices = [np.where(d_ait_concat_rank == rank_id_name)[0].squeeze() for rank_id_name in rank_id_names]
    rank_colors  = [hex_to_rgbf(d_ait_concat_rank_col.iloc[inner_list[0]]) for inner_list in rank_indices]

    return rank_id_names, rank_indices, rank_colors

def compute_rank_means(rank_d_ait : ad.AnnData, rank_name : str, rank_var_idx: np.ndarray) -> tuple[np.ndarray, list[str], npt.NDArray[np.object_]]:
    import gc
    print(f"Load {rank_name} data for {rank_d_ait.filename}")
    rank_data = rank_d_ait.obs[rank_name]
    rank_id_names = rank_data.cat.categories[:].to_list()
    rank_indices = np.array([np.where(rank_data == rank_id_name)[0].squeeze() for rank_id_name in rank_id_names], dtype=object)

    print(f"Copy data to memory")
    features = rank_d_ait.X.to_memory()[:, rank_var_idx]

    n_cols = features.shape[1]
    n_rows = len(rank_indices)
    print(f"Reserve dense matrix: {n_rows} x {n_cols}")
    means = np.zeros(shape=(n_rows, n_cols), dtype=np.float32)

    print(f"Compute means...")
    iterator = tqdm(enumerate(rank_indices), total=n_rows, disable=False)
    for i, rows in iterator:
        if len(rows) == 0:
            continue  # avoid division by zero
        features_sub = features[rows, :]  # still sparse
        means_row = features_sub.mean(axis=0)  # (1, n) matrix
        means[i] = np.asarray(means_row).ravel()

    print(f"Cleaning up...")
    features = None
    del features
    gc.collect()

    return means, rank_id_names, rank_indices

def reverse_jagged_mapping(forward_map: npt.NDArray[np.object_]) -> npt.NDArray[np.object_]:
    """
    Reverses a jagged mapping where the index is the source and the values
    are the destinations.

    Example:
        forward_map = [[0, 1], [2, 3, 4], [5]]
        (0->0,1), (1->2,3,4), (2->5)

        Returns: [[0], [0], [1], [1], [1], [2]]
        (0->0), (1->0), (2->1), (3->1), (4->1), (5->2)

    Example code:
        forward_mapping = np.array([
            np.array([0, 1], dtype=np.int64),
            np.array([2, 3, 4], dtype=np.int64),
            np.array([5], dtype=np.int64),
        ], dtype=object)

        print("Forward Mapping (Jagged Array):")
        for i, destinations in enumerate(forward_mapping):
            print(f"  {i} -> {destinations}")

        # Get the reversed mapping
        reverse_mapping = reverse_jagged_mapping(forward_mapping)

        print("\nReversed Mapping (1D Array):")
        print(reverse_mapping)
        print("\nMeaning:")
        for destination, source in enumerate(reverse_mapping):
            print(f"  {destination} -> {source}")

    Args:
        forward_map: A NumPy array with dtype=object, where each element is
                     a 1D NumPy array of int64 destinations.

    Returns:
        A 1D NumPy array of int64 where the index represents the destination
        and the value represents its source.
    """
    # Handle the edge case of an empty input map
    if len(forward_map) == 0:
        return np.array([], dtype=np.int64)

    # Concatenate all destination values into a single flat array
    # e.g., [0, 1, 2, 3, 4, 5]
    all_destinations = np.concatenate(forward_map)

    # Handle case where there are no destinations
    if all_destinations.size == 0:
        return np.array([], dtype=np.int64)

    # Get the lengths of each inner array to know how many times to repeat the source index
    # e.g., [2, 3, 1]
    lengths = np.array([len(arr) for arr in forward_map])

    # Create a parallel array of source indices
    # np.repeat([0, 1, 2], [2, 3, 1]) -> [0, 0, 1, 1, 1, 2]
    source_indices = np.repeat(np.arange(len(forward_map)), lengths)

    # Create the output array. Its size is determined by the max destination value.
    output_size = np.max(all_destinations) + 1
    reverse_map = np.empty(output_size, dtype=np.int64)

    # Populate the reverse map.
    reverse_map[all_destinations] = source_indices
    reverse_map = np.array([np.array([x]) for x in reverse_map], dtype=object)

    return reverse_map

def hex_to_rgb(hex_str : str):
    hex_str_val = hex_str.lstrip('#')
    return np.array([int(hex_str_val[i:i + 2], 16) for i in (0, 2, 4)], dtype=float)

def hex_to_rgbf(hex_str : str):
    return hex_to_rgb(hex_str) / 255

def ensure_c_contiguous(arr: np.ndarray) -> np.ndarray:
    return arr if arr.flags['C_CONTIGUOUS'] else np.ascontiguousarray(arr)

In [None]:
# Where to store the data
path_base = Path("E:/avieth/Data/Allen/Basal/")

# Settings
use_cluster_mean = True # Whether to add a dataset with means of all shared vars per cluster (requires lots of memory to compute)
use_cache = True        # Cache values that require lots of memory of compute time in path_base
use_prep_umap = True    # Whether to download a precomputed UMAP

# Remote data, see https://alleninstitute.github.io/HMBA_BasalGanglia_Consensus_Taxonomy/
url_base = "https://released-taxonomies-802451596237-us-west-2.s3.us-west-2.amazonaws.com/HMBA/BasalGanglia"
data_version = "BICAN_05072025_pre-print_release"

HUMAN = 0
MACAQ = 1
MARMO = 2
SPECIES = [HUMAN, MACAQ, MARMO]

data_names = [""] * 3
data_names[HUMAN] = "Human_HMBA_basalganglia_AIT_pre-print.h5ad"
data_names[MACAQ] = "Macaque_HMBA_basalganglia_AIT_pre-print.h5ad"
data_names[MARMO] = "Marmoset_HMBA_basalganglia_AIT_pre-print.h5ad"

data_ait_urls = [f"{url_base}/{data_version}/{data_names[kind]}" for kind in SPECIES]
data_ait_paths = [path_base / data_names[kind] for kind in SPECIES]

print("Ensure data is available locally...")

# Download data
for kind in SPECIES:
    fetch_data(data_ait_urls[kind], data_ait_paths[kind])

print("Data is available locally. Loading data...")

# Load anndata [you'll need a good amount of RAM even though not all data is loaded to memory]
d_ait = [ad.io.read_h5ad(data_ait_paths[kind], backed='r') for kind in SPECIES]

# expected:
# (1'034'819, 36'601)
# (  548'281, 35'219)
# (  313'033, 35'787)
for kind in SPECIES:
    print(f"({d_ait[kind].n_obs:,}, {d_ait[kind].n_vars:,})".replace(",", "'").rjust(19))

print("Finished loading data.")

In [None]:
# Prep some meta data
print(f"Compute shared variable names")
shared_var_names = fct.reduce(np.intersect1d, [d_ait[kind].var_names.to_numpy() for kind in SPECIES])
shared_var_names = np.array(np.sort(shared_var_names), dtype=np.dtypes.StringDType())

all_obs_names    = np.concatenate([d_ait[kind].obs_names.to_numpy() for kind in SPECIES])
shared_obs_names = fct.reduce(np.intersect1d, all_obs_names)

assert len(shared_obs_names) == 0

species_indices_starts = [
    0,
    d_ait[HUMAN].n_obs,
    d_ait[HUMAN].n_obs + d_ait[MACAQ].n_obs
]
species_indices_ends   = [
    d_ait[HUMAN].n_obs,
    d_ait[HUMAN].n_obs + d_ait[MACAQ].n_obs,
    d_ait[HUMAN].n_obs + d_ait[MACAQ].n_obs + d_ait[MARMO].n_obs
]

print(f"Compute indices")
var_indices = [np.array([d_ait[kind].var_names.get_loc(var) for var in shared_var_names if var in d_ait[kind].var_names]) for kind in SPECIES]

print(f"Num combined observations: {species_indices_ends[-1]}")
print(f"Num shared variables: {shared_var_names.shape[0]}")

In [None]:
%%skip_if not use_cluster_mean
# Mean feature for each cluster
print(f"Get cluster means")

def check_cluster_cache_exists(cluster_path_base : Path) -> bool :
    return (
        (cluster_path_base / "cluster_means_concat.npy").is_file()
        and (cluster_path_base / "cluster_names_concat.npy").is_file()
        and (cluster_path_base / "cluster_indices_concat.npy").is_file()
        and (cluster_path_base / "reverse_cluster_map.npy").is_file()
    )

def save_cluster_cache(cluster_path_base : Path, cluster_means_concat_cache, cluster_names_concat_cache, cluster_indices_concat_cache, reverse_cluster_map_cache) -> None:
    np.save(cluster_path_base / "cluster_means_concat.npy", cluster_means_concat_cache, allow_pickle=True)
    np.save(cluster_path_base / "cluster_names_concat.npy", cluster_names_concat_cache, allow_pickle=True)
    np.save(cluster_path_base / "cluster_indices_concat.npy", cluster_indices_concat_cache, allow_pickle=True)
    np.save(cluster_path_base / "reverse_cluster_map.npy", reverse_cluster_map_cache, allow_pickle=True)

def load_cluster_cache(cluster_path_base : Path):
    cluster_means_concat_cache    = np.load(cluster_path_base/ "cluster_means_concat.npy", allow_pickle=True)
    cluster_names_concat_cache    = np.load(cluster_path_base/ "cluster_names_concat.npy", allow_pickle=True)
    cluster_indices_concat_cache  = np.load(cluster_path_base/ "cluster_indices_concat.npy", allow_pickle=True)
    reverse_cluster_map_cache     = np.load(cluster_path_base/ "reverse_cluster_map.npy", allow_pickle=True)
    return cluster_means_concat_cache, cluster_names_concat_cache, cluster_indices_concat_cache, reverse_cluster_map_cache

cluster_cache_exists = check_cluster_cache_exists(path_base)

if cluster_cache_exists and use_cache:
    print("Load cluster means from cache")
    cluster_means_concat, cluster_names_concat, cluster_indices_concat, reverse_cluster_map = load_cluster_cache(path_base)
else:
    print("Compute cluster means")
    cluster_means_human, cluster_names_human, cluster_indices_human = compute_rank_means(d_ait[HUMAN], "Cluster", var_indices[HUMAN])
    cluster_means_macaq, cluster_names_macaq, cluster_indices_macaq = compute_rank_means(d_ait[MACAQ], "Cluster", var_indices[MACAQ])
    cluster_means_marmo, cluster_names_marmo, cluster_indices_marmo = compute_rank_means(d_ait[MARMO], "Cluster", var_indices[MARMO])
    
    # Concatenate features data
    cluster_means_concat   = np.concatenate([cluster_means_human, cluster_means_macaq, cluster_means_marmo], axis=0)
    cluster_names_concat   = [name for cluster in [cluster_names_human, cluster_names_macaq, cluster_names_marmo] for name in cluster]
    cluster_indices_concat = np.concatenate([
        cluster_indices_human + species_indices_starts[HUMAN],
        cluster_indices_macaq + species_indices_starts[MACAQ],
        cluster_indices_marmo + species_indices_starts[MARMO]], 
        axis=0)
    
    # Reverse the mapping, from data to clusters
    reverse_cluster_map = reverse_jagged_mapping(cluster_indices_concat)

    if use_cache:
        print("Save cluster means to cache")
        save_cluster_cache(path_base, cluster_means_concat, cluster_names_concat, cluster_indices_concat, reverse_cluster_map)

assert reverse_cluster_map.shape[0] == species_indices_ends[-1]

print(f"Num combined clusters: {cluster_means_concat.shape[0]}")

In [None]:
print(f"Concatenated scVI features")
# Extract scVI feature matrix (dense, 64 variables expected)
d_ait_scVI = [d_ait[kind].obsm['X_scVI'] for kind in SPECIES]

# Concatenate scVI features
d_ait_concat_scVI = np.concatenate(d_ait_scVI, axis=0)
#d_ait_concat_scVI = np.concatenate((d_ait_marmo_scVI, d_ait_marmo_scVI), axis=0)

print(f"Concatenated size: {d_ait_concat_scVI.shape}")

In [None]:
# Extract and concatenate classification
print("Extract rank data")
neigh_names,    neigh_indices,    neigh_colors    = extract_rank_data('Neighborhood', 'color_hex_neighborhood', d_ait[HUMAN], d_ait[MACAQ], d_ait[MARMO])
class_names,    class_indices,    class_colors    = extract_rank_data('Class', 'color_hex_class', d_ait[HUMAN], d_ait[MACAQ], d_ait[MARMO])
subclass_names, subclass_indices, subclass_colors = extract_rank_data('Subclass', 'color_hex_subclass', d_ait[HUMAN], d_ait[MACAQ], d_ait[MARMO])
group_names,    group_indices,    group_colors    = extract_rank_data('Group', 'color_hex_group', d_ait[HUMAN], d_ait[MACAQ], d_ait[MARMO])

species_names   = ['Human', 'Macaque', 'Marmoset'] # human, macaque, marmoset
species_colors  = [hex_to_rgbf('1b6097'), hex_to_rgbf('318e2d'), hex_to_rgbf('db423f')] # human, macaque, marmoset
species_indices = [
    np.arange(species_indices_starts[0], species_indices_ends[0]),
    np.arange(species_indices_starts[1], species_indices_ends[1]),
    np.arange(species_indices_starts[2], species_indices_ends[2])
]

print("Finished extracting rank data")

In [None]:
%%skip_if not use_prep_umap
# Download pre-computed UMAP
print("Load pre-computed UMAP")

# Download / use cache
umap_url = "https://allen-brain-cell-atlas.s3-us-west-2.amazonaws.com/metadata/HMBA-BG-taxonomy-CCN20250428/20250531/cell_2d_embedding_coordinates.csv"
umap_cache_file = path_base / "cell_2d_embedding_coordinates.csv"

# Extract coordinates from data frame
d_umap = load_csv(umap_url, umap_cache_file)
umap_coords = d_umap[["x", "y"]].to_numpy(dtype="float32")

umap_coords = ensure_c_contiguous(umap_coords)

# Check names 
umap_labels = d_umap[["cell_label"]].to_numpy().ravel()
assert np.array_equal(umap_labels, all_obs_names)

In [None]:
import mvstudio.data
dh = mvstudio.data.Hierarchy()

print("Add points to ManiVault")
mv_concat_scVI = dh.addPointsItem(d_ait_concat_scVI, "scVI")

print("Add clusters to ManiVault")
mv_concat_species  = dh.addClusterItem(mv_concat_scVI.datasetId, species_indices, "Species", names=species_names, colors=species_colors)
mv_concat_neigh    = dh.addClusterItem(mv_concat_scVI.datasetId, neigh_indices, "Neighbors", names=neigh_names, colors=neigh_colors)
mv_concat_class    = dh.addClusterItem(mv_concat_scVI.datasetId, class_indices, "Class", names=class_names, colors=class_colors)
mv_concat_subclass = dh.addClusterItem(mv_concat_scVI.datasetId, subclass_indices, "Subclass", names=subclass_names, colors=subclass_colors)
mv_concat_group    = dh.addClusterItem(mv_concat_scVI.datasetId, group_indices, "Group", names=group_names, colors=group_colors)

if use_cluster_mean:
    print("Add aggregated feature data to ManiVault")
    mv_concat_clusterM = dh.addPointsItem(cluster_means_concat, "Cluster expression (mean)", mv_concat_scVI.datasetId, shared_var_names)

    # Only link in one direction. Adding both mappings would cause an unfortunate back-and-forth selection loop
    #success_link = mv_concat_clusterM.setLinkedData(mv_concat_scVI, cluster_indices_concat)
    success_link = mv_concat_scVI.setLinkedData(mv_concat_clusterM, reverse_cluster_map)
    assert success_link

if use_prep_umap:
    print("Add pre-computed UMAP to ManiVault")
    mv_concat_scVI = dh.addDerivedPointsItem(umap_coords, "UMAP embedding", mv_concat_scVI.datasetId, ["UMAP x", "UMAP, y"])

print("Done adding data to ManiVault")