In [1]:
# ! pip install bin2cell
import json
import tifffile
import numpy as np
import cv2
import bin2cell as b2c
from pathlib import Path
import scanpy as sc 
import matplotlib.pyplot as plt
import napari

In [2]:

# Parameters
IMMUNEXID = "IMMUNEX001"
config_name = "standard_LowHE_LowGEX"
mpp = 0.2

he_stardist_params = {
    "stardist_model": "2D_versatile_he",
    "prob_thresh": 0.1,
    'block_size':256,         # ✅ Small block size
    "min_overlap": 64,     # smaller than smallest cell
    "context": 16         # must still satisfy: min_overlap + 2×context < block_size
}

gex_stardist_params = {
    "image_path": "stardist/gex.tiff",
    "labels_npz_path": "stardist/gex.npz",
    "stardist_model": "2D_versatile_fluo",
    "prob_thresh": 0.1,
    "nms_thresh": 0.1
}



# Patched image loader
def patched_load_image(image_path, **kwargs):
    if not isinstance(image_path, (str, Path)):
        raise ValueError(f"Expected a path, got {type(image_path)} instead.")
    print(f"Loading image via tifffile: {image_path}")
    img = tifffile.imread(image_path)
    if img.ndim == 2:
        img = np.stack([img] * 3, axis=-1)
    elif img.shape[0] == 3 and img.ndim == 3:
        img = np.moveaxis(img, 0, -1)
    return img

In [3]:
import numpy as np
import napari

def napari_spots(adata, spot_size=1, colormap='viridis', spot_layers=None, image_layers=None):
    """
    Launch Napari with one or more image layers and spot layers from adata.

    Parameters:
    - adata: AnnData object
    - spot_size: float, size of each spot
    - colormap: str, colormap for numeric features
    - spot_layers: list of adata.obs column names to plot as separate point layers
    - image_layers: list of image sources to include, e.g. ['hires', 'spatial_cropped_150_buffer']
        - 'hires' refers to adata.uns['spatial'][sample_key]['images']['hires']
        - anything else looks in adata.obsm
    """
    if spot_layers is None:
        spot_layers = []
    if image_layers is None:
        image_layers = ['hires']  # default to hires tissue image

    sample_key = list(adata.uns['spatial'].keys())[0]
    scale = adata.uns['spatial'][sample_key]['scalefactors']['tissue_hires_scalef']
    coords = adata.obsm['spatial'] * scale
    coords = coords[:, [1, 0]]

    # Start viewer
    viewer = napari.Viewer()

    # Add requested image layers
    for img_key in image_layers:
        if img_key == 'hires':
            image = adata.uns['spatial'][sample_key]['images']['hires']
            viewer.add_image(image, name='Tissue image (hires)')
        elif img_key in adata.obsm:
            image = adata.obsm[img_key]
            if image.ndim == 2 or (image.ndim == 3 and image.shape[-1] in [1, 3]):
                viewer.add_image(image, name=f'{img_key}')
            else:
                print(f"[Warning] '{img_key}' in adata.obsm is not a valid 2D or RGB image. Skipping.")
        else:
            print(f"[Warning] Image key '{img_key}' not found in adata.uns or obsm. Skipping.")

    # Base UMI spot layer
    umi_values = np.nan_to_num(adata.obs['n_counts'].values)
    base_properties = {'UMI': umi_values}
    base_layer = viewer.add_points(
        coords,
        properties=base_properties,
        face_color='UMI',
        face_colormap=colormap,
        size=spot_size,
        name='UMI-colored spots'
    )

    base_layer.border_width = 0
    base_layer.face_color = 'UMI'
    base_layer.face_color_mode = 'colormap'
    # base_layer.face_color_mode = 'direct'
    
    base_layer.face_contrast_limits = (umi_values.min(), umi_values.max())
    base_layer.show_colorbar = True

    print("Layer properties keys:", base_layer.properties.keys())
    print("Layer properties values:", base_layer.properties.values())
    print("Face color mode:", base_layer.face_color_mode)
    print("Face color source:", base_layer.face_color)
    
    # Additional spot layers
    for layer_name in spot_layers:
        if layer_name in adata.obs.columns:
            values = np.nan_to_num(adata.obs[layer_name].values)
            properties = {layer_name: values}
            layer = viewer.add_points(
                coords,
                properties=properties,
                face_color=layer_name,
                face_colormap=colormap,
                size=spot_size,
                name=f'{layer_name} layer'
            )
            layer.face_color_mode = 'colormap'
            layer.face_contrast_limits = (values.min(), values.max())
            layer.border_width = 0
            layer.show_colorbar = True

        else:
            print(f"[Warning] '{layer_name}' not found in adata.obs. Skipping.")

    print("Napari viewer created.")
    return viewer



