## A3 - Data preprocessing for training and benchmark

This notebook demonstrates how to prepare paired ST-image for both model training, benchmarking and inference with ST. If you only have H&E images and wish to run inference, please follow the steps outlined in the README file directly.

In [None]:
import pandas as pd
import numpy as np
import h5py
import scanpy as sc
import os
from scipy.stats import zscore

The input data of PASTA is fully compatible with formats used in the [HEST](https://github.com/mahmoodlab/HEST) library. You can directly download the pre-processed data from HEST or adhere to its processing pipeline to prepare your own datasets. The data structure should be structured as follows:
```
data_train/
├── meta.csv                    # Sample metadata
├── patches/                    # H5 files with patches
│   ├── sample001.h5
│   ├── sample002.h5
│   └── ...
├── st/
│   ├── sample001.h5ad          # ST h5ad files
│   ├── sample002.h5ad 
│   └── ...
├── gene/
│   ├── sample001.csv           # Gene expression
│   ├── sample002.csv 
│   └── ...
├── pathway/
│   ├── sample001.csv           # Pathway scores
│   ├── sample002.csv 
│   └── ...
├── wsis/
│   ├── sample001.tif           # HE images
│   ├── sample002.tif 
│   └── ...
```

#### (Optional) Extract auxiliary tiles

In [None]:
tile_h5_path = f'/workspace/data/HEST/patches/TENX156.h5'

In [None]:
with h5py.File(tile_h5_path, 'r') as f:
    print(f.keys())
    coords = f['coords'][:]
    x_min, y_min = coords.min(axis=0)
    x_max, y_max = coords.max(axis=0)
    print(f'x min:{x_min}, y min:{y_min}, x max:{x_max}, y max:{y_max}')
    
    img_attrs = dict(f['img'].attrs)
    if 'patch_size' in img_attrs.keys():
        patch_size = img_attrs['patch_size']
    elif 'patch_size_src' in img_attrs.keys():
        patch_size = img_attrs['patch_size_src']
    else:
        print('No patch size attrs in h5 file, please set patch_size')

<KeysViewHDF5 ['barcode', 'coords', 'img']>
x min:555, y min:12360, x max:24397, y max:36669


In [None]:
coords_new = []; imgs_new=[]
step = patch_size//2
for i in range(x_min, x_max+step, step):
    for j in range(y_min, y_max+step, step):
        coords_new.append([i,j])
        patch_tmp = wsi.read_region((int(i-step),int(j-step)), 0, (patch_size, patch_size)).convert('RGB').resize((224, 224))
        patch_tmp_np = np.array(patch_tmp)
        imgs_new.append(patch_tmp_np)
coords_new = np.array(coords_new)
imgs_new = np.stack(imgs_new)

In [None]:
with h5py.File(f'/workspace/data/HEST/patches/TENX156_aux.h5', 'w') as f:
    f.create_dataset('coords', data=coords_new)
    f.create_dataset('img', data=imgs_new)
    for key, value in img_attrs.items():
        f['img'].attrs[key] = value

#### Pathway scores file preparation

If you wish to train a model focusing on specific pathways, you can refer to the following code.

In [None]:
!pip install gseapy

In [None]:
import gseapy as gp

In [None]:
gmt_path = '/workspace/pasta/configs/gmt/Tumor_gene_sets.gmt'
gmt = gp.base.GSEAbase()
gmt = gmt.parse_gmt(gmt_path)
pathway_names = list(gmt.keys())

In [None]:
def ssgsea_pathway(exp: pd.DataFrame, gmt, pathway_name:list) -> pd.DataFrame:
    """Run ssGSEA and return pathway scores."""
    if 'Name' not in exp.columns:
        exp.insert(0, 'Name', exp.index)
    exp = exp.drop_duplicates()
    ss = gp.ssgsea(data=exp,
                   gene_sets=gmt,
                   min_size=1,
                   outdir=None,
                   sample_norm_method='rank', # choose 'custom' will only use the raw value of `data`
                   no_plot=True,
                   processes=10)
    ss_df = ss.res2d.pivot(index='Term', columns='Name', values='NES')
    if exp.shape[1]>1:
        ori_idx = exp.columns[1:]
        return ss_df.loc[:,ori_idx]
    if ss_df.shape[0]<len(pathway_name):
        no_results = [col for col in pathway_name if col not in ss_df.columns]
        print(f'No results of {no_results}')
    return ss_df

def calculate_pathways_score(exp: pd.DataFrame, pathway_name:list, gmt, agg=False, file_id=None):
    '''Calculate curated pathways scores'''
    pathway_df = ssgsea_pathway(exp, gmt, pathway_name)
    pathway_df = pathway_df.astype(np.float32)
    if agg:
        map_dict = {'Cell_Cycle':['Cell_Cycle_G1/S', 'Cell_Cycle_G2/M']}

        map_reverse_dict = {v: k for k, values in map_dict.items() for v in values}
        
        pathway_df = pathway_df.rename(index=map_reverse_dict).groupby(level=0).mean()
    
    return pathway_df.T

In [None]:
file_list = ['NCBI785']

In [None]:
os.makedirs(f"/workspace/data/HEST/pathway/", exist_ok=True)
for sample in file_list:
    adata = sc.read_h5ad(f'/workspace/data/HEST/st/{sample}.h5ad')
    pathway_df = pd.DataFrame(0, index=adata.obs.index, columns=pathway_names)
    pathway = calculate_pathways_score(adata.to_df().T, gmt, pathway_names)
    pathway = pathway.astype(np.float32).T
    pathway_df.to_csv(f'/workspace/data/HEST/pathway/{sample}.csv')

You can set info_dir in train.py to either the gene/ or pathway/ directory, depending on whether you want to use gene expression or pathway scores as input.

#### Aggregate Xenium or Visium HD data into different granularity

Functions for pooling [Xenium](https://github.com/mahmoodlab/HEST/blob/3c39afa2bd8d407db904e285021466292a2be09e/src/hest/readers.py#L1068) or [Visium HD](https://github.com/mahmoodlab/HEST/blob/3c39afa2bd8d407db904e285021466292a2be09e/src/hest/readers.py#L1160) data is accessible via the HEST. Here, we offer a basic implementation of that process.

In [None]:
import hest
from hest import iter_hest
from hest.readers import XeniumReader
from hest.readers import pool_transcripts_xenium,pool_bins_visiumhd
from hest.utils import find_pixel_size_from_spot_coords

file_list = ['NCBI785']
tot = len(file_list)
for idx, st in enumerate(iter_hest('/workspace/data/HEST/', id_list=file_list, load_transcripts=True)):
    for spot_size in [16, 32, 64]:
        adata_hd_um = pool_transcripts_xenium(st.transcript_df, pixel_size_he=st.meta["pixel_size_um_estimated"], spot_size_um=spot_size)
        os.makedirs(f"/workspace/data/benchmark/st_agg/{spot_size}um/", exist_ok=True)
        adata_hd_um.write_h5ad(f"/workspace/data/benchmark/st_agg/{spot_size}um/{file_list[idx]}.h5ad")
        print(f"[{idx}/{tot}] Finish {file_list[idx]} of {spot_size} um spots!")