In [17]:
import sys
import os
import anndata as ad

sys.path.append(os.path.abspath(os.path.join("..", "tissuenarrator")))

from utils import benchmark_expression_conversion, reconstruct_expression_from_cell_sentence
from preprocess import generate_vocabulary, build_spatial_df, build_cell_df, generate_cell_sentences, generate_vocabulary


### Step 1: Read raw data

Make sure the AnnData object is **log₁₀-transformed** (using `log1p` with base 10).  
The `.obs` table should include the following required columns:

- `x` — spatial coordinate (x-axis)  
- `y` — spatial coordinate (y-axis)  
- `section` — section or sample identifier  

You can also include any additional metadata columns you want to preserve for downstream analysis.


In [15]:
adata = ad.read_h5ad("/home/sizheliu/spatial-text/data/merfish/merfish_preprocessed.h5ad")

### Step 2: Fit linear model for reconstruction

Use Cell2Sentence’s linear model fitting method to prepare for future reconstruction of cells. 

In [4]:
import numpy as np
import scipy.sparse as sp
from tqdm import tqdm

N = 10000
adata_sub = adata[adata.obs["split"] == "train"].copy()
np.random.seed(42)
idx = np.random.choice(adata_sub.n_obs, size=N, replace=False)
adata_sub = adata_sub[idx].copy()
adata_sub.X = adata_sub.X.astype(np.float32)

top_k_gene_count = 100

for cell_idx in tqdm(range(adata_sub.X.shape[0])):
    ind = np.argpartition(adata_sub.X[cell_idx], -top_k_gene_count)[-top_k_gene_count:]
    ind.sort()
    all_but_ind = np.setdiff1d(np.array(list(range(adata_sub.X.shape[1])), dtype=np.int64), ind)
    adata_sub.X[cell_idx, all_but_ind] = 0

adata_sub.X = sp.csr_matrix(adata_sub.X)
transformation_benchmarking_save_name = "inverse_transformation_top100"
output_path = "./"
benchmark_expression_conversion(
    benchmark_output_dir=output_path,
    save_name=transformation_benchmarking_save_name,
    normalized_expression_matrix=adata_sub.X,
    sample_size=N,
)

100%|██████████| 10000/10000 [00:01<00:00, 8431.13it/s]


Benchmarking with a sample dataset of size 10000




### Step 3: Construct cell sentences

Build **cell sentences** from the processed data and store them in a `cell_df` DataFrame.


In [7]:
label_col_names = ["x", "y", "section", "class", "subclass", "split", "spatial_domain"]
vocabulary = generate_vocabulary(adata)
sentences = generate_cell_sentences(adata, vocabulary, delimiter=' ')
cell_names = adata.obs_names.tolist()
cell_df = build_cell_df(
    cell_names=cell_names, 
    sentences=sentences,
    adata=adata,
    label_col_names=label_col_names
)

100%|██████████| 2616328/2616328 [04:31<00:00, 9635.16it/s] 


In [16]:
print(cell_df.iloc[0]["sentence"])

GAD2 KIT DLX1AS NDNF DDR2 RGS5 SV2C GLIS3 PCP4L1 PROX1 ZMAT4 HCRTR2 EGFR FAM107A PLD5 NXPH1 KCNS3 MAFB KCNAB3 GABBR2 GRIK1 TOX2 CNR1 NFIX MAF CALN1 POU3F3 LAMP5


In [None]:
# cell_df.to_parquet("merfish_cell_sentences.parquet", index=False)

### Step 4: Construct spatial sentences

In [12]:
spatial_df = build_spatial_df(
    cell_df=cell_df,
    bin_width=200,              # group cells within 200×200 regions
    traversal_methods=["nn"],
    truncate=50,
    meta=["class"],             # include cell type in <meta>
    n_repeats=2,                # two traversal augmentations per bin
    random_state=42
)

Processing sections: 100%|██████████| 129/129 [08:14<00:00,  3.83s/it]


In [None]:
# spatial_df.to_parquet("merfish_spatial_sentences.parquet", index=False)

In [14]:
print(spatial_df.iloc[0]["sentence"])

<pos> X: 3773, Y: 3999 <meta> class: OEC <cs> ACSBG1 CLDN5 FBLN2 SOX10 GRIK3 GJA1 NR2F1 IGFBP4 PLCE1 FNDC1 SNTB1 COL18A1 AGT FMO1 TNS1 ZEB2 COL11A1 NFIB GDA TSPAN18 LRP4 RPRM CCND2 ISYNA1 FRAS1 ZIC2 EPB41L4A ABCA8A SLC9A3R1 YIF1B SPON1 COL27A1 ZFHX4 ARHGEF26 CPNE2 KIRREL SYNPO2 GFAP ALDH1L1 NR2F2 BMP6 SEMA3D ADAM12 RXRG FAT1 B9D2 ZFP521 INHBA NTRK3 MAF </cs> <pos> X: 3787, Y: 3980 <meta> class: OEC <cs> AGT GJA1 BCL6 AQP1 ACSBG1 NR2F2 FNDC1 SERPINE2 COL11A1 MYBPC1 FMO1 TSPAN18 FBLN2 IGFBP4 GRIK3 SOX10 SNTB1 TEAD1 KIRREL ROR1 ZFHX4 COL23A1 SLC9A3R1 NFIB DCDC2A MAGED2 MET ZFP521 LRATD2 RNH1 EPHA7 STXBP6 CCDC80 ZFP536 MAFA IQGAP2 SLC25A13 NUDT14 CLSPN BMPR1B ADRA2A ALDH1L1 HOMER3 BMP2 FAT1 TCF7L2 PLSCR4 LHX9 RXRG ST3GAL1 </cs> <pos> X: 3774, Y: 3945 <meta> class: Vascular <cs> ZIC2 GJA1 FN1 IGF2 ITIH5 SERPINF1 DCN RANBP3L ISYNA1 ZIC1 SLC7A11 FMO1 IGFBP4 NBL1 NR2F1 BMP6 EYA1 EYA2 CCN3 IGFBP6 NR2F2 ACSBG1 MAF TSPAN18 LAMA1 AQP4 NTRK3 EPHA8 GFAP SOX10 CLDN11 GPRC5B UCP2 AQP1 LPAR1 NFIB SERPI