In [20]:
import os
import random
import numpy as np
import pandas as pd
import scanpy as sc
import anndata as ad
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
from scipy.spatial import distance
from scipy.spatial import distance_matrix
from scipy.optimize import linear_sum_assignment
from sklearn.impute import KNNImputer
from sklearn.cluster import KMeans
from sklearn.metrics import mean_squared_error, f1_score
from sklearn.metrics.pairwise import cosine_similarity
from skimage.metrics import structural_similarity as ssim

import torch

In [21]:
def seed_everything(seed=2024):
    random.seed(seed)    # Python random module
    np.random.seed(seed) # Numpy module
    os.environ['PYTHONHASHSEED'] = str(seed) # Env variable
    
    torch.manual_seed(seed)  # Torch
    torch.cuda.manual_seed(seed)  # CUDA
    torch.cuda.manual_seed_all(seed)  # For multi-GPU setups
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
    print(f'Seeding all randomness with seed={seed}')

class GroundTruth:
    def __init__(self, hole_cells, gene_expression):
        self.hole_cells = hole_cells
        self.gene_expression = gene_expression

class TestArea:
    def __init__(self, hole_min_x, hole_max_x, hole_min_y, hole_max_y, dominant_tissue):
        self.hole_min_x = hole_min_x
        self.hole_max_x = hole_max_x
        self.hole_min_y = hole_min_y
        self.hole_max_y = hole_max_y
        self.dominant_tissue = dominant_tissue

class TestItem:
    def __init__(self, adata, ground_truth, test_area):
        self.adata = adata
        self.ground_truth = ground_truth
        self.test_area = test_area

In [46]:
seed_everything()
fold_dir = "/extra/zhanglab0/SpatialTranscriptomicsData/"
platform = "MERFISH"
dataset = "MouseBrainAging"
hole_size = 200
num_holes = 5

Seeding all randomness with seed=2024


In [84]:
adata = ad.read_h5ad(fold_dir + platform + "/" + dataset + "/2330673b-b5dc-4690-bbbe-8f409362df31.h5ad")
            
obs = adata.obs
donor_id_list = list(obs['donor_id'].unique())
all_test_items = []

In [87]:
for donor_id in donor_id_list:
    print(f'Donor_id: {donor_id}')
    donor_obs = obs[obs['donor_id'] == donor_id]
    donor_x = adata[obs['donor_id'] == donor_id]
    slice_list = list(donor_obs['slice'].unique())
    slice_list.sort()
    for slice_id in slice_list: 
        print(f'Slice_id: {slice_id}')
        slice_obs = donor_obs[donor_obs['slice'] == slice_id]
        slice_x = donor_x[donor_obs['slice'] == slice_id]
        
        slice_obs_df = pd.DataFrame(slice_obs)
        slice_obs_df['min_x'] = slice_obs_df['min_x'].astype(float)
        slice_obs_df['max_x'] = slice_obs_df['max_x'].astype(float)
        slice_obs_df['min_y'] = slice_obs_df['min_y'].astype(float)
        slice_obs_df['max_y'] = slice_obs_df['max_y'].astype(float)
        slice_obs_df['center_x'] = slice_obs_df['center_x'].astype(float)
        slice_obs_df['center_y'] = slice_obs_df['center_y'].astype(float)

        slice_obs_df['fov'] = slice_obs_df['fov'].cat.remove_unused_categories()
        
        fov_boundaries = slice_obs_df.groupby('fov').agg(
            min_x=('min_x', 'min'),
            max_x=('max_x', 'max'),
            min_y=('min_y', 'min'),
            max_y=('max_y', 'max')
        ).reset_index()

        fov_boundaries['center_x'] = (fov_boundaries['min_x'] + fov_boundaries['max_x']) / 2
        fov_boundaries['center_y'] = (fov_boundaries['min_y'] + fov_boundaries['max_y']) / 2

        for _ in range(num_holes):
            fov = fov_boundaries.sample(1).iloc[0]
            rand_center_x = fov['center_x'] + random.uniform(-hole_size / 4, hole_size / 4)
            rand_center_y = fov['center_y'] + random.uniform(-hole_size / 4, hole_size / 4)

            hole_min_x = rand_center_x - hole_size / 2
            hole_max_x = rand_center_x + hole_size / 2
            hole_min_y = rand_center_y - hole_size / 2
            hole_max_y = rand_center_y + hole_size / 2

            hole_cells = slice_obs_df[
                (slice_obs_df['center_x'] >= hole_min_x) & 
                (slice_obs_df['center_x'] <= hole_max_x) &
                (slice_obs_df['center_y'] >= hole_min_y) & 
                (slice_obs_df['center_y'] <= hole_max_y)
            ]

            if not hole_cells.empty:
                dominant_tissue = hole_cells['tissue'].value_counts().idxmax()

                hole_cells_index = hole_cells.index
                gene_expression = slice_x[slice_obs.index.isin(hole_cells_index)].X

                adata_copy = slice_x[~slice_obs.index.isin(hole_cells_index)]

                ground_truth = GroundTruth(hole_cells=hole_cells, gene_expression=gene_expression)

                test_area = TestArea(
                    hole_min_x=hole_min_x,
                    hole_max_x=hole_max_x,
                    hole_min_y=hole_min_y,
                    hole_max_y=hole_max_y,
                    dominant_tissue=dominant_tissue
                )

                test_item = TestItem(
                    adata=adata_copy,
                    ground_truth=ground_truth,
                    test_area=test_area
                )

                all_test_items.append(test_item)

Donor_id: MsBrainAgingSpatialDonor_1
Slice_id: 0
Donor_id: MsBrainAgingSpatialDonor_2
Slice_id: 0
Slice_id: 1
Donor_id: MsBrainAgingSpatialDonor_3
Slice_id: 0


KeyboardInterrupt: 