# Package Imports

In [1]:

import os 
import traceback
import pickle as pk
from pathlib import Path
from datasets import Dataset
import itertools
from sklearn.utils import shuffle
import seaborn as sns
from geneformer import TranscriptomeTokenizer
import copy
import argparse
import requests
import polars as pl
import matplotlib.pyplot as plt
import pandas as pd

# ML base imports
from base_utils.ML_base import *

# Properly sets up NCCV environment
GPU_NUMBER = [i for i in range(torch.cuda.device_count())] 
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
device = torch.device("cuda")

  from .autonotebook import tqdm as notebook_tqdm
2024-02-24 15:31:34,454	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.
2024-02-24 15:31:35,048	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


# Cell Data Class

In [2]:
# Custom cell dataset class
class CellData(Dataset):
    def __init__(self, test = None, train = None, 
                 label = "cell_type", ID = 'ENSEMBL', dataset_normalization = True):
        # Dataset normalization controls whether data is ranked based on just median normalized data, or if it is ranked based on median and dataset gene count normalization
        self.dataset_normalization = dataset_normalization
        try:
            self.label = label
            
            # Uses a train/test split if provided, otherwise creates the test split if only a test value is provided
            try:
                train_test_labels = [1 for _ in range(len(train))] + [0 for _ in range(len(test))]
                data = pl.concat((train, test), how = 'vertical')
                data, ranked_genes = self.convert(data, label, ID, train_test_labels)
            except:
                data, ranked_genes = self.convert(test, label, ID, train_test_labels = None)
                
            self.ranked_genes = ranked_genes
            super().__init__(data)
        except:
            print(traceback.format_exc())
            pass
        
    def convert(self, data, label, ID, train_test_labels, count_id = 'expression', gene_id = "genes", GF_limit = 2048):
        tokens = TranscriptomeTokenizer()
        token_dict = tokens.gene_token_dict
        median_dict = tokens.gene_median_dict
        if train_test_labels != None:
            ans_dict = {'input_ids':[], 'length':[], label:[], 'train':train_test_labels}
        else:
            ans_dict = {'input_ids':[], 'length':[], label:[]}
        
        # Normalizes for gene counts in sample
        gene_counts = {}
        for col_name in data.columns:
            if '_' not in col_name:
                gene_expression = data[col_name]
                gene_counts[col_name] = np.median(gene_expression)
                    
        genes = list(gene_counts.keys())
        expression = data.select(genes)
        label_index = data.columns.index(label)
    
        
        for row in expression.iter_rows():
            gexp = {genes[i]:exp for i, exp in enumerate(row[:-1])}
            
            # Normalizes each set of genes by median and/or dataset gene count
            for num, key in enumerate(list(gexp.keys())):
                try:
                    if self.dataset_normalization:
                        gexp[key] /= (median_dict[key] * gene_counts[key])
                    else:
                        gexp[key] /= (median_dict[key])
                except:
                    gexp.pop(key)
               
                
            gexp = sorted(gexp.items(), key = lambda x: x[1], reverse = True)
            ranked_genes = [gexp[i][0] for i in range(len(gexp))][:GF_limit]
            
            input_ids = self.tokenize_dataset(gene_set = ranked_genes, token_dict = token_dict, type = ID)
            ans_dict["input_ids"].append(input_ids)
            ans_dict["length"].append(len(input_ids))
            try:
                ans_dict[label].append(row[label_index])
            except:
                ans_dict[label].append(row[label_index - 1])
        
        # Creates pyarrow tabe out of the data
        data = pa.Table.from_arrays([ans_dict[key] for key in list(ans_dict.keys())], names=list(ans_dict.keys()))

        return data, ranked_genes
        
    # Function for tokenizing genes into ranked-value encodings from Geneformer
    def tokenize_dataset(self, gene_set, token_dict, type = None, species = 'human'):
        wrap = True

        if isinstance(gene_set[0], list) == False:
            gene_set = [gene_set]
            wrap = False
            
        pool = Pool()
        converted_set = []

        # Ensembl based searching
        def process_gene(gene):
             api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{gene}?object_type=gene"
             response = requests.get(api_url, headers={"Content-Type": "application/json"})
             try:
                 data = response.json()
                 gene = data[0]['id']
             except:
                 gene = None
             return gene
             
        # HGNC ID searching
        def process_hgnc(gene):
            for gene in tqdm.tqdm(genes, total = len(genes)):
                api_url = f"https://rest.ensembl.org/xrefs/symbol/{species}/{hgnc_id}?object_type=gene"
                response = requests.get(api_url, headers={"Content-Type": "application/json"})
                try:
                    data = response.json()
                    gene = data[0]['id']
                except:
                    gene = None
                return gene
                        
        # GO ID searching
        def process_go(gene):
             mg = mygene.MyGeneInfo()
             results = mg.query(gene, scopes="go", species=species, fields="ensembl.gene")
    
             ensembl_ids = []
             max_score = 0
             for hit_num, hit in enumerate(results["hits"]):
                 if hit['_score'] > max_score:
                     max_score = hit['_score']
                     chosen_hit = hit
             try:
                 try:
                     gene = chosen_hit["ensembl"]["gene"]
                 except:
                     gene = chosen_hit["ensembl"][0]["gene"]
             except:
                 gene = None
             return gene
             
        # Selects the ID conversion to ensembl
        if type == None or type.upper() == 'ENSEMBL':
            converted_set = gene_set
        elif type.upper() == 'GENE':
            for genes in gene_set:
                converted_genes = []
                for result in tqdm.tqdm(pool.imap(process_gene, genes), total = len(genes)):
                    converted_genes.append(result)
                converted_set.append(converted_genes)
                
        elif type.upper() == 'GO':
            for genes in gene_set:
                converted_genes = []
                for result in tqdm.tqdm(pool.imap(process_go, genes), total = len(genes)):
                    converted_genes.append(result)
                converted_set.append(converted_genes)
                
        elif type.upper() == 'HGNC':
            for genes in gene_set:
                converted_genes = []
                for result in tqdm.tqdm(pool.imap(process_hgnc, genes), total = len(genes)):
                    converted_genes.append(result)
                converted_set.append(converted_genes)
                
        # Obtains Cheml ENSEMBL names for each gene if possible
        Chembl = []
        
        for set_num, set in enumerate(converted_set):
            Chembl.append([])
            for gene in set:
                if gene == None:
                    Chembl[set_num].append(None)
                else:
                    try:
                        Chembl[set_num].append(token_dict[gene])
                    
                    except:
                        print(f'{gene} not found in tokenized dataset!')
                        Chembl[set_num].append(None)
    
        if wrap == False:
            Chembl = Chembl[0]
        
        return Chembl    