In [4]:

base_sample_path = Path(f"/Users/mounim/Documents/IMMUNEX_data/") # spaceranger output folder 
sample_folder = base_sample_path / f"OUTPUT/spaceranger_output//Visium_NSCLC_{IMMUNEXID}"
path_visium = f'{sample_folder}/outs/binned_outputs/square_002um/'

# Load suffix mapping
with open("../data/metadata/he_mapping_suffix.json", "r") as f:
    nanozoomer_tif = json.load(f)
source_image_path = Path(f"{base_sample_path}/IMAGE/IMAGE/HE_nanozoomer_tif/") / f"{IMMUNEXID}{nanozoomer_tif[IMMUNEXID]}.tif"

path_spaceranger = Path(base_sample_path / f"OUTPUT/spaceranger_output//Visium_NSCLC_{IMMUNEXID}/outs/spatial")

# Output paths

# Paths
base_output_dir = Path(f"./data/intermediate/segmentation/bin2cell/{IMMUNEXID}/")
output_dir = base_output_dir / f"{IMMUNEXID}__{config_name}"
output_dir.mkdir(parents=True, exist_ok=True)

# Load and inspect file 

In [5]:
# Start processing
adata = b2c.read_visium(
    path= path_visium, 
    count_file='./filtered_feature_bc_matrix.h5',
    source_image_path= source_image_path,
    spaceranger_image_path= path_spaceranger
)
adata.var_names_make_unique()
print(adata)

adata.obs['n_counts'] = np.sum(adata.X, axis=1).A1 if hasattr(adata.X, 'A1') else np.sum(adata.X, axis=1)
print('n counted')
sc.pp.filter_genes(adata, min_cells=0)
sc.pp.filter_cells(adata, min_counts=0)
print('filtred')
print(adata)

anndata.py (1758): Variable names are not unique. To make them unique, call `.var_names_make_unique`.
anndata.py (1758): Variable names are not unique. To make them unique, call `.var_names_make_unique`.


AnnData object with n_obs × n_vars = 10822530 × 18536
    obs: 'in_tissue', 'array_row', 'array_col'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatial'
    obsm: 'spatial'
n counted
filtred
AnnData object with n_obs × n_vars = 10822530 × 18536
    obs: 'in_tissue', 'array_row', 'array_col', 'n_counts'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells'
    uns: 'spatial'
    obsm: 'spatial'


In [6]:
# napari_spots(adata)

In [7]:
b2c.bin2cell.load_image = patched_load_image

# # Destripe and scale image
b2c.destripe(adata)
print('adata destripped', adata)
he_img_out = output_dir / "he.tiff"

# napari_spots(adata, spot_layers=['n_counts_adjusted'])

_construct.py (163): Series.__getitem__ treating keys as positions is deprecated. In a future version, integer keys will always be treated as labels (consistent with DataFrame behavior). To access a value by position, use `ser.iloc[pos]`


adata destripped AnnData object with n_obs × n_vars = 10822530 × 18536
    obs: 'in_tissue', 'array_row', 'array_col', 'n_counts', 'destripe_factor', 'n_counts_adjusted'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells'
    uns: 'spatial'
    obsm: 'spatial'


