# Create datasets based on full Xenium data

Use all >400k cells in Xenium ovarian cancer dataset for SCimilarity embeddings

In [None]:
import numpy as np
import anndata as ad
import scanpy as sc
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import r2_score
import torch
from torch.utils.data import random_split, DataLoader

from scimilarity.utils import align_dataset, lognorm_counts
from scimilarity.cell_embedding import CellEmbedding
from scimilarity import CellAnnotation
from scimilarity import CellQuery

from tqdm.auto import tqdm

import sys
sys.path.append("../data_processing/SCimilarity/training_mlp/model_archs")
import mlp_img_embed_to_scimilarity

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
scimilarity_model_path = '../SCimilarity/models/'
out_dir = '../../input_data/'
data_dir = 
sdata_path = 

In [4]:
# Initialize SCimilarity classes
ca = CellAnnotation(model_path=f"{scimilarity_model_path}model_v1.1")
ce = CellEmbedding(model_path=f"{scimilarity_model_path}model_v1.1")
# cq = CellQuery(model_path=f"{scimilarity_model_path}model_v1.1")

In [5]:
ovary_cell_types = ['B Cell', 'T Cell', 'macrophage', 'monocyte', 'epithelial cell', 'fibroblast', 'endothelial cell', 'dendritic cell', 'stromal cell of ovary', 'leukocyte', 'mast cell', 'myofibroblast cell', 'smooth muscle cell']

In [6]:
ca.safelist_celltypes(ovary_cell_types)

## Load Transcriptome

In [7]:
full_dataset = ad.read_zarr(f'{sdata_path}tables/table')



In [8]:
full_dataset

AnnData object with n_obs × n_vars = 407124 × 5101
    obs: 'cell_id', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'region', 'z_level', 'cell_labels', 'cell_type', 'is_tumor_cell_type'
    var: 'gene_ids', 'feature_types', 'genome'
    uns: 'spatialdata_attrs'
    obsm: 'spatial'

## Filter Transcriptome

In [18]:
raw = full_dataset.copy()
sc.pp.filter_genes(
    raw,
    # min_cells=5,
    min_counts=2000,
    # inplace=False,
)
sc.pp.filter_genes(
    raw,
    min_cells=1000,
    # min_counts=10,
    # inplace=False,
)
sc.pp.filter_cells(
    raw,
    min_counts=200,
)
sc.pp.filter_cells(
    raw,
    min_genes=100,
)



In [19]:
full_dataset.obs['transcriptome_passed_QC'] = full_dataset.obs['cell_id'].isin(raw.obs['cell_id'])
full_dataset.var['gene_passed_QC'] = full_dataset.var['gene_ids'].isin(raw.var['gene_ids'])

## Load 10X cell type annotation

In [11]:
cell_types = pd.read_csv(f'{data_dir}Xenium_Prime_Ovarian_Cancer_FFPE_XRrun_cell_groups.csv')
cell_types.rename(columns={'group': '10X_cell_type', 'cell_id': 'cell_id_tmp'}, inplace=True)

In [12]:
merged_obs = full_dataset.obs.merge(
    cell_types,
    left_on='cell_id',
    right_on='cell_id_tmp',
    how='left',
)
merged_obs.drop(columns='cell_id_tmp', inplace=True)
full_dataset.obs = merged_obs

In [21]:
full_dataset.obs['is_ovary_cell_type'] = full_dataset.obs['10X_cell_type'].isin(ovary_cell_types)

## Load ViT-MAE Xenium finetuned image embeddings

In [None]:
image_dir_train = '../input_data/ViTMAE_training_ovarian_cancer/xenium_vitmae_train_features.h5ad'
image_dir_val = '../input_data/ViTMAE_training_ovarian_cancer/xenium_vitmae_val_features.h5ad'
image_dir_test = '../input_data/ViTMAE_training_ovarian_cancer/xenium_vitmae_test_features.h5ad'

In [23]:
image_embeds = ad.concat([
    ad.read_h5ad(image_dir_train),
    ad.read_h5ad(image_dir_val),
    ad.read_h5ad(image_dir_test),
])

  utils.warn_names_duplicates("obs")


In [24]:
image_embeds

AnnData object with n_obs × n_vars = 406875 × 768
    obs: 'targets', 'cell_ids', 'leiden_res0_25'
    obsm: 'X_pca', 'X_umap'

In [25]:
full_dataset

AnnData object with n_obs × n_vars = 407124 × 5101
    obs: 'cell_id', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'region', 'z_level', 'cell_labels', 'cell_type', 'is_tumor_cell_type', '10X_cell_type', 'transcriptome_passed_QC', 'is_ovary_cell_type'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_passed_QC'
    uns: 'spatialdata_attrs'
    obsm: 'spatial'