# Function for filtering the sample of samples with much greater or fewer genes
def filter_samples(data):
    gene_columns = data.columns
    
    # Calculate the total count of genes for each sample by iterating through rows
    total_gene_counts = [row.sum() for row in data.select(gene_columns).to_numpy()]
   
    # Create a Series from the list of sums and add it as a new column
    data_with_total = data.with_columns(pl.Series("total_gene_count", total_gene_counts))

    # Calculate the mean and standard deviation of the total counts
    mean_count = data_with_total["total_gene_count"].mean()
    std_dev = data_with_total["total_gene_count"].std()

    # Filter out samples with total count outside the specified range
    lower_bound = mean_count - 3 * std_dev
    upper_bound = mean_count + 3 * std_dev
    filtered_data = data_with_total.filter((pl.col("total_gene_count") >= lower_bound) &
                                           (pl.col("total_gene_count") <= upper_bound))

    # Drop the 'total_gene_count' column if not needed
    filtered_data = filtered_data.drop("total_gene_count")

    return filtered_data

# Function for equalizing labels
def equalize_data(data, label = 'RA'):
    labels = data[label].to_list()
    label_set = list(set(labels))
    freq = {i:labels.count(i) for i in list(set(labels))}
    classes = sorted(freq.items(), key = lambda x: x[1])
    min_class, min_freq = classes[0][0], classes[0][1]
    labels = [i[0] for i in classes]
    labels = [i for i in labels if i != min_class]
    data_columns = data.columns
    
    class_numbers = {key:0 for key in label_set}
    data_rows = []
    for row in data.iter_rows():  
        row_label = row[-1]  
    
        if class_numbers[row_label] < min_freq:
            class_numbers[row_label] += 1
            data_rows.append(row)
    data = pl.DataFrame(data_rows, schema = data_columns)

    return data

