# Saving final dataframe with all the information regarding the clusters (ID, label and category $\in \{ healthy, tumor\}$) WITH RAW COUNTS (to be used later for scvi)

In [12]:
import os
import tempfile
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import numpy as np
import h5py
from tqdm import tqdm
from copy import deepcopy
from tqdm import tqdm


all_clonelabels = ['healthy','tumor']

gene_set_collections = np.array(['c6','hallmarks', 'c2_pid', 'gene','geneOncoKB'])
gene_set_collection = gene_set_collections[4]

clonemodes = np.array(['scatrex','phenograph'])
clonemode = clonemodes[0]
path_scatrex = '/data/users/04_share_reanalysis_results/melanoma_2025/new_scatrex/'
pathgsva = '/data/users/04_share_reanalysis_results/02_melanoma/04_metacells_gsva/{0}/'.format('gene' if ('gene' in gene_set_collection) else gene_set_collection)
path_MC = '/data/users/04_share_reanalysis_results/02_melanoma/03_metacells_atypical_removed/'
pathsave = '/data/users/04_share_reanalysis_results/melanoma_2025/02_atypical_removed_preprocessing/metacells_{0}_{1}_rawcounts/'.format(gene_set_collection, clonemode)
path_info_cohort = '/data/users/melanoma_sample.txt'



def find_idx(string, array):
    i = 0
    while not(string in array[i]):
        i += 1
    return i

In [6]:
KEY_TUMOR = 'Melanoma'
def celltype2tumor(celltype):
    return (KEY_TUMOR in celltype)

def celltype2category(celltype):
    if celltype2tumor(celltype):
        return 'tumor'
    else:
        return 'healthy'
        
def label2category(label):
    if label=='healthy':
        return 'healthy'
    else:
        return 'tumor'
    
def celltype2label(celltype):
    if KEY_TUMOR in celltype:
        return 'tumor'
    else:
        return 'healthy'

In [7]:
def filter_reorder(data, data_cellIDs, ref_cellIDs):
    dfmod = pd.DataFrame(data.reshape(1,-1), columns=data_cellIDs)
    dfmod = dfmod[ref_cellIDs]
    mod = dfmod.to_numpy().reshape(-1)
    return mod

def get_sample_names(use_scatrex_clone=False, gene_feature=False):
    """
    Get the list of the name of samples that will be considered for the analysis.
    We take the intersection of the samples for which required data is available (GSVA, SCATrEX clones, ...),
    and we remove some problematic samples.
    """
    samples_names = os.listdir(path_MC)
    samples_names = np.unique([sample[:sample.find('_')] for sample in samples_names if 'h5' in sample])
    if not(gene_feature):
        samples_gsva = os.listdir(pathgsva)
        samples_gsva = [sample[:sample.find('_')] for sample in samples_gsva]
        samples_names = np.intersect1d(samples_names, samples_gsva)
    if use_scatrex_clone:
        samples_scatrex = os.listdir(path_scatrex)
        samples_scatrex = [sample[:sample.find('_')] for sample in samples_scatrex]
        samples_names = np.intersect1d(samples_names, samples_scatrex)
    

    # Load name of samples that should be removed (data not reliable)
    dfsamples = pd.read_csv(path_info_cohort, sep="\t")
    samples_removed = dfsamples.loc[dfsamples['part_cohort_analyses']==False]['sampleID'].to_numpy()
    for i,sample in enumerate(samples_removed):
        samples_removed[i] = sample

    # Remove the sample that is an outlier (too many cells)
    sample_names = [sample for sample in samples_names if sample!='MATIWAQ-T']
    # Remove the other samples identified as not reliable
    sample_names = [sample for sample in sample_names if not(sample in samples_removed)]


    dfinfo = pd.read_csv(path_info_cohort, sep='\t')
    dfinfo = dfinfo[dfinfo['sampleID'].apply(lambda x: x in sample_names).values]
    dfinfo = dfinfo.set_index('sampleID')
    print('Total number of samples: ', len(sample_names))
    return sample_names

