# Generate a Train, Val and Test split of the H5SC single-cell image dataset for ViTMAE and SCimilarity training

In [None]:
import spatialdata
import scportrait
import os
import numpy as np
import csv
import pandas as pd

import random
seed = 42
rng = random.Random(seed)

  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
# define paths to datasources and outputs
path_sdata = "../processed_data/scPortrait_project_xenium/scportrait.sdata"
h5sc_path = "../processed_data/scPortrait_project_xenium/extraction/data/single_cells.h5sc"

output_folder = "../processed_data/test_val_datasets"
os.makedirs(output_folder, exist_ok = True)

## Read SpatialData object generated by scPortrait

In [3]:
sdata = spatialdata.read_zarr(path_sdata)

version mismatch: detected: RasterFormatV02, requested: FormatV04
version mismatch: detected: RasterFormatV02, requested: FormatV04
version mismatch: detected: RasterFormatV02, requested: FormatV04
version mismatch: detected: RasterFormatV02, requested: FormatV04


## Read single-cell images

In [4]:
h5sc = scportrait.io.read_h5sc(h5sc_path)

## Get all cell-ids from dataset

In [5]:
# some cells potentially were not extracted to single-cell images because too close to image edges -> remove those ids
all_cell_ids = set(sdata["table"].obs["scportrait_cell_id"].tolist()).intersection(set(h5sc.obs.scportrait_cell_id.tolist()))

## Get cell-ids in selected test region

In [6]:
#define center of selected region and size of the selection
select_region=(40300, 21650)
max_width = 1500

In [7]:
# generate a subsetted sdata object containing only this information
from scportrait.tl.sdata.pp import get_bounding_box_sdata

_sdata = get_bounding_box_sdata(sdata, max_width, select_region[0], select_region[1])
_sdata.write("../processed_data/overview_region.sdata.zarr", overwrite=True)

[34mINFO    [0m The SpatialData object is not self-contained [1m([0mi.e. it contains some elements that are Dask-backed from    
         locations outside ..[35m/processed_data/[0m[95moverview_region.sdata.zarr[0m[1m)[0m. Please see the documentation of          
         `[1;35mis_self_contained[0m[1m([0m[1m)[0m` to understand the implications of working with SpatialData objects that are not     
         self-contained.                                                                                           
[34mINFO    [0m The Zarr backing store has been changed from [3;35mNone[0m the new file path:                                      
         ..[35m/processed_data/[0m[95moverview_region.sdata.zarr[0m                                                              


In [8]:
region_cell_ids = set(np.unique(_sdata["seg_all_cytosol"].scale0.image.compute().values)[1:]) #exclude 0

## Generate Test, Val and Train sets

In [9]:
val_percentage = 0.05
test_percentage = 0.05
val_size = np.ceil(len(all_cell_ids) * val_percentage)
test_size = np.ceil(len(all_cell_ids) * test_percentage)
train_size = len(all_cell_ids) - val_size - test_size

print("val size:", val_size)
print("test_size:", test_size)
print("train_size:", train_size)

val size: 20344.0
test_size: 20344.0
train_size: 366187.0


In [10]:
# assemble test set
remaining_cell_ids = all_cell_ids - region_cell_ids
n = int(test_size - len(region_cell_ids))  # replace with desired sample size
test_set_remaining = set(rng.sample(list(remaining_cell_ids), n))
test_set = region_cell_ids | test_set_remaining

# assemble validation set
remaining_cell_ids = all_cell_ids - test_set
n = int(val_size)
val_set = set(rng.sample(list(remaining_cell_ids), n))

#train set
train_set = all_cell_ids - test_set - val_set

In [11]:
# convert to index locations
test_set_indexes = scportrait.tl.h5sc.get_image_index(h5sc, test_set)
val_set_indexes = scportrait.tl.h5sc.get_image_index(h5sc, val_set)
train_set_indexes = scportrait.tl.h5sc.get_image_index(h5sc, train_set)

In [12]:
#write out to csv files
def write_set_to_csv(cell_ids, filename):
    with open(filename, 'w', newline='') as f:
        writer = csv.writer(f)
        for cell_id in cell_ids:
            writer.writerow([cell_id])

# Write each set to its own CSV file
write_set_to_csv(train_set, f'{output_folder}/train_set_cell_ids.csv')
write_set_to_csv(val_set, f'{output_folder}/val_set_cell_ids.csv')
write_set_to_csv(test_set, f'{output_folder}/test_set_cell_ids.csv')

write_set_to_csv(train_set_indexes, f'{output_folder}/train_set_indexes.csv')
write_set_to_csv(val_set_indexes, f'{output_folder}/val_set_indexes.csv')
write_set_to_csv(test_set_indexes, f'{output_folder}/test_set_indexes.csv')

## Test H5SCSingleCellDataset with the provided cell_id lists

In [13]:
# using the H5SCSingleCellDataset with a provided cell_id list
from scportrait.tools.ml.datasets import H5ScSingleCellDataset

val_set_indexes = pd.read_csv(f'{output_folder}/val_set_indexes.csv', header = None)[0].tolist()
print("Number of val indexes:", len(val_set_indexes))
val_dataset = H5ScSingleCellDataset([h5sc_path], 
                                [0], 
                                index_list=[val_set_indexes], 
                                select_channel = None)

Number of val indexes: 20344
Total single cell records: 20344