# Primary Function

In [3]:
# Primary function for running Geneformer analysis
def format_sci(data, 
               token_dictionary = Path('geneformer/token_dictionary.pkl'), 
               augment = False, 
               noise = None, 
               save = 'Genes.dataset', 
               gene_conversion = Path("geneformer/gene_name_id_dict.pkl"), 
               target_label = "RA", 
               GF_samples = 20000, 
               equalize = True,
               save_img = 'Scipher_Roc_data.png',
               ensembl_convert = True,
               epochs = 50,
               filter_data = False, 
               normalize = True, 
               augment_combine = False,
               keyword  = None):
               
    '''
    KEY FUNCTION PARAMETERS
    ------------------------------------------
    data : csv
        CSV file containing expression/labelled data to be loaded into the model

    augment: bool, default = False
        Chooses whether to create augmented data and use it as a training set (with the true dataset as the test set) or not.
        
    noise : None, float, default = None
        If set to a float, noise equivalent to the noise * the original gene mean for each gene will be applied to the dataset.
        
    save : str, path, default = 'Genes.dataset'
        Save name for the GF-compatible saved dataset created when converting to the proper dataset format.
        
    target_label : str, default = 'RA'
        The name of the column in the csv dataset that contains class labels.
        
    GF_samples : int, default = 20000
        The number of samples to augment in total. Each class is represented equally
    
    equalize : bool, default = True
        Equalizes the dataset so that all classes are represented in equal amounts
        
    save_file : bool, None, default = Stats.png
        Save file for the PR/ROC curve generated
    
    finetuned_model_location : str, default = 'Geneformer-finetuned'
        Location where the finetuned model weights are saved

    epochs : int
        Number of epochs to train Geneformer
        
    filter_data : bool, default = False
        Whether data should be filters for samples falling outside of a standard deviation of gene counts (by default, -3 to 3)
        
    augment_combine : bool, default = True
        If data is augmented, the data is mxixed into the train/test set. If set to false, it will PURELY be used as training data, and the og dataset will be used for testing
        
    '''
    
    cols = []
    conversion = {}
    data = pl.read_csv(data)
        
    token_dict = pk.load(open(token_dictionary, 'rb'))
    gene_dict = pk.load(open(gene_conversion, 'rb'))
    
    if ensembl_convert == True:
        # Converts from gene symbol to ensembl ID
        for column in data.columns:
            try:
                ensembl = gene_dict[column.strip()]
            except:
                continue
            try:
                token_dict[ensembl]
            except:
                continue
            cols.append(column)
            conversion[column] = ensembl

        keep = cols + [target_label]
        data = data.select(keep)
        data = data.rename(conversion)
    
    if equalize == True:
        data = equalize_data(data)
  
    labels = [int(i) for i in list(data[target_label])]
    data = data.sample(fraction = 1.0, shuffle = True)
  
    # Calculates dataset bias
    dataset_bias = 1 - (labels.count(0)/labels.count(1))/2
    
    # Augments data
    augmented_data = None
    if augment == True:
        #augmented_data = augment_data(data = data, selected_label = 'all', num_samples = GF_samples, polars = True, normalize = False)    
        augmented_data = augment_data(data = data, selected_label = 0, num_samples = labels.count(1) - labels.count(0), polars= True, normalize = False)
        augmented_data = augmented_data.sample(fraction = 1.0, shuffle = True)
        
    if augment_combine:
        data = pd.concat((data.to_pandas(), augmented_data.to_pandas()), axis = 0)
        data = pl.from_pandas(data)
        augmented_data = None
        
    # Normalizes data if indicated
    if normalize:
        data = normalize_data(data, polars = True)
        
    if filter_data:
        data = filter_samples(data)
    
    # Converts data to GF-applicable format
    try:
        cell_data = CellData(train = augmented_data, test = data, label = target_label)
    except:
        cell_data = CellData(train = None, test = data, label = target_label)

    cell_data.save_to_disk(save)
    '''
    # If you want to load data instead of creating data for GeneFormer (for replicability), move the 2 lines of save code below to the prior section. 
    cell_data = CellData()
    cell_data.load_from_disk(save)
    '''

    # Selects only genes that are exposed to GeneFormer
    data = data.select(cell_data.ranked_genes + [target_label])
    
    try:
        augmented_data = augmented_data.select(cell_data.ranked_genes + [target_label])
    except:
        pass
    
    # Adds types of noise to data if indicated
    if noise:
        data = add_noise(data, noise = noise)     
        
    data = data.sample(fraction = 1.0, shuffle = True)
    
    # Differential expression analysis
    fpr_de, tpr_de, auc_de = de_analysis(data, label_column = 'RA')
    print(f'Differential Expression AUC: {auc_de}')
    
    # Calculates ROC curve for GeneFormer
    fpr_gf, tpr_gf, auc_gf = finetune_cells(model_location = "/work/ccnr/GeneFormer/GeneFormer_repo", dataset = 'Scipher.dataset', 
                                            epochs = epochs, geneformer_batch_size = 100,
            skip_training = False, label = "RA", inference = False, optimize_hyperparameters = False, device = device,
            emb_extract = False, freeze_layers = 0, output_dir = 'GF-finetuned', max_lr = 1e-3)
    
    # Ensemble models
    fpr_svc, tpr_svc, auc_svc = SVC_model(data,)
    fpr_rf, tpr_rf, auc_rf = RandomForest(data)
    fpr_xg, tpr_xg, auc_xg = XGBoost(data)
    fpr_ffn, tpr_ffn, auc_ffn = FFN(test_data = data)
        
    plt.figure(figsize=(8, 6))
    plt.plot(fpr_rf, tpr_rf, color='darkorange', lw=2, label=f'Random Forest (RF) AUC = {round(auc_rf, 2)}')
    plt.plot(fpr_svc, tpr_svc, color='green', lw=2, label=f'Support Vector Machine (SVM) AUC = {round(auc_svc, 2)}')
    plt.plot(fpr_ffn, tpr_ffn, color = 'blue', lw=2, label = f'Feedforward Neural Network (FFN) AUC = {round(auc_ffn, 2)}')
    plt.plot(fpr_gf, tpr_gf, color = 'red', lw=2, label = f'GeneFormer (GF) AUC = {round(auc_gf, 2)}')
    plt.plot(fpr_de, tpr_de, color = 'black', lw=2, label = f'Differential Expression (DE) AUC = {round(auc_de, 2)}')
    plt.plot(fpr_xg, tpr_xg, color = 'orange', lw=2, label = 'XGBoost (XG) AUC = {round(auc_xg, 2)}')
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='Chance')
    
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curve Comparison')
    plt.legend(loc="lower right")
    
    # Saves FPR/TPR data
    img_data = {
        'Geneformer FPR': fpr_gf,
        'Geneformer TPR':tpr_gf,
        'Feed Forward FPR':fpr_ffn,
        'Feed Forward TPR':tpr_ffn,
        'SVM FPR':fpr_svc,
        'SVM TPR':tpr_svc,
        'RF FPR':fpr_rf,
        'RF TPR':tpr_rf
        
        }
    img_data = pd.DataFrame.from_dict(img_data, orient='index')
    img_data = img_data.transpose()
    
    if keyword == None:
        img_data.to_csv('data_ROC.csv')
        plt.show()
    else:
        img_data.to_csv(f'{keyword}_data_ROC.csv')
        plt.show()
    

