In [22]:
import os
import torch
import numpy as np
import scanpy as sc
import rootutils
from torch_geometric.data import Data
from sklearn.neighbors import NearestNeighbors
rootutils.setup_root(os.getcwd(), indicator=".project-root", pythonpath=True)
from src.data.spatial_omics_datamodule import SpatialOmicsDataModule
from src.utils.preprocess_helpers import (
    read_samples_into_dict,
    save_sample,
    preprocess_sample,
    euclid_dist,
    create_graph,
    SpatialOmicsDataset
)

In [2]:
def print_sample(sample_name, adata):
    """
    Visualizes spatial embeddings of a single AnnData object and prints its details.

    Parameters:
    ----------
    sample_name : str
        The name of the sample.
    adata : AnnData
        The AnnData object containing spatial and gene expression data.

    Behavior:
    --------
    - Prints the sample name and the corresponding AnnData object.
    - Determines the color key (`domain`, `region`, or `layer`) based on the sample name prefix.
    - Plots the spatial embedding of the AnnData object using `scanpy.pl.embedding`.

    Notes:
    -----
    - The `domain` is used for samples starting with "MERFISH".
    - The `region` is used for samples starting with "STARmap".
    - The `layer` is used for samples starting with "BaristaSeq".
    - The spatial embedding is visualized using the `spatial` basis.
    """
    if sample_name.startswith("MERFISH"):
        domain = "domain"
    elif sample_name.startswith("STARmap"):
        domain = "region"
    elif sample_name.startswith("BaristaSeq"):
        domain = "layer"
    
    print(f"Sample: {sample_name}")
    print(adata)
    sc.pl.embedding(adata, basis="spatial", color=domain)

In [3]:
raw_file_paths = [
    "../data/domain/raw/MERFISH_small1.h5ad",
    "../data/domain/raw/MERFISH_small2.h5ad",
    "../data/domain/raw/MERFISH_small3.h5ad",
    "../data/domain/raw/MERFISH_small4.h5ad",
    "../data/domain/raw/MERFISH_small5.h5ad",
    "../data/domain/raw/STARmap1.h5ad",
    "../data/domain/raw/STARmap2.h5ad",
    "../data/domain/raw/STARmap3.h5ad",
    "../data/domain/raw/STARmap4.h5ad",
    "../data/domain/raw/BaristaSeq1.h5ad",
    "../data/domain/raw/BaristaSeq2.h5ad",
    "../data/domain/raw/BaristaSeq3.h5ad"
]

processed_file_paths = [
    "../data/domain/processed/MERFISH_small1.h5ad",
    "../data/domain/processed/MERFISH_small2.h5ad",
    "../data/domain/processed/MERFISH_small3.h5ad",
    "../data/domain/processed/MERFISH_small4.h5ad",
    "../data/domain/processed/MERFISH_small5.h5ad",
    "../data/domain/processed/STARmap1.h5ad",
    "../data/domain/processed/STARmap2.h5ad",
    "../data/domain/processed/STARmap3.h5ad",
    "../data/domain/processed/STARmap4.h5ad",
    "../data/domain/processed/BaristaSeq1.h5ad",
    "../data/domain/processed/BaristaSeq2.h5ad",
    "../data/domain/processed/BaristaSeq3.h5ad"
]


raw_samples = read_samples_into_dict(raw_file_paths)
print(len(raw_samples))

processed_samples = read_samples_into_dict(processed_file_paths)
print(len(processed_samples))

graphs = torch.load("../data/domain/processed/dataset.pt", weights_only=False)
print(len(graphs))

12
12
12


In [30]:
for g in graphs:
    print(g)

Data(x=[1512, 50], edge_index=[2, 50170], edge_weight=[50170], sample_name='BaristaSeq1')
Data(x=[5488, 50], edge_index=[2, 178606], edge_weight=[178606], sample_name='MERFISH_small1')
Data(x=[1088, 50], edge_index=[2, 36254], edge_weight=[36254], sample_name='STARmap4')
Data(x=[1049, 50], edge_index=[2, 34960], edge_weight=[34960], sample_name='STARmap2')
Data(x=[1053, 50], edge_index=[2, 35386], edge_weight=[35386], sample_name='STARmap3')
Data(x=[5543, 50], edge_index=[2, 180856], edge_weight=[180856], sample_name='MERFISH_small5')
Data(x=[5803, 50], edge_index=[2, 188964], edge_weight=[188964], sample_name='MERFISH_small4')
Data(x=[1207, 50], edge_index=[2, 40370], edge_weight=[40370], sample_name='STARmap1')
Data(x=[1627, 50], edge_index=[2, 53984], edge_weight=[53984], sample_name='BaristaSeq3')
Data(x=[5926, 50], edge_index=[2, 193204], edge_weight=[193204], sample_name='MERFISH_small3')
Data(x=[5557, 50], edge_index=[2, 182046], edge_weight=[182046], sample_name='MERFISH_small2