In [None]:
"""
This script implements phase I in BMGI algorithm: more detials can be found in BMGI paper: 

Random Forest is used here to score every gene for classification power using its \
isoforms expression values as features.

Required Inputs:
1- Annotated gene sets. Each gene is represented by an Ensemble ID (e.g. ). 
2- TEM: sigle-cell processed transcriptomic expression data (rows are the samples (i.g. cells), columns are the transcripts). csv file
3- GEM: sigle-cell processed gene expression data (rows are the samples (i.g. cells), columns are the genes). csv file
4- Sample labels (samples (i.g. cells) belongs to classes, each samples should have a label).

Not required inputs:
5- P-value threshold (defult = 0.001).
6- Percentage threshold (defult = 0.6). 
7- Number of trees for Random Forest (defult =100). 
8- Number of folds for cross validation (defult =10). 

Outputs:
1- 



# Example:

python Model/BMGI_Phase_I.py \
    --gene_sets    /home/maburid/AMIA_project/Transc_stem_cell.csv \
    --TEM_dataset  /home/maburid/AMIA_project/Transc_stem_cell.csv \
    --GEM_dataset  /home/maburid/AMIA_project/Gene_stem_cell.csv \
    --labels       /home/maburid/AMIA_project/labels.csv \
    --num_trees    50 \
    --cv           5

"""


import gc
gc.collect()
import numpy as np
import sklearn
import sklearn.ensemble
import sklearn.mixture
import scipy

import random
import sys
import json
import time
import mygene
import gseapy
from gseapy import parser 
import pandas as pd                        
import xlrd        
import os
from sklearn import preprocessing
import operator




if __name__ == "__main__":
    # parse command-line arguments
    parser = argparse.ArgumentParser(description="Select genes which perform significantly better than an equivalent random gene for each annotated set.")
    parser.add_argument("--gene_sets", help="list of annotated gene sets", required=True)
    parser.add_argument("--TEM_dataset", help="TEM input dataset (samples x transcripts)", required=True)
    parser.add_argument("--GEM_dataset", help="GEM input dataset (samples x genes)", required=True)
    parser.add_argument("--labels", help="list of sample labels", required=True)
    
    parser.add_argument("--num_trees", help="number of trees in random forest", type=int, default=100)
    parser.add_argument("--cv", help="number of folds for k-fold cross validation", type=int, default=10)
    parser.add_argument("--p_threshold", help="maximum p-value required for a gene to be selected", type=float, default=0.001)
    parser.add_argument("--percent_threshold", help="minimum percentage of filtered genes in a set to the origional genes, otherwise, if less \
    than this threshould, dont count the set as a significant", type=float, default=0.6)
    
    args = parser.parse_args()

    # load input data
    print("loading input datasets...")
    filename_TEM = args.TEM_dataset
    filename_GEM = args.GEM_dataset 
    
    Trans_data=pd.read_csv(filename_TEM)     
    Gene_data=pd.read_csv(filename_GEM)        
    print("Genes_level_data_shape:  ", Gene_data.shape )
    print("Genes_level_data_shape:  ", Trans_data.shape )

    
    