# Runtime

In [4]:
run_type = 'arthritis'
if 'arthritis' in run_type:
    format_sci(data = Path("Training_Datasets/GSE97476_unique.csv"), save_img = 'RAROC.png', save = 'Scipher.dataset',
               equalize = True, augment = False, epochs = 30, noise = 0.0, filter_data = True,
                            normalize = False, ensembl_convert = False, keyword = 'arthritis') 
                            
if 'arthritis2' in run_type:
    format_sci(data = Path("Training_Datasets/GSE97810_unique.csv"), save_img = 'RAROC.png', save = 'Scipher.dataset', 
               equalize = True, augment = False, epochs = 30, noise = 0.0,
                            normalize = False, ensembl_convert = False, keyword = 'arthritis2') 
elif 'lung_cancer' in run_type:
    format_sci(data = Path("Training_Datasets/lungCancer.csv"), save_img = 'CancerROC.png', save = 'Scipher.dataset', equalize = True, filter_data = False, noise = 0, 
                                                                                    normalize = False, epochs = 8, ensembl_convert = False,  keyword = 'lung_cancer') 
    
elif 'breast_cancer' in run_type:
    format_sci(data = Path("Training_Datasets/breastCancer.csv"), ensembl_convert = False, 
               save_img = 'CancerROC.png', save = 'Scipher.dataset', equalize = True, filter_data = False, noise = 0, keyword = 'breast_cancer', 
                                                                                    normalize = False, epochs = 20) 