In [26]:
image_cell_ids = image_embeds.obs['cell_ids']

In [27]:
## Some cells were missing from the image h5 file, likely edge cells that could not be extracted
intersection = set(full_dataset.obs['cell_labels']) & set(image_cell_ids)
intersection_count = len(intersection)
intersection_ratio = len(intersection)/len(full_dataset.obs['cell_labels'])
intersection_diff = len(intersection) - len(full_dataset.obs['cell_labels'])

print("Intersection count:", intersection_count)
print("Intersection ratio:", intersection_ratio)
print("Intersection difference:", intersection_diff)

common_cell_ids = list(intersection)

Intersection count: 406875
Intersection ratio: 0.9993883927255578
Intersection difference: -249


In [None]:
# intersection does not preserve order
common_cell_ids_set = set(common_cell_ids)
subset_cell_ids = [x for x in image_cell_ids if x in common_cell_ids_set]

In [None]:
# Filter Transcriptome to image data, ViT-MAE data are missing a few cells, likely edge cells that could not be extracted
full_dataset = full_dataset[full_dataset.obs['cell_labels'].isin(common_cell_ids)].copy()



In [30]:
full_dataset

AnnData object with n_obs × n_vars = 406875 × 5101
    obs: 'cell_id', 'transcript_counts', 'control_probe_counts', 'genomic_control_counts', 'control_codeword_counts', 'unassigned_codeword_counts', 'deprecated_codeword_counts', 'total_counts', 'cell_area', 'nucleus_area', 'nucleus_count', 'segmentation_method', 'region', 'z_level', 'cell_labels', 'cell_type', 'is_tumor_cell_type', '10X_cell_type', 'transcriptome_passed_QC', 'is_ovary_cell_type'
    var: 'gene_ids', 'feature_types', 'genome', 'gene_passed_QC'
    uns: 'spatialdata_attrs'
    obsm: 'spatial'

In [31]:
index_map = {cell_id: i for i, cell_id in enumerate(subset_cell_ids)}
adata_indices = [index_map[cell_id] for cell_id in full_dataset.obs["cell_labels"]]

In [32]:
full_dataset.obsm['X_vitmae_finetuned_img_features'] = image_embeds.X[adata_indices]

### Scale embeddings to [0,1] per row

In [34]:
row_mins = full_dataset.obsm['X_vitmae_finetuned_img_features'].min(axis=1, keepdims=True)
row_maxs = full_dataset.obsm['X_vitmae_finetuned_img_features'].max(axis=1, keepdims=True)
scaled_img_features = (full_dataset.obsm['X_vitmae_finetuned_img_features'] - row_mins) / (
    row_maxs - row_mins
)
full_dataset.obsm['X_vitmae_finetuned_img_features_minmax'] = scaled_img_features

### Get ViT-MAE test set

In [35]:
test_set = ad.read_h5ad(image_dir_test)

In [36]:
full_dataset.obs['is_in_vitmae_test_set'] = full_dataset.obs['cell_labels'].isin(test_set.obs['cell_ids'])

## Get SCimilarity for Transcriptome

Since these are not stochastically sampled counts as in dissociated transcriptomics, but probe-based spatial transcriptomics data, we are not log1p normalizing them.

In [48]:
norm = full_dataset.copy()
sc.pp.normalize_total(norm, target_sum=1e4)
aligned = align_dataset(norm, ce.gene_order, gene_overlap_threshold=1000)
embeddings = ce.get_embeddings(aligned.X)

In [49]:
full_dataset.obsm['X_SCimilarity_transcriptome_embeds'] = embeddings

### Get SCimilarity cell types

In [52]:
predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(
    full_dataset.obsm['X_SCimilarity_transcriptome_embeds'],
    weighting=True,
)
full_dataset.obs['SCimilarity_transcriptome_cell_type'] = predictions.to_numpy()
full_dataset.obs['SCimilarity_transcriptome_min_dist'] = stats['min_dist'].to_numpy()

Get nearest neighbors finished in: 1.5862017671267192 min


100%|██████████| 406875/406875 [00:27<00:00, 14690.24it/s]


## Save for Training SCimilarity-to-ViT-MAE prediction

In [56]:
full_dataset.write_h5ad(f'{out_dir}xenium_ovarian_cancer_vitmae_feats.h5ad')

## Predict SCimilarity from ViT-MAE image embeddings

Trained On All Cells Normalized to 10k Counts per Cell