def get_scatrex_clone(sample, cellsIDs_rna):
    """
    Get the clone IDs from SCATrEX. Note that SCATrEX is only run on tumor cells and thus we assign the label 0
    to the non-malignant cells (and we make sure the SCATrEX subclones do not also have the label 0).
    """
    filename = os.path.join(path_scatrex, "{0}__scatrex.h5ad".format(sample))
    with h5py.File(filename, "r") as f:
        subclones = f['obs']['scatrex_obs_node']['codes'][()]
        cellnames = f['obs']['cell_names'][:]
        cellnames = np.array([cname.decode("utf-8") for cname in cellnames])

    # Loading the mapping from single cells to SEAcells
    data = pd.read_csv (os.path.join(path_MC,"{0}.genes_cells_filtered_seacells_hard_assignment.tsv".format(sample)), sep = '\t')
    seacells = np.unique(data['SEACell'])
    seacellnames = data['SEACell'].to_numpy()
    nbcellsfound = 0
    seacell2subclone = {}
    for seacell in seacells:
        seacell2subclone[seacell] = 0
        idxcell = np.where(seacellnames==seacell)[0][0]
        #seacelltype = seacelltypes[idxcell]
        df = data[data['SEACell']==seacell]
        for cellID in df['archetype_barcode'].values:
            idxs = np.where(cellnames==cellID)[0]
            if (len(idxs))>=1:
                nbcellsfound += len(idxs)
                # +1 since only tumor SEAcell have been used in SCATrEX and 0 is for the non malignant cells
                seacell2subclone[seacell] = list(subclones[idxs])[0]+1
    return {cell: subcloneID for cell, subcloneID in seacell2subclone.items() if cell in cellsIDs_rna}
    
    
def get_cellIDs(sample, use_scatrex_clone=False):
    """
    Get the set of MetaCells that we can consider for the analysis. Indeed, if we use SCATrEX subclones,
    SCATrEX has a pre-filtering step that remove some part of the tumor cells (and therefore we need to get rid of them).
    """
    filename = os.path.join(path_MC, "{0}_seacells.atypical_removed.h5".format(sample))
    with h5py.File(filename, "r") as f:
        cellnames = f['cell_attrs']['cell_names'][:]
        cellnames = np.array([cname.decode("utf-8") for cname in cellnames])
        seacelltypes = np.array([el.decode("utf-8") for el in f['cell_attrs']['celltype_final'][:]])
    if use_scatrex_clone:
        filename = os.path.join(path_scatrex, "{0}__scatrex.h5ad".format(sample))
        with h5py.File(filename, "r") as f:
            cellnames_scatrex = f['obs']['cell_names'][:]
            cellnames_scatrex = np.array([cname.decode("utf-8") for cname in cellnames_scatrex])
        cellnames_inter = []
        for i, seatype in enumerate(seacelltypes):
            if not(KEY_TUMOR in seatype): # SCATrEX is run only on Melanoma cells
                cellnames_inter.append(cellnames[i])
            else: # SCATrEX has considered only some of the tumor cells (filtering step)
                if cellnames[i] in cellnames_scatrex:
                    cellnames_inter.append(cellnames[i])
    else:
        return cellnames

def get_selected_genes(sample_names, gene_set_collection):
    """
    Return the set of features (genes) that will be used for the analysis.
    """
    df = None
#     for idsample, sample in tqdm(enumerate(sample_names)):
#         filename = os.path.join(path_MC, "{0}_seacells.atypical_removed.h5".format(sample))