elif 'carcinoma' in run_type:
    format_sci(data = Path("Training_Datasets/Carcinoma.csv"), save_img = 'CancerROC.png', 
                                                                save = 'Scipher.dataset', equalize = True, filter_data = False, noise = 0, keyword = 'carcinoma',
                                                                                    normalize = True, epochs = 20) 
elif 'covid' in run_type:
    format_sci(data = Path("Training_Datasets/COVID_half.csv"), save_img = 'CV_ROC.png', keyword = 'covid',
               save = 'Scipher.dataset', equalize = True, epochs = 5) 
elif 'carcinoma_single' in args.type:
    format_sci(data = Path("Training_Datasets/carcinoma_sample.csv"), save_img = 'CV_ROC.png', keyword = 'carcinoma_single',
               save = 'Scipher.dataset', equalize = True, epochs = 10) 
elif 'amd' in run_type:
    format_sci(data = Path("Training_Datasets/AMD_frac.csv"), save_img = 'CV_ROC.png', keyword = 'macular degen',
               save = 'Scipher.dataset', equalize = True, epochs = 60) 
else:
    format_sci(data = Path("Training_Datasets/singleRA.csv"), save_img = 'ArtitROC.png', save = 'Scipher.dataset', equalize = True, normalize = True, filter_data = True, keyword = 'RA-large',
               epochs = 15) 
        

  gexp[key] /= (median_dict[key] * gene_counts[key])
Saving the dataset (1/1 shards): 100%|██████████| 897/897 [00:00<00:00, 18230.10 examples/s]
Fitting size factors...
... done in 0.15 seconds.

Fitting dispersions...
... done in 1.40 seconds.

Fitting dispersion trend curve...
... done in 0.15 seconds.

Fitting MAP dispersions...
... done in 1.55 seconds.

Fitting LFCs...
... done in 0.65 seconds.

Refitting 0 outliers.

Running Wald tests...
... done in 0.61 seconds.



Log2 fold change & Wald test p-value: condition 1 vs 0
                 baseMean  log2FoldChange     lfcSE      stat    pvalue  \
ENSG00000244115  5.504495        0.018674  0.041086  0.454499  0.649470   
ENSG00000204020  4.891277       -0.001145  0.043583 -0.026282  0.979032   
ENSG00000197769  4.250484       -0.000325  0.046752 -0.006949  0.994456   
ENSG00000141179  6.350873        0.012082  0.038248  0.315881  0.752093   
ENSG00000105173  4.695315       -0.006137  0.044480 -0.137970  0.890264   
...                   ...             ...       ...       ...       ...   
ENSG00000213625  8.941217       -0.005180  0.032236 -0.160707  0.872324   
ENSG00000138386  6.198562        0.002943  0.038714  0.076028  0.939397   
ENSG00000161921  8.496854        0.006874  0.033069  0.207860  0.835338   
ENSG00000184922  5.798247       -0.006533  0.040029 -0.163205  0.870357   
ENSG00000170836  6.164849       -0.006460  0.038819 -0.166412  0.867832   

                     padj  
ENSG00000244115 

Filter:   0%|          | 0/897 [00:00<?, ? examples/s]


RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx