In [None]:
import os
import k3d
import copy

import numpy as np
import anndata as ad
import seaborn as sns

from typing import Optional

In [None]:
def rgb_to_hex(rgb):
    color = '#'
    for i in rgb:
        num = int(i * 255)
        color += str(hex(num))[-2:].replace('x', '0')  # .upper()
    return color

def plot_cloud_point(adata_list: list,
                     spatial_key: str = 'spatial',
                     anno: str = 'annotation',
                     color_anno: str = None,
                     color_map: Optional[dict] = None,
                     point_size: float = 0.01,
                     z_intervel: float = 0.008,
                     opacity: float = 0.8
                     ):
    """
    Transform h5ad file into input format required by 3D analysis pipeline.
    """
    # Create color mapping
    if color_map is not None:
        pass
    elif color_anno is not None:
        color_map = {}
        for ad in adata_list:
            _color_map = ad.uns[color_anno]
            ind = ad.obs[anno].values.categories.to_list()
            ind = list(map(int, ind))
            color_map.update(zip(ind, _color_map))
    else:
        max_ind = 0
        for ad in adata_list:
            ind = ad.obs[anno].values.categories.to_list()
            if len(ind) > max_ind:
                max_ind = len(ind)

        _color_map = list(sns.cubehelix_palette(n_colors = ind)) if max_ind > 20 else \
            (sns.color_palette('tab20') if max_ind > 10 else sns.color_palette('tab10'))

        color_map = dict(zip(list(range(len(_color_map))), [rgb_to_hex(i) for i in _color_map]))

    # Handle both dictionary and array-based color maps
    if isinstance(color_map, dict):
        annotation = list(color_map.keys())
        colors = np.array(list(color_map.values()))
    else:
        unique_labels = np.unique(adata_list[0].obs[anno])
        annotation = list(unique_labels)
        colors = np.array(color_map)

        # Ensure enough colors
        if len(colors) < len(annotation):
            print(f"Warning: Not enough colors ({len(colors)}) for all labels ({len(annotation)}).")
            colors = np.tile(colors, len(annotation) // len(colors) + 1)
            colors = colors[:len(annotation)]

    # Convert hex colors to integers
    color_arr = np.array([int(c[1:], 16) for c in colors], dtype = np.uint32)
    color_map = dict(zip(annotation, color_arr))
    
    # Group points by annotation with explicit float32 conversion
    pts_map = {}
    for z, adata in enumerate(adata_list):
        for i in range(len(adata.obs)):
            annot = adata.obs[anno].iloc[i]

            if annot not in pts_map:
                pts_map[annot] = np.empty((0, 3), dtype = np.float32)  # Explicit dtype

            # Copy and convert to float32
            _data = adata.obsm[spatial_key][i].astype(np.float32)
            _data[2] = z_intervel * z
            pts_map[annot] = np.append(pts_map[annot], [_data], axis = 0)

    # Create k3d plot
    plot = k3d.plot()
    for key, val in sorted(pts_map.items()):
        try:
            plt_points = k3d.points(
                positions = val,  # Already float32
                colors = [color_map[int(key)]] * val.shape[0],
                point_size = point_size,
                shader = '3dSpecular',
                opacity = opacity,
                name = str(key)
            )
            plot += plt_points
        except KeyError as e:
            print(f"Warning: No color for label {key}, using default.")
            plt_points = k3d.points(
                positions = val,
                colors = [0x808080] * val.shape[0],
                point_size = point_size,
                shader = '3dSpecular',
                opacity = opacity,
                name = str(key)
            )
            plot += plt_points

    return plot

In [8]:
# Specify the directory containing h5ad files
h5ad_dir = r"D:\RD_Data\3D\demo_output_20250416\06.color"

# Get paths to all h5ad files in the directory
h5ad_files = [os.path.join(h5ad_dir, f) for f in os.listdir(h5ad_dir) 
              if f.endswith('.h5ad')]

# Sort files by name (optional)
h5ad_files.sort()

# Read all h5ad files
adata_list = []
for file_path in h5ad_files:
    try:
        adata = ad.read_h5ad(file_path)
        adata_list.append(adata)
        print(f"Successfully read: {os.path.basename(file_path)}")
    except Exception as e:
        print(f"Failed to read: {os.path.basename(file_path)}, Error: {e}")

# Check if any data was successfully read
if not adata_list:
    print("No readable h5ad files found!")
else:
    # Plot 3D point cloud
    plot = plot_cloud_point(
        adata_list=adata_list,
        spatial_key='spatial_mm',
        anno='leiden',
        # color_anno='leiden_colors',
        z_intervel=0.02,
        opacity=1
    )

Successfully read: A02183A1.h5ad
Successfully read: A02183A2.h5ad
Successfully read: A02183A3.h5ad
Successfully read: A02183A4.h5ad
Successfully read: A02183A5.h5ad
Successfully read: A02183A6.h5ad
Successfully read: A02183A7.h5ad
Successfully read: A02183A8.h5ad
Successfully read: A02183A9.h5ad
Successfully read: A02183B1.h5ad
Successfully read: A02183B2.h5ad
Successfully read: A02183B3.h5ad
Successfully read: A02183B4.h5ad
Successfully read: A02183B5.h5ad
Successfully read: A02183B6.h5ad
Successfully read: A02183B7.h5ad


In [9]:
plot.display()

Output()