#         with h5py.File(filename, "r") as f:
#             res_var = f['gene_attrs']['residual_variance'][:]
#             gene_ids = f['gene_attrs']['gene_ids'][:]
#             gene_ids = np.array([gene_id.decode("utf-8") for gene_id in gene_ids])
#             if df is None:
#                 df = pd.DataFrame({gene_id:[res_var[i]] for i, gene_id in enumerate(gene_ids)})
#             else:
#                 df_temp = pd.DataFrame({gene_id:[res_var[i]] for i, gene_id in enumerate(gene_ids)})
#                 df = pd.concat([df, df_temp], ignore_index=True)
#     # Drop columns where more than 50% of values are NaN
#     df_cleaned = df.dropna(axis=1, thresh=len(df) * 0.1)
#     means_res_var = df_cleaned.mean(axis=0, skipna=True)
#     idxs = np.flip(np.argsort(means_res_var))
#     selected_genes = df_cleaned.columns[idxs[:500]]

    # Initialize an empty dictionary to store the largest residual variance per gene
    if gene_set_collection=="geneOncoKB":
        
        # Extracting the ENSG codes of all genes in the OncoKB Cancer Gene List
        from pyensembl.shell import collect_all_installed_ensembl_releases
        from pyensembl import EnsemblRelease
        #!pyensembl install --release 104 --species homo_sapiens
        collect_all_installed_ensembl_releases()
        genome = EnsemblRelease(release=104, species='homo_sapiens')
        data=pd.read_csv('../cancerGeneList.tsv',sep='\t')
        gene_names = data['Hugo Symbol'].unique()
        gene_ids = []
        for gene_name in gene_names:
            try:
                gene_ids.append(genome.gene_ids_of_gene_name(gene_name))
            except:
                print(gene_name)
        gene_idsOnco = np.array(gene_ids).reshape(-1)
        
        # Counting for each gene the number of samples for which we have data
        all_gene_ids = []
        dic_all_gene_ids = {}
        for idsample, sample in tqdm(enumerate(sample_names)):
            filename = os.path.join(path_MC, f"{sample}_seacells.atypical_removed.h5")

            with h5py.File(filename, "r") as f:
                res_var = f['gene_attrs']['residual_variance'][:]
                gene_ids = f['gene_attrs']['gene_ids'][:]
                gene_ids = np.array([gene_id.decode("utf-8") for gene_id in gene_ids])
                all_gene_ids = list(set(list(set(gene_ids))+list(set(all_gene_ids))))
                for gene in gene_ids:
                    dic_all_gene_ids[gene] = dic_all_gene_ids.get(gene, 0) + 1

        # Defining the final gene set considered
        selected_genes = []
        threshold = int(0.95*len(sample_names))
        for gene in gene_idsOnco:
            if (gene in (all_gene_ids)):
                if dic_all_gene_ids[gene]>=threshold:
                    selected_genes.append(gene)
        
    else:
        gene_variance_dict = {}

        for idsample, sample in tqdm(enumerate(sample_names)):
            filename = os.path.join(path_MC, f"{sample}_seacells.atypical_removed.h5")

            with h5py.File(filename, "r") as f:
                res_var = f['gene_attrs']['residual_variance'][:]
                gene_ids = f['gene_attrs']['gene_ids'][:]
                gene_ids = np.array([gene_id.decode("utf-8") for gene_id in gene_ids])

                for gene_id, var in zip(gene_ids, res_var):
                    if gene_id not in gene_variance_dict or var > gene_variance_dict[gene_id]:
                        gene_variance_dict[gene_id] = var

        # Convert the dictionary to a DataFrame
        df = pd.DataFrame(list(gene_variance_dict.items()), columns=['Gene', 'Max_Residual_Variance'])

        # Sort genes by maximum residual variance
        selected_genes = np.array(df.nlargest(500, 'Max_Residual_Variance')['Gene'])
    return selected_genes

def get_feat_celltype_genefeat(sample, cellsIDs, selected_genes):
    """
    Return the feature matrix and the celltypes when working with genes as features.
    """
    df = None
    filename = os.path.join(path_MC, "{0}_seacells.atypical_removed.h5".format(sample))

    with h5py.File(filename, "r") as f:
        X = f['raw_counts'][:]
        cell_names = f['cell_attrs']['cell_names'][:]
        gene_ids = f['gene_attrs']['gene_ids'][:]
        cell_names = np.array([cname.decode("utf-8") for cname in cell_names])
        gene_ids = np.array([gene_id.decode("utf-8") for gene_id in gene_ids])
        seacelltypes = f['cell_attrs']['celltype_final'][()]
        seacelltypes = np.array([el.decode("utf-8") for el in f['cell_attrs']['celltype_final'][:]])

    
    dfsample = pd.DataFrame(X, columns=gene_ids)
    dfsample.set_index(cell_names, inplace=True)
    dfsample = dfsample.loc[cellsIDs].reindex(columns=selected_genes) 
    dfcelltypes = pd.DataFrame(seacelltypes.reshape(1,-1), columns=cell_names)
    dfcelltypes = dfcelltypes[cellsIDs]
    return dfsample.to_numpy(), dfcelltypes.to_numpy().reshape(-1)
    
    
def save_data(sample_names, gene_set_collection, use_scatrex_clone=False):
    """
    Save the data file for each sample.
    """
    print('Step 1/3: computing set of features')
    gene_feature = 'gene' in gene_set_collection
    if gene_feature:
        selected_genes = get_selected_genes(sample_names, gene_set_collection)
        feature_names = selected_genes
    else:
        set_pathways = []
        for sample in sample_names:
            df_gsva = pd.read_csv(os.path.join(pathgsva, '{0}_seacells_celltype_GSVA.tsv'.format(sample)), sep='\t')
            if len(set_pathways)==0:
                set_pathways = df_gsva['gene_set'].to_numpy()
            else:
                set_pathways = np.union1d(set_pathways, df_gsva['gene_set'].to_numpy())
        pathway2index = {pathway:i for i, pathway in enumerate(set_pathways)}
        feature_names = set_pathways
    
    print('Step 2/3: preprocessing data for each sample')
    for id_sample, sample in tqdm(enumerate(sample_names)):
        cellsIDs_rna = get_cellIDs(sample)
        if use_scatrex_clone:
            seacell2subclone = get_scatrex_clone(sample, cellsIDs_rna)
            clustersini = []
            for seacellID in cellsIDs_rna:
                clustersini.append(int(seacell2subclone[seacellID]))
            clustersini = np.array(clustersini)
        else:
            with h5py.File(os.path.join(path_MC, '{0}_seacells.atypical_removed.h5'.format(sample)), 'r') as f:
                cell_names = f['cell_attrs']['cell_names'][:]
                cell_names = np.array([cname.decode("utf-8") for cname in cell_names])
                clustersini = np.array(f['cell_attrs']["phenograph_clusters"][:]).astype(int)
                clustersini = filter_reorder(clustersini, cell_names, cellsIDs_rna)

        if gene_feature:
            features, celltypes = get_feat_celltype_genefeat(sample, cellsIDs_rna, selected_genes)
        else:
            with h5py.File(os.path.join(path_MC, '{0}_seacells.atypical_removed.h5'.format(sample)), 'r') as f:
                cell_names = f['cell_attrs']['cell_names'][:]
                cell_names = np.array([cname.decode("utf-8") for cname in cell_names])
                celltypes = np.array([el.decode("utf-8") for el in f['cell_attrs']['celltype_final'][:]])
                celltypes = filter_reorder(celltypes, cell_names, cellsIDs_rna)
            
            df_gsva = pd.read_csv(os.path.join(pathgsva, '{0}_seacells_celltype_GSVA.tsv'.format(sample)), sep='\t')
            features = np.zeros((len(cellsIDs_rna), len(set_pathways)))
            for idx_cell, cellID in enumerate(cellsIDs_rna):
                df_gsva_cell = df_gsva[df_gsva['barcode']==cellID]
                gene_sets = np.unique(df_gsva_cell['gene_set'].to_numpy())
                missing_pathways = list(filter(lambda x: x not in gene_sets, set_pathways))
                missing_barcode = [cellID for l in range(len(missing_pathways))]
                missing_value = [0 for l in range(len(missing_pathways))]
                df_missing = pd.DataFrame({'gene_set':missing_pathways, 
                                           'barcode': missing_barcode,
                                           'value': missing_value
                                          })
                df_gsva_cell = pd.concat([df_gsva_cell, df_missing], axis=0)
                df_gsva_cell = df_gsva_cell.set_index('gene_set')
                df_gsva_cell = df_gsva_cell.sort_index(key=lambda x: x.map(pathway2index))
                features[idx_cell,:] = df_gsva_cell['value'].to_numpy()
    
        data = np.array(cellsIDs_rna)
        data = np.concatenate((data[:,None], features), axis=1)
        data = np.concatenate((data, celltypes[:,None]), axis=1)
        cellcategories = np.array([celltype2category(celltype) for celltype in celltypes])
        data = np.concatenate((data, cellcategories[:,None]), axis=1)
        data = np.concatenate((data, clustersini[:,None]), axis=1)

        colnames = ['cell_id'] + ['dim_{0}_{1}'.format(i+1,gene_set) for i,gene_set in enumerate(feature_names)]
        colnames += ['celltype', 'cellcategory', 'initial_cloneID']
        df = pd.DataFrame(data, columns=colnames)
        df = df.set_index('cell_id') 

        temp_df = df[['initial_cloneID','celltype']].groupby(['initial_cloneID','celltype']).size()
        all_cloneIDs = np.unique(temp_df.index.get_level_values('initial_cloneID').values)
        
        if not(use_scatrex_clone):
            cloneID2clonetype = {}
            if False: # take the dominant celltype
                for cloneID in all_cloneIDs:
                    cloneID2clonetype[cloneID] = temp_df.loc[cloneID].idxmax()
            else:
                for cloneID in all_cloneIDs:
                    # Get the subset for the current cloneID
                    clone_counts = temp_df.loc[cloneID]

                    # Separate into two groups
                    melanoma_group = clone_counts[clone_counts.index.str.contains(KEY_TUMOR, case=False, na=False)]
                    other_group = clone_counts[~clone_counts.index.str.contains(KEY_TUMOR, case=False, na=False)]

                    # Sum the counts in each group
                    melanoma_count = melanoma_group.sum() if not melanoma_group.empty else 0
                    other_count = other_group.sum() if not other_group.empty else 0

                    # Determine the assigned cell type
                    if melanoma_count > other_count:
                        cloneID2clonetype[cloneID] = KEY_TUMOR
                    else:
                        cloneID2clonetype[cloneID] = other_group.idxmax() if not other_group.empty else 'Unknown'
        else:
            cloneID2clonetype = {cloneID:KEY_TUMOR for cloneID in all_cloneIDs}
            cloneID2clonetype[0] = 'non malignant'
        clonetype_of_cells = [cloneID2clonetype[cloneID] for cloneID in (df['initial_cloneID'].values.copy())]
        inicloneID2cluster = {inicloneID:i for i, inicloneID in enumerate(np.unique(df['initial_cloneID'].values))}
        df_c = pd.DataFrame({
                            'cell_id': cellsIDs_rna,
                            'clonetype': clonetype_of_cells,
                            'clonelabel': [celltype2label(clonetype) for clonetype in clonetype_of_cells],
                             'clonecategory': [celltype2category(clonetype) for clonetype in clonetype_of_cells],
                             'cloneID': [inicloneID2cluster[inicloneID] for inicloneID in df['initial_cloneID'].values]
                            })
        df_c = df_c.set_index('cell_id') 
        df = pd.concat([df, df_c], axis=1)

        if (not(use_scatrex_clone) and id_sample==0):
            # just a sanity check: The way I compute the celltype with the largest number of cells in each cluster should match the 
            # major celltypes in each cluster provided by Anne and Franziska
            df_clusters_data = pd.read_csv(os.path.join(path_MC,'{0}_seacells.atypical_removed.phenograph_celltype_association.txt'.format(sample)), sep='\t')
            print(df_clusters_data)
            df_clusters_data = df_clusters_data.set_index('Cluster')
            print(df_clusters_data.index)
            for cloneID in np.unique(df['initial_cloneID'].values):
                print(cloneID, ' ', cloneID2clonetype[cloneID], ' ', df_clusters_data.loc[str(cloneID),]['Dominant.celltype'])
        df.to_csv(os.path.join(pathsave,'sample2data/{0}.csv'.format(sample)), index=True)    

In [8]:
def preprocess_data(sample_names, gene_set_collection, use_scatrex_clone=False):
    """
    Main function to preprocess data.
    """
    gene_feature = ('gene' in gene_set_collection)
    save_data(sample_names, gene_set_collection, use_scatrex_clone=use_scatrex_clone)
    # Find the max number of clusters within each labels across all samples
    clonelabel2max_nb_cluster = {label: 0 for label in all_clonelabels}
    for id_sample, sample in enumerate(tqdm(sample_names)):
        df = pd.read_csv(pathsave+'sample2data/{0}.csv'.format(sample), index_col=0)
        for label in all_clonelabels:
            sample_nb_clusters_with_label = len(np.unique(df[df['clonelabel']==label]['cloneID']))
            clonelabel2max_nb_cluster[label] = max([clonelabel2max_nb_cluster[label], sample_nb_clusters_with_label])    
    print(clonelabel2max_nb_cluster)
    print('Step 3/3: grouping cloneIDs by label and saving the clone_infos file')
    sample2cloneID2clonetype = {}
    for id_sample, sample in enumerate(tqdm(sample_names)):
        sample2cloneID2clonetype[sample] = {}
        df = pd.read_csv(pathsave+'sample2data/{0}.csv'.format(sample), index_col=0)
        start_cloneID = 0 
        for label in all_clonelabels:
            df_label = (df[df['clonelabel']==label]).copy()
            oldcloneIDs = np.unique(df_label['cloneID'])
            nb_clusters = len(oldcloneIDs)
            # correct cloneID
            oldcloneID2newcloneID = {oldcloneID:(start_cloneID+i) for i,oldcloneID in enumerate(oldcloneIDs)}
            df.loc[df['clonelabel'] == label, 'cloneID'] = df.loc[df['clonelabel'] == label, 'cloneID'].apply(lambda x: oldcloneID2newcloneID[x])

            for oldcloneID,cloneID in oldcloneID2newcloneID.items():
                sample2cloneID2clonetype[sample][cloneID] = df_label[df_label['cloneID']==oldcloneID]['clonetype'].iloc[0]
            start_cloneID += clonelabel2max_nb_cluster[label]
        df.to_csv(os.path.join(pathsave,'sample2data/{0}.csv'.format(sample)), index=True)
        
    # saving clone infos
    col_cloneID, col_label, col_cat = [], [], []
    count = 0
    for label in all_clonelabels:
        for i in range(clonelabel2max_nb_cluster[label]):
            col_cloneID.append(count)
            col_label.append(label)
            col_cat.append(label2category(label))
            count += 1

    dic_df = {'cloneID':col_cloneID, 'clonelabel':col_label, 'clonecategory':col_cat}
    for id_sample, sample in enumerate(tqdm(sample_names)):
        col_clonetype = []
        for cloneID in col_cloneID:
            col_clonetype.append(sample2cloneID2clonetype[sample].get(cloneID, None))
        dic_df["clonetype_{0}".format(sample)] = col_clonetype

    df = pd.DataFrame(dic_df)
    df = df.set_index('cloneID')
    df.to_csv(pathsave+'clone_infos.csv'.format(sample), index=True)

In [None]:
for clonemode in clonemodes[:1]:
    for gene_set_collection in gene_set_collections[[4]]:
        pathsave = '/data/users/04_share_reanalysis_results/melanoma_2025/02_atypical_removed_preprocessing/metacells_{0}_{1}_rawcounts/'.format(gene_set_collection, clonemode)
        pathgsva = '/data/users/04_share_reanalysis_results/02_melanoma/04_metacells_gsva/{0}/'.format('gene' if ('gene' in gene_set_collection) else gene_set_collection)
        import os
        if not os.path.exists(pathsave):
            os.makedirs(pathsave)
        if not os.path.exists(os.path.join(pathsave, 'sample2data')):
            os.makedirs(os.path.join(pathsave, 'sample2data'))
        print('Clone mode:', clonemode)
        print('Features:', gene_set_collection)
        sample_names = get_sample_names(use_scatrex_clone=(clonemode=='scatrex'), gene_feature=('gene' in gene_set_collection))
        preprocess_data(sample_names, gene_set_collection, use_scatrex_clone=(clonemode=='scatrex'))