In [8]:
b2c.bin2cell.load_image = patched_load_image

b2c.scaled_he_image(
    adata,
    save_path=he_img_out, 
    mpp=mpp
    )
adata

Loading image via tifffile: /Users/mounim/Documents/IMMUNEX_data/IMAGE/IMAGE/HE_nanozoomer_tif/IMMUNEX001_Visium_HE_x40_z0.tif
Cropped spatial coordinates key: spatial_cropped_150_buffer
Image key: 0.2_mpp_150_buffer


AnnData object with n_obs × n_vars = 10822530 × 18536
    obs: 'in_tissue', 'array_row', 'array_col', 'n_counts', 'destripe_factor', 'n_counts_adjusted'
    var: 'gene_ids', 'feature_types', 'genome', 'n_cells'
    uns: 'spatial'
    obsm: 'spatial', 'spatial_cropped_150_buffer'

In [9]:
napari_spots(adata, spot_layers=['n_counts_adjusted'])

Layer properties keys: dict_keys(['UMI'])
Layer properties values: dict_values([array([ 0., 16.,  5., ...,  0.,  3.,  1.], dtype=float32)])
Face color mode: ColorMode.COLORMAP
Face color source: [[0.26700401 0.004874   0.32941499 1.        ]
 [0.2832185  0.12204475 0.441678   1.        ]
 [0.27563746 0.04186161 0.36814395 1.        ]
 ...
 [0.26700401 0.004874   0.32941499 1.        ]
 [0.27257386 0.02547517 0.353002   1.        ]
 [0.26898054 0.01125219 0.33737999 1.        ]]
Napari viewer created.


Viewer(camera=Camera(center=(0.0, 2315.5, 2999.5), zoom=0.0931, angles=(0.0, 0.0, 90.0), perspective=0.0, mouse_pan=True, mouse_zoom=True, orientation=(<DepthAxisOrientation.TOWARDS: 'towards'>, <VerticalAxisOrientation.DOWN: 'down'>, <HorizontalAxisOrientation.RIGHT: 'right'>)), cursor=Cursor(position=(1.0, 1.0), scaled=True, style=<CursorStyle.STANDARD: 'standard'>, size=1.0), dims=Dims(ndim=2, ndisplay=2, order=(0, 1), axis_labels=('0', '1'), rollable=(True, True), range=(RangeTuple(start=0.0, stop=4631.0, step=1.0), RangeTuple(start=0.0, stop=5999.0, step=1.0)), margin_left=(0.0, 0.0), margin_right=(0.0, 0.0), point=(2315.0, 2999.0), last_used=0), grid=GridCanvas(stride=1, shape=(-1, -1), enabled=False, spacing=0.0), layers=[<Image layer 'Tissue image (hires)' at 0x559ec4490>, <Points layer 'UMI-colored spots' at 0x559e4a230>, <Points layer 'n_counts_adjusted layer' at 0x55bb35210>], help='use <5> for transform, use <2> for add points, use <3> for select points', status='Ready', to

In [None]:
# import scanpy as sc
# sc.set_figure_params(figsize=[5, 5], dpi=100)

# # Use explicit keys and object
# sc.pl.spatial(
#     adata,
#     color=[None, "n_counts", "n_counts_adjusted"],
#     img_key=f"{mpp}_mpp_150_buffer",
#     basis="spatial_cropped_150_buffer",
#     cmap='Reds',
#     show=False
# )
# plt.savefig(output_dir / "spatial_destriping.pdf")
# plt.close()

# print('HE preview exported')

In [10]:
sample_key = list(adata.uns['spatial'].keys())[0]
adata.uns['spatial'][sample_key]['images']['0.2_mpp_150_buffer'].shape

(34177, 34175, 3)

# Segmentation

In [None]:
from stardist.models import StarDist2D
import napari
import numpy as np
from tqdm import tqdm
import tifffile


# Load model
model = StarDist2D.from_pretrained('2D_versatile_he')

# Extract a crop from adata
sample_key = list(adata.uns['spatial'].keys())[0]
crop = adata.uns['spatial'][sample_key]['images']['0.2_mpp_150_buffer'][1000:10000, 1000:10000]

# tifffile.imwrite("HE_image.tif", crop = adata.uns['spatial'][sample_key]['images']['0.2_mpp_150_buffer'])
tifffile.imwrite("crop_HE_image.tif", crop)

In [10]:
from stardist.models import StarDist2D
import napari
import numpy as np
from tqdm import tqdm
import tifffile
from skimage.measure import regionprops, label

# Save and reload crop (if needed)
crop = tifffile.imread("crop_HE_image.tif")

# Set up Napari
viewer = napari.Viewer()
viewer.add_image(crop, name='Original H&E Image')

# Load model
model = StarDist2D.from_pretrained('2D_versatile_he')

# Parameters to test
prob_values = [0.1, 0.5]
params = [
    {'block_size': (1000, 1000, 1), 'min_overlap': (100, 100, 0), 'context': (25, 25, 0)},
]

# Color list (adjust length if needed)
color_list = ['red', 'green', 'blue', 'yellow', 'cyan', 'magenta']

# Loop
color_index = 0
for p in tqdm(prob_values):
    for nms in [0.3]:
        for param in params:
            print(f"Segmenting with prob_thresh={p}, nms_thresh={nms}, block={param['block_size']}")

            labels, _ = model.predict_instances_big(
                crop,
                axes='YXC',
                prob_thresh=p,
                nms_thresh=nms,
                block_size=param['block_size'],
                min_overlap=param['min_overlap'],
                context=param['context'],
                verbose=False,
                show_tile_progress=False
            )

            layer_name = f"p={p:.2f} | block={param['block_size'][0]}"
            layer = viewer.add_labels(labels, name=layer_name)
            layer.opacity = 0.5
            layer.rendering = 'translucent'
            print(f" → Cells identified: {np.max(labels)}")

            color_index += 1

napari.run()


Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.


  0%|          | 0/2 [00:00<?, ?it/s]

Segmenting with prob_thresh=0.1, nms_thresh=0.3, block=(1000, 1000, 1)
effective: block_size=(1008, 1008, 3), min_overlap=(112, 112, 0), context=(32, 32, 0)


functional.py (238): The structure of `inputs` doesn't match the expected structure.
Expected: ['input']
Received: inputs=Tensor(shape=(1, 256, 256, 3))


Y: context of 32 is small, recommended to use at least 94
X: context of 32 is small, recommended to use at least 94
changing 'show_tile_progress' from False to False


base.py (406): Predicting on non-float input... ( forgot to normalize? )
functional.py (238): The structure of `inputs` doesn't match the expected structure.
Expected: ['input']
Received: inputs=Tensor(shape=(1, 1008, 1008, 3))
100%|██████████| 4/4 [00:01<00:00,  3.08it/s]
 50%|█████     | 1/2 [00:01<00:01,  1.87s/it]

 → Cells identified: 41
Segmenting with prob_thresh=0.5, nms_thresh=0.3, block=(1000, 1000, 1)
effective: block_size=(1008, 1008, 3), min_overlap=(112, 112, 0), context=(32, 32, 0)
Y: context of 32 is small, recommended to use at least 94
X: context of 32 is small, recommended to use at least 94
changing 'show_tile_progress' from False to False


base.py (406): Predicting on non-float input... ( forgot to normalize? )
100%|██████████| 4/4 [00:01<00:00,  3.39it/s]
100%|██████████| 2/2 [00:03<00:00,  1.55s/it]

 → Cells identified: 35





In [None]:
he_stardist_params = {
    "stardist_model": "2D_versatile_he",
    "prob_thresh": 0.1,
    "block_size": 256,
    "min_overlap": 96,    # ↑ increase overlap
    "context": 16         # keep this small enough to satisfy constraint
}


from tifffile import imread

he_img_out = output_dir / "he.tiff"

img = imread(he_img_out)
print(f"Image shape: {img.shape}")

he_seg_out = output_dir / "he.npz"
# # H&E segmentation
b2c.stardist(
    image_path=str(he_img_out),
    labels_npz_path=str(he_seg_out), 
    **he_stardist_params)

print("Stardist H&E done")


print("Available obsm keys:", adata.obsm.keys())
sample_id = list(adata.uns["spatial"].keys())[0]
print("Image keys:", adata.uns["spatial"][sample_id]["images"].keys())


Image shape: (34177, 34175, 3)
Loading image via tifffile: data/intermediate/segmentation/bin2cell/IMMUNEX001/IMMUNEX001__standard_LowHE_LowGEX/he.tiff
Found model '2D_versatile_he' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.692478, nms_thresh=0.3.
effective: block_size=(256, 256, 3), min_overlap=(96, 96, 0), context=(16, 16, 0)


functional.py (238): The structure of `inputs` doesn't match the expected structure.
Expected: ['input']
Received: inputs=Tensor(shape=(1, 256, 256, 3))


Y: context of 16 is small, recommended to use at least 94
X: context of 16 is small, recommended to use at least 94


  1%|          | 666/71022 [01:00<1:47:05, 10.95it/s]


KeyboardInterrupt: 

: 

In [None]:
adata

Collecting tensorflow
  Downloading tensorflow-2.19.0-cp310-cp310-macosx_12_0_arm64.whl.metadata (4.0 kB)
Collecting absl-py>=1.0.0 (from tensorflow)
  Using cached absl_py-2.2.2-py3-none-any.whl.metadata (2.6 kB)
Collecting astunparse>=1.6.0 (from tensorflow)
  Using cached astunparse-1.6.3-py2.py3-none-any.whl.metadata (4.4 kB)
Collecting flatbuffers>=24.3.25 (from tensorflow)
  Using cached flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 (from tensorflow)
  Using cached gast-0.6.0-py3-none-any.whl.metadata (1.3 kB)
Collecting google-pasta>=0.1.1 (from tensorflow)
  Using cached google_pasta-0.2.0-py3-none-any.whl.metadata (814 bytes)
Collecting libclang>=13.0.0 (from tensorflow)
  Using cached libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl.metadata (5.2 kB)
Collecting opt-einsum>=2.3.2 (from tensorflow)
  Using cached opt_einsum-3.4.0-py3-none-any.whl.metadata (6.3 kB)
Collecting protobuf!=4.21.0,!=4.21.1,!=4.21.2,!=

In [None]:
b2c.insert_labels(adata, str(he_seg_out), basis="spatial", spatial_key="spatial_cropped_150_buffer", mpp=mpp, labels_key="labels_he")

print("Inserted H&E labels")

In [None]:
b2c.expand_labels(
    adata,
    labels_key="labels_he",
    expanded_labels_key="labels_he_expanded"
)
print("Expanded H&E labels")
print(adata)

In [None]:

# Compute the crop from the spatial coordinates
crop = b2c.get_crop(adata, basis="spatial_cropped_150_buffer", mpp=mpp)
he_seg_out = output_dir / "he.npz"
b2c.view_stardist_labels(
    adata,
    labels_npz_path=str(he_seg_out),
    crop=crop,
    labels_key="labels_he_expanded",
    save=output_dir / "stardist_labels_he.png"
)
print("View of H&E exported")

In [None]:
import importlib
importlib.reload(b2c.bin2cell)
grid_image_sigma = 5
# 1. Generate GEX grid image and export
gex_img_out = output_dir / "gex.tiff"
img = b2c.grid_image(adata, "n_counts_adjusted", mpp=mpp, sigma=grid_image_sigma)
cv2.imwrite(str(gex_img_out), img)
print(f"Exported GEX image to: {gex_img_out}")


In [None]:

# 2. Run Stardist on the GEX TIFF
gex_seg_out = output_dir / "gex.npz"
b2c.stardist(
    image_path=str(gex_img_out),
    labels_npz_path=str(gex_seg_out),
    **gex_stardist_params
)
print("Stardist GEX done")

In [None]:
# 3. Insert GEX labels into adata
b2c.insert_labels(
    adata,
    labels_npz_path=str(gex_seg_out),
    basis="spatial",  # or "array" if you prefer
    spatial_key="spatial_cropped_150_buffer",
    mpp=mpp,
    labels_key="labels_gex"
)
print("Inserted GEX labels")

In [None]:


# 4. GEX segmentation overlay (label visualization)
bdata = adata[mask].copy()
bdata = bdata[bdata.obs["labels_gex"] > 0]
bdata.obs["labels_gex"] = bdata.obs["labels_gex"].astype(str)

sc.pl.spatial(
    bdata,
    color=[None, "labels_gex"],
    img_key=f"{mpp}_mpp_150_buffer",
    basis="spatial_cropped_150_buffer",
    show=False
)
plt.savefig(output_dir / "gex_segmentation_labels_gex_overlay.png", dpi=300, bbox_inches='tight', pad_inches=0)
plt.close()
print("Exported GEX segmentation label overlay")

In [None]:

# 5. Normalized GEX overlay image
crop = b2c.get_crop(bdata, basis="spatial", mpp=mpp)
rendered = b2c.view_labels(
    image_path=gex_img_out,
    labels_npz_path=gex_seg_out,
    crop=crop,
    stardist_normalize=True
)
plt.imshow(rendered)
plt.axis("off")
plt.tight_layout()
plt.savefig(output_dir / "gex_segmentation_overlay_normalized.png", dpi=300, bbox_inches='tight', pad_inches=0)
plt.close()
print("Exported normalized GEX overlay")


In [None]:
# 6. Merge H&E and GEX labels
b2c.salvage_secondary_labels(
    adata,
    primary_label="labels_he_expanded",
    secondary_label="labels_gex",
    labels_key="labels_joint"
)
print("Salvaged H&E + GEX labels into labels_joint")


In [None]:

# 7. Plot joint labels
bdata = adata[mask].copy()
bdata = bdata[bdata.obs["labels_joint"] > 0]
bdata.obs["labels_joint"] = bdata.obs["labels_joint"].astype(str)

sc.pl.spatial(
    bdata,
    color=[None, "labels_joint_source", "labels_joint"],
    img_key=f"{mpp}_mpp_150_buffer",
    basis="spatial_cropped_150_buffer",
    show=False
)
plt.savefig(output_dir / "labels_joint_overlay.png", dpi=300, bbox_inches='tight', pad_inches=0)
plt.close()
print("Exported joint label overlay")


In [None]:

# 8. Convert bins to cells
cdata = b2c.bin_to_cell(
    adata,
    labels_key="labels_joint",
    spatial_keys=["spatial", "spatial_cropped_150_buffer"]
)
sc.pl.spatial(
    cdata,
    color=["bin_count"],
    basis="spatial_cropped_150_buffer",
    img_key=f"{mpp}_mpp_150_buffer",
    show=False
)
plt.savefig(output_dir / "spatial_cell_density.pdf")
plt.close()
print("Exported spatial cell density plot")



In [None]:
# Export statistics
stats = {
    "num_bins": adata.shape[0],
    "num_genes": adata.shape[1],
    "total_umis": np.sum(adata.X),
    "avg_umis_per_bin": np.mean(adata.X.sum(axis=1)),
    "num_cells": len(np.unique(adata.obs["labels_joint"])),
    "avg_bin_per_cell": adata.shape[0] / len(np.unique(adata.obs["labels_joint"]))
}

stats