In [1]:
import scanpy as sc
import numpy as np
import torch

# label_mapping & reverse_mapping match

In [2]:
def label_mapping_test(
    data_name,
    label_keys,
    ):
    
    adata_path = f'../../Datasets/preprocessed_datasets/{data_name}.h5ad'
    adata = sc.read_h5ad(adata_path)
    try:
        adata.obs['sc_cell_ids'] = adata.obs['sc_cell_ids'].astype('category')
    except:
        pass

    cell_ids = list(range(len(adata.obs.index.tolist())))
    adata.obs['cell_ids'] = cell_ids

    label_mapping = {}
    reverse_label_mapping = {}

    label_data_by_key = {}
    cell_id_label_mapping = {}

    for key in label_keys:
        label_series = adata.obs[key].astype('category')
        
        label_data_by_key[key] = torch.tensor(
            label_series.cat.codes.values, 
            dtype=torch.long
        )

        label_mapping[key] = dict(
			enumerate(label_series.cat.categories.tolist())
			)

        reverse_label_mapping[key] = {v:k for k,v in label_mapping[key].items()}

    for idx, cell_id in enumerate(cell_ids):
        cell_id_label_mapping[cell_id] = {}
        for key in label_keys:
            cell_id_label_mapping[cell_id][key] = label_data_by_key[key][idx].item()


    # ASSERTIONS ----------------
    for key in label_keys:
        # Assert label_mapping & reverse label_mapping
        unique_covs = adata.obs[key].unique().tolist()
        for cov in unique_covs:
            mask_cov = (adata.obs[key] == cov)
            assert label_mapping[key][reverse_label_mapping[key][cov]] == adata.obs[key].values[mask_cov][0]
            assert reverse_label_mapping[key][cov] == adata.obs[key].cat.codes.values[mask_cov][0]
            
        # Assert label_data_by_key
        assert label_data_by_key[key].numpy().tolist() == adata.obs[key].cat.codes.values.tolist()

        # Assert cell_id_label_mapping
        for row_id in range(adata.shape[0]):
            assert cell_id_label_mapping[row_id][key] == adata.obs[key].cat.codes.values[row_id]
            
    # Assert cell_ids
    assert adata.obs['cell_ids'].values.tolist() == list(range(adata.shape[0]))

    return f'all unit tests succeeded for {data_name}'

In [3]:
label_mapping_test(
    data_name = 'kang',
    label_keys = ['condition', 'cell_type', 'sc_cell_ids']
)

'all unit tests succeeded for kang'

In [4]:
label_mapping_test(
    data_name = 'liver',
    label_keys = ['status_control', 'zone', 'infected', 'time', 'coarse_time']
)

'all unit tests succeeded for liver'

In [5]:
label_mapping_test(
    data_name = 'liver',
    label_keys = ['status_control', 'zone', 'infected', 'time', 'coarse_time']
)

'all unit tests succeeded for liver'

In [6]:
label_mapping_test(
    data_name = 'norman',
    label_keys = ['condition', 'perturbation1', 'perturbation2', 'sc_cell_ids']
)

'all unit tests succeeded for norman'

In [8]:
label_mapping_test(
    data_name = 'prostate',
    label_keys = ['time', 'batchID', 'predType']
)

'all unit tests succeeded for prostate'

In [9]:
label_mapping_test(
    data_name = 'tabula',
    label_keys = ['specie', 'cell_type', 'sc_cell_ids', 'donor_id', 'sex']
)

'all unit tests succeeded for tabula'

In [10]:
label_mapping_test(
    data_name = 'myocarditis',
    label_keys = ['donor', 'tissue', 'on_steroids', 'sc_cell_ids', 'cell_type']
)

'all unit tests succeeded for myocarditis'

In [11]:
label_mapping_test(
    data_name = 'leukemia',
    label_keys = ['Sample_id', 'cell_type', 'sc_cell_ids']
)

'all unit tests succeeded for leukemia'

In [12]:
label_mapping_test(
    data_name = 'seurat',
    label_keys = ['donor', 'time', 'celltype.l1', 'celltype.l2', 'celltype.l3']
)

'all unit tests succeeded for seurat'

In [13]:
label_mapping_test(
    data_name = 'kang',
    label_keys = ['condition', 'cell_type']
)

'all unit tests succeeded for kang'