In [None]:
epoch_to_load = 2342
scalers_dir = "/fs/gpfs41/lv03/fileset01/pool/pool-mann-maedler-shared/niklas_workspace/scPortrait_vit_250726/SCimilarity/training_mlp/checkpoints/predict_scim_v1_1/scaled_norm10k/all_cells/"
ckpt_path = f"/fs/gpfs41/lv03/fileset01/pool/pool-mann-maedler-shared/niklas_workspace/scPortrait_vit_250726/SCimilarity/training_mlp/checkpoints/predict_scim_v1_1/scaled_norm10k/all_cells/ColNorm-epoch={epoch_to_load}-end.ckpt"

In [61]:
device='cuda:1'
model = mlp_img_embed_to_scimilarity.get_mlp(output_size=128)
checkpoint = torch.load(ckpt_path, weights_only=False)
# remove 'model.' in state dict
pruned_state_dict = {}
for key in checkpoint['state_dict']:
    new_key = key[6:]
    pruned_state_dict[new_key] = checkpoint['state_dict'][key]
model.load_state_dict(pruned_state_dict)
model.eval()
model.to(device)

Sequential(
  (0): Linear(in_features=768, out_features=512, bias=True)
  (1): ReLU()
  (2): Linear(in_features=512, out_features=512, bias=True)
  (3): ReLU()
  (4): Linear(in_features=512, out_features=256, bias=True)
  (5): ReLU()
  (6): Linear(in_features=256, out_features=128, bias=True)
  (7): ReLU()
  (8): Linear(in_features=128, out_features=128, bias=True)
)

In [62]:
means = np.load(f"{scalers_dir}means.npy")
stds = np.load(f"{scalers_dir}stds.npy")

In [63]:
# Use row-normalized img features
img_features = full_dataset.obsm['X_vitmae_finetuned_img_features_minmax']
scim_embeds = full_dataset.obsm['X_SCimilarity_transcriptome_embeds']
scim_embeds, _, _ = mlp_img_embed_to_scimilarity.scale(scim_embeds, means, stds)
ids = full_dataset.obs['cell_labels']

dataset = mlp_img_embed_to_scimilarity.Embeddings_transcripts_dataset_ids(img_features, scim_embeds, ids)

In [None]:
torch.manual_seed(920924)
num_samples = len(dataset)
train_size = 0
val_size = 0
test_size = 1
train_dataset, val_dataset, test_dataset = random_split(dataset, [train_size, val_size, test_size])
test_loader = DataLoader(test_dataset, batch_size=2048, num_workers=0, shuffle=False)



In [65]:
img_reprs = []
preds = []
origs = []
ids = []
with torch.no_grad():
    for batch in tqdm(iter(test_loader)):
        img_feats = batch[0].to(device)
        img_reprs.append(img_feats.detach().cpu().numpy())
        pred = model(img_feats)
        preds.append(pred.detach().cpu().numpy())
        origs.append(batch[1].detach().cpu().numpy())
        ids.append(batch[2].detach().cpu().numpy())

np.random.seed(920924)
shuffled_indices = np.random.permutation(len(row_mins))
row_mins_shuffled = row_mins[shuffled_indices]
row_maxs_shuffled = row_maxs[shuffled_indices]

img_reprs_scaled = np.concatenate(img_reprs)
img_reprs = mlp_img_embed_to_scimilarity.undo_row_norm(np.concatenate(img_reprs), row_mins_shuffled, row_maxs_shuffled)
origs = np.concatenate(origs)
origs = mlp_img_embed_to_scimilarity.undo_standard_scaling(
    origs,
    means=means,
    stds=stds,
)
preds = np.concatenate(preds)
preds = mlp_img_embed_to_scimilarity.undo_standard_scaling(
    preds,
    means=means,
    stds=stds,
)
ids = np.concatenate(ids)

  return self.embeddings[idx], self.gene_expressions[idx], self.ids[idx]
100%|██████████| 199/199 [00:08<00:00, 24.38it/s]


In [66]:
index_map = {cell_id: i for i, cell_id in enumerate(ids)}
adata_indices = [index_map[cell_id] for cell_id in full_dataset.obs["cell_labels"]]

In [67]:
full_dataset.obsm['X_SCimilarity_image_embeds'] = preds[adata_indices]

### Get SCimilarity cell types

In [68]:
predictions, nn_idxs, nn_dists, stats = ca.get_predictions_knn(
    full_dataset.obsm['X_SCimilarity_image_embeds'],
    weighting=True,
)
full_dataset.obs['SCimilarity_image_cell_type'] = predictions.to_numpy()
full_dataset.obs['SCimilarity_image_min_dist'] = stats['min_dist'].to_numpy()

Get nearest neighbors finished in: 0.058790818850199385 min


100%|██████████| 406875/406875 [00:25<00:00, 16105.43it/s]


## Save Dataset

In [69]:
full_dataset.write_h5ad(f'{out_dir}xenium_ovarian_cancer_full.h5ad')