# ArChIPelag -AggRegating multiple position weight matrices and ChIP-seq with machinE leArninG for prediction of transcription factors binding sites

In [None]:
import os
import re
import sys
import argparse
import subprocess
import matplotlib
import time
import pandas as pd
import numpy as np
from itertools import groupby
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import roc_auc_score, average_precision_score, accuracy_score
from sklearn.metrics import roc_curve
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve
import matplotlib.pyplot as plt
from pylab import *
%matplotlib inline
from sklearn.decomposition import PCA
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, classification_report, confusion_matrix
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score, confusion_matrix
import pickle
import csv
from collections import defaultdict
import operator
import joblib
from sklearn.ensemble import BaggingClassifier
from sklearn import svm
from sklearn.model_selection import StratifiedKFold
from Bio import SeqIO
from sklearn.feature_selection import RFE
from sklearn import preprocessing
import pybedtools as pbt
import pyBigWig as pbw
import glob
from datetime import date
import random
from xgboost import XGBClassifier
from sklearn.model_selection import KFold
from sklearn.metrics import roc_auc_score
from itertools import chain
import string
import shlex
import shutil
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from sklearn.ensemble import AdaBoostClassifier
from adjustText import adjust_text
import matplotlib.ticker as ticker


In [None]:
os.chdir("...") 
root = "..." 
basicdir = os.path.abspath('GHTS/') 
outputdir = os.path.abspath('GHTS/outputdir')
train_dir = os.path.abspath('GHTS/Train/') 
test_dir = os.path.abspath('GHTS/Test/') 

if not os.path.exists(basicdir): 
    os.makedirs(basicdir)
if not os.path.exists(outputdir):
    os.makedirs(outputdir)
if not os.path.exists(train_dir):
    os.makedirs(train_dir)
if not os.path.exists(test_dir):
    os.makedirs(test_dir)

TFs_CHS_AFS = pd.read_csv("....txt", sep="\t", header=None)
TFs_CHS_AFS

In [None]:
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field

from typing import List
from sklearn.metrics import roc_auc_score, average_precision_score
import numpy as np

@dataclass
class Scorer(metaclass=ABCMeta):
    name: str
    @abstractmethod
    def score(self, *args, **kwargs) -> float:
        pass

@dataclass
class ConstantScorer(Scorer):
    const: float
    def score(self, *args, **kwargs) -> float:
        return self.const

class BinaryScorer(Scorer):
    @abstractmethod
    def score(self, y_score: List[float], y_real: List[int]) -> float:
        raise NotImplementedError

class SklearnScorer(BinaryScorer):
    pass

class SklearnROCAUC(SklearnScorer):
    def score(self, y_score: List[float], y_real: List[int]) -> float:
        y_score_arr = np.array(y_score)
        y_real_arr = np.array(y_real)
        return float(roc_auc_score(y_true=y_real_arr, y_score=y_score_arr))
    
class SklearnPRAUC(SklearnScorer):
    def score(self, y_score: List[float], y_real: List[int]) -> float:
        y_score_arr = np.array(y_score)
        y_real_arr = np.array(y_real)
        return float(average_precision_score(y_true=y_real_arr, y_score=y_score_arr))

class PRROCScorer(BinaryScorer):
    pass

def import_PRROC():
    '''
    import PRROC package (https://cran.r-project.org/web/packages/PRROC/index.html)
    '''
    from rpy2.robjects.packages import importr, isinstalled
    if not isinstalled("PRROC"):
        utils = importr("utils")
        utils.chooseCRANmirror(ind=1)
        utils.install_packages("PRROC", quiet = True, verbose=False)
    pkg = importr("PRROC")
    return pkg

@dataclass
class PRROC_PRAUC(PRROCScorer):
    type: str

    def score(self, y_score: List[float], y_real: List[int]) -> float:
        from rpy2.rinterface_lib import openrlib
        with openrlib.rlock:
            pkg = import_PRROC()
            from rpy2.robjects.vectors import FloatVector
            labels = FloatVector([x for x in y_real])
            scores = FloatVector(y_score)
            if self.type == "integral":
                auroc = pkg.pr_curve(scores, weights_class0=labels, dg_compute=False)
                auroc = auroc[1][0]
            elif self.type == "davisgoadrich":
                auroc = pkg.pr_curve(scores, weights_class0=labels, dg_compute=True)
                auroc = auroc[2][0]
            else:
                raise Exception()
            return auroc

class PRROC_ROCAUC(PRROCScorer):
    def score(self, y_score: List[float], y_real: List[int]) -> float:
        from rpy2.rinterface_lib import openrlib
        with openrlib.rlock:
            pkg = import_PRROC()
            from rpy2.robjects.vectors import FloatVector
            labels = FloatVector(y_real)
            scores = FloatVector(y_score)
            auroc = pkg.roc_curve(scores, weights_class0=labels)
            auroc = auroc[1][0]
            return auroc

@dataclass
class ScorerInfo:
    name: str
    alias: str = ""
    params: dict = field(default_factory=dict)

    @classmethod
    def from_dict(cls, dt: dict):
        return cls(**dt)

    def __attrs_post_init__(self):
        if not self.alias:
            self.alias = self.name
    
    def make(self):
        if self.name == "scikit_rocauc":
            return SklearnROCAUC(self.alias)
        elif self.name == "scikit_prauc":
            return SklearnPRAUC(self.alias)
        elif self.name == "prroc_rocauc":
            return PRROC_ROCAUC(self.alias)
        elif self.name == "prroc_prauc":
            tp = self.params.get("type")
            if tp is None:
                raise Exception("type must be specified for prauc scorer from PRROC package")
            tp = tp.lower()
            return PRROC_PRAUC(self.alias, tp)
        elif self.name == "constant_scorer":
            cons = self.params.get("cons")
            if cons is None:
                raise Exception("cons must be specified for constant scorer")
            cons = float(cons)
            return ConstantScorer(self.alias, cons)
        raise Exception(f"Wrong scorer: {self.name}")
    
    def to_dict(self) -> dict:
        dt = {}
        dt['name'] = self.name
        dt['alias'] = self.alias
        dt['params'] = self.params
        return dt
    
import scorer_module


score_calc_rocauc = ScorerInfo("prroc_rocauc", "rocauc").make()

score_calc_prauc = ScorerInfo("prroc_prauc", "prauc", params={"type": "integral"}).make()


In [None]:
def model_building(X_train, X_test, Y_train, Y_test, model_name): 

    print(" ")
    print("Model_building is running ...")

    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=0)

    if model_name == "RandomForestClassifier":
        MODEL = RandomForestClassifier(**{'max_depth': 6, 'max_features': 4, 'min_samples_leaf': 4, 'min_samples_split': 11, 'n_estimators': 100}, n_jobs=n_jobs)

    MODEL.fit(X_train, Y_train)
    
    return MODEL


def rocauc_plotting(X_train_H, X_test_H, Y_train_H, Y_test_H,
                    X_test_M, Y_test_M,
                    Y_train_predicted_proba_H,
                    Y_test_predicted_proba_H,
                    Y_test_predicted_proba_M, 
                    training_obj, testing_obj, 
                    features_c, 
                    model_name, 
                    TF, plot_slim, MODEL, mtrx_di=0, mtrx_mono=0):


    print(" ")
    print("ROC_AUC_plotting is running ...")
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(16, 10)
    linewidth = 5
    color_palette = sns.color_palette("tab10")
  
    f_c_tr_h = 0
    best_pwm_q_list = []
    best_pwm_number_list = []
    for k in range(len(X_train_H.columns)):
        f = X_train_H.columns[k]
        fpr_train_H_PWM, tpr_train_H_PWM, thresholds = roc_curve(Y_train_H, X_train_H[f])
        best_pwm_q_list.append(score_calc_rocauc.score(X_train_H[f], Y_train_H))
        best_pwm_number_list.append(k)
        
    best_pwm_q_list_1, best_pwm_number_list1 = zip(*sorted(zip(best_pwm_q_list, best_pwm_number_list), reverse=True))
    f_c_tr_h = best_pwm_number_list1[0]
    

    f_c_tr_h_pr = 0
    best_pwm_q_list = []
    best_pwm_number_list = []
    for k in range(len(X_train_H.columns)):
        f = X_train_H.columns[k]
        precision_test_H_m, recall_test_H_m, thresholds = precision_recall_curve(Y_train_H, X_train_H[f])
        best_pwm_q_list.append(score_calc_prauc.score(X_train_H[f], Y_train_H))
        best_pwm_number_list.append(k)
        
    best_pwm_q_list_2, best_pwm_number_list2 = zip(*sorted(zip(best_pwm_q_list, best_pwm_number_list), reverse=True))
    f_c_tr_h_pr = best_pwm_number_list2[0]
    
    
    best_pwm_q_list = []
    best_pwm_number_list = []
    for k in range(len(X_train_H.columns)):
        f = X_train_H.columns[k]
        fpr_train_H_PWM, tpr_train_H_PWM, thresholds = roc_curve(Y_test_M, X_test_M[f])
        best_pwm_q_list.append(score_calc_rocauc.score(X_test_M[f], Y_test_M))
        best_pwm_number_list.append(k)
        
    best_pwm_q_list_3, best_pwm_number_list3 = zip(*sorted(zip(best_pwm_q_list, best_pwm_number_list), reverse=True))
    
    fpr_train_H_PWM, tpr_train_H_PWM, thresholds = roc_curve(Y_train_H, X_train_H[X_train_H.columns[f_c_tr_h]])
    roc_auc_train_H_PWM_mono = score_calc_rocauc.score(X_train_H[X_train_H.columns[f_c_tr_h]], Y_train_H)
    plt.plot(fpr_train_H_PWM, tpr_train_H_PWM, linewidth=linewidth, label=f'Train PWM {training_obj} ROC (by roc) (AUC = %0.3f)' % roc_auc_train_H_PWM_mono, color=color_palette[0])

    
    fpr_test_H_PWM, tpr_test_H_PWM, thresholds = roc_curve(Y_test_H, X_test_H[X_test_H.columns[f_c_tr_h]])
    roc_auc_test_H_PWM_mono = score_calc_rocauc.score(X_test_H[X_test_H.columns[f_c_tr_h]], Y_test_H)
    plt.plot(fpr_test_H_PWM, tpr_test_H_PWM, linewidth=linewidth,label=f'Validation PWM {training_obj} ROC (AUC = %0.3f)' % roc_auc_test_H_PWM_mono, color=color_palette[1])

    fpr_test_M_PWM, tpr_test_M_PWM, thresholds = roc_curve(Y_test_M, X_test_M[X_test_M.columns[f_c_tr_h]])
    roc_auc_test_M_PWM_mono = score_calc_rocauc.score(X_test_M[X_test_M.columns[f_c_tr_h]], Y_test_M)
    plt.plot(fpr_test_M_PWM, tpr_test_M_PWM, linewidth=linewidth,label=f'Validation PWM {testing_obj} ROC (AUC = %0.3f)' % roc_auc_test_M_PWM_mono, color=color_palette[2])

    roc_auc_train_H_PWM_di = 0
    roc_auc_test_H_PWM_di = 0
    roc_auc_test_M_PWM_di = 0
 
    
    fpr_test_M, tpr_test_M, thresholds = roc_curve(Y_test_M, Y_test_predicted_proba_M)
    roc_auc_test_M = score_calc_rocauc.score(Y_test_predicted_proba_M, Y_test_M)
    plt.plot(fpr_test_M, tpr_test_M, linewidth=linewidth,label=f'Validation {model_name}_{mode} {testing_obj} ROC (AUC = %0.3f)' % (roc_auc_test_M), color=color_palette[3])

    DF_ROC_PLOTS = pd.DataFrame()
    DF_ROC_PLOTS["fpr"] = fpr_test_M
    DF_ROC_PLOTS["tpr"] = tpr_test_M
    DF_ROC_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_ROC_PLOTS_M_mono_di_test.csv", sep="\t", index=False) # записываю 


    fpr_train_H, tpr_train_H, thresholds = roc_curve(Y_train_H, Y_train_predicted_proba_H)
    roc_auc_train_H = score_calc_rocauc.score(Y_train_predicted_proba_H, Y_train_H)
    plt.plot(fpr_train_H, tpr_train_H, linewidth=linewidth,label=f'Train {model_name}_{mode} {training_obj} ROC (AUC = %0.3f)' % (roc_auc_train_H), color=color_palette[4])

    DF_ROC_PLOTS = pd.DataFrame()
    DF_ROC_PLOTS["fpr"] = fpr_train_H
    DF_ROC_PLOTS["tpr"] = tpr_train_H
    DF_ROC_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_ROC_PLOTS_H_mono_di_train.csv", sep="\t", index=False) # записываю 

    fpr_test_H, tpr_test_H, thresholds = roc_curve(Y_test_H, Y_test_predicted_proba_H)
    roc_auc_test_H = score_calc_rocauc.score(Y_test_predicted_proba_H, Y_test_H)
    plt.plot(fpr_test_H, tpr_test_H, linewidth=linewidth,label=f'Validation {model_name}_{mode} {training_obj} ROC (AUC = %0.3f)' % (roc_auc_test_H), color=color_palette[5])

    DF_ROC_PLOTS = pd.DataFrame()
    DF_ROC_PLOTS["fpr"] = fpr_test_H
    DF_ROC_PLOTS["tpr"] = tpr_test_H
    DF_ROC_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_ROC_PLOTS_H_mono_di_test.csv", sep="\t", index=False) # записываю 

    sns.set_context("paper", font_scale=3)
    plt.rc('legend',fontsize=12) # using a size in points
    plt.plot([0, 1], [0, 1], 'k--')  # ideal classifier
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('False Positive Rate', fontsize=25)
    plt.ylabel('True Positive Rate', fontsize=25)

    
    plt.title('{model_name} besthits. {training_obj}/{testing_obj}. {N} features'.format(model_name=model_name, N=features_c, training_obj=training_obj, testing_obj=testing_obj), fontsize=20)
    plt.legend(loc="lower right")
    plt.savefig('ROC_AUC_{model_name}_{mode}_Train_{training_obj}_Test_{testing_obj}_{N}_features_{mode}.pdf'.format(mode=mode,model_name=model_name, N=features_c, training_obj=training_obj, testing_obj=testing_obj), dpi=100)
    plt.show()
    plt.close()

    print(" ")
    print("PR_plotting is running ...")
    fig = matplotlib.pyplot.gcf()
    fig.set_size_inches(16, 10)

    precision_train_H_PWM, recall_train_H_PWM, thresholds = precision_recall_curve(Y_train_H, X_train_H[X_train_H.columns[f_c_tr_h_pr]])
    pr_auc_train_H_PWM_mono = score_calc_prauc.score(X_train_H[X_train_H.columns[f_c_tr_h_pr]], Y_train_H)
    plt.plot(recall_train_H_PWM, precision_train_H_PWM, linewidth=linewidth,label=f'Train PWM {training_obj} PR (by pr) (AUC = %0.3f)' % pr_auc_train_H_PWM_mono, color=color_palette[0])

    
    precision_test_H_PWM, recall_test_H_PWM, thresholds = precision_recall_curve(Y_test_H, X_test_H[X_test_H.columns[f_c_tr_h_pr]])
    pr_auc_test_H_PWM_mono = score_calc_prauc.score(X_test_H[X_test_H.columns[f_c_tr_h_pr]], Y_test_H)
    plt.plot(recall_test_H_PWM, precision_test_H_PWM, linewidth=linewidth,label=f'Validation PWM {training_obj} PR (AUC = %0.3f)' % pr_auc_test_H_PWM_mono, color=color_palette[1])

    precision_test_M_PWM, recall_test_M_PWM, thresholds = precision_recall_curve(Y_test_M, X_test_M[X_test_M.columns[f_c_tr_h_pr]])
    pr_auc_test_M_PWM_mono = score_calc_prauc.score(X_test_M[X_test_M.columns[f_c_tr_h_pr]], Y_test_M)
    plt.plot(recall_test_M_PWM, precision_test_M_PWM, linewidth=linewidth,label=f'Validation PWM {testing_obj} PR (AUC = %0.3f)' % pr_auc_test_M_PWM_mono, color=color_palette[2])


    pr_auc_train_H_PWM_di = 0
    pr_auc_test_H_PWM_di = 0
    pr_auc_test_M_PWM_di = 0

    precision_test_M, recall_test_M, thresholds = precision_recall_curve(Y_test_M, Y_test_predicted_proba_M)
    pr_auc_test_M = score_calc_prauc.score(Y_test_predicted_proba_M, Y_test_M)
    plt.plot(recall_test_M, precision_test_M, linewidth=linewidth,label=f'Validation {testing_obj} PR (AUC = %0.3f )' % (pr_auc_test_M), color=color_palette[3])


    DF_PR_PLOTS = pd.DataFrame()
    DF_PR_PLOTS["precision"] = precision_test_M
    DF_PR_PLOTS["recall"] = recall_test_M
    DF_PR_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_PR_PLOTS_M_mono_di_test.csv", sep="\t", index=False) # записываю 

    ##
    precision_train_H, recall_train_H, thresholds = precision_recall_curve(Y_train_H, Y_train_predicted_proba_H)
    pr_auc_train_H = score_calc_prauc.score(Y_train_predicted_proba_H, Y_train_H)
    plt.plot(recall_train_H, precision_train_H, linewidth=linewidth,label=f'Train {training_obj} PR (AUC = %0.3f )' % (pr_auc_train_H), color=color_palette[4])

    DF_PR_PLOTS = pd.DataFrame()
    DF_PR_PLOTS["precision"] = precision_train_H
    DF_PR_PLOTS["recall"] = recall_train_H
    DF_PR_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_PR_PLOTS_H_mono_di_train.csv", sep="\t", index=False) # записываю 

    ##
    precision_test_H, recall_test_H, thresholds = precision_recall_curve(Y_test_H, Y_test_predicted_proba_H)
    pr_auc_test_H = score_calc_prauc.score(Y_test_predicted_proba_H, Y_test_H)
    plt.plot(recall_test_H, precision_test_H, linewidth=linewidth,label=f'Validation {training_obj} PR (AUC = %0.3f )' % (pr_auc_test_H), color=color_palette[5])

    DF_PR_PLOTS = pd.DataFrame()
    DF_PR_PLOTS["precision"] = precision_test_H
    DF_PR_PLOTS["recall"] = recall_test_H
    DF_PR_PLOTS.to_csv(new_dir_name + "/" + TF + "_" + model_name + "_" + str(features_c) + "_DF_PR_PLOTS_H_mono_di_test.csv", sep="\t", index=False) # записываю 

    sns.set_context("paper", font_scale=3)
    plt.rc('legend',fontsize=12)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.0])
    plt.xlabel('Recall', fontsize=25)
    plt.ylabel('Precision', fontsize=25)


    plt.title('{model_name} besthits. {training_obj}/{testing_obj}. {N} features'.format(model_name=model_name, N=features_c, training_obj=training_obj, testing_obj=testing_obj), fontsize=20)
    plt.legend(loc="lower left")
    plt.savefig('PR_AUC_{model_name}_{mode}_Train_{training_obj}_Test_{testing_obj}_{N}_features_{mode}.pdf'.format(mode=mode,model_name=model_name, N=features_c, training_obj=training_obj, testing_obj=testing_obj), dpi=100)


    line_f = f'echo {features_c} {roc_auc_train_H_PWM_mono} {roc_auc_train_H_PWM_di} {roc_auc_test_H_PWM_mono} {roc_auc_test_H_PWM_di} {roc_auc_test_M_PWM_mono} {roc_auc_test_M_PWM_di} ' \
                 f'{roc_auc_train_H} 0 0 0 ' \
                 f'{roc_auc_test_H} 0 0 0 ' \
                 f'{roc_auc_test_M} 0 0 0 ' \
                 f'{pr_auc_train_H} 0 0 0 ' \
                 f'{pr_auc_test_H} 0 0 0 ' \
                 f'{pr_auc_test_M} 0 0 0 ' \
                 f'{pr_auc_train_H_PWM_mono} {pr_auc_train_H_PWM_di} {pr_auc_test_H_PWM_mono} {pr_auc_test_H_PWM_di} {pr_auc_test_M_PWM_mono} {pr_auc_test_M_PWM_di} ' \
             f' >> {TF}_new_log_roc_pr_{mode}_{today_date}.txt'
    p = subprocess.Popen(line_f, shell=True)
    p.wait()
    plt.show()



def Scale_transform(X, scale_data):
    if scale_data == True:
        X_out = X.copy()
        scaler = StandardScaler()
        scaler_fit = scaler.fit(X_out)

        print("mean", scaler_fit.mean_)
        print("var", scaler_fit.var_)
        
        
        X_out = pd.DataFrame(scaler_fit.transform(X_out), columns = X_out.columns)
        X = []
        return X_out
    else:
        X_out = X.copy()
        X = []
        return X_out

    
def scrambled(orig):
    dest = orig[:]
    shuffle(dest)
    return dest



In [None]:
from sklearn.preprocessing import StandardScaler
    

def collect_all_scanning_res(TF, exp, dataset, flag, root, mode, model_name, pwm_scanning_res_list):
    pwmdir = f"{exp}/{dataset}/{TF}/pwm_scanning_results_addshift"
    pwm_scanning_res_list_flaged = [x + f"_{flag}_cut.tab" for x in pwm_scanning_res_list]

    df_collector = []
    for file_name in pwm_scanning_res_list_flaged:
        dftmp = pd.read_csv(f"{root}/{pwmdir}/{file_name}", header=None, sep='\t')[0]
        df_collector.append(dftmp)
    
    df = pd.DataFrame({i:j for i,j in enumerate(df_collector)})
    return df



def collect_all_scanning_res_CHS_scanning_with_GHTS_pwm(TF, exp, dataset, flag, root, mode, model_name, pwm_scanning_res_list):
    pwmdir = f"{exp}/{dataset}/{TF}/pwm_scanning_results_CHS_scanning_with_GHTS_pwm_addshift"
    pwm_scanning_res_list_flaged = [x + f"_{flag}_cut.tab" for x in pwm_scanning_res_list]

    df_collector = []
    for file_name in pwm_scanning_res_list_flaged:
        dftmp = pd.read_csv(f"{root}/{pwmdir}/{file_name}", header=None, sep='\t')[0]
        df_collector.append(dftmp)
    
    df = pd.DataFrame({i:j for i,j in enumerate(df_collector)})
    return df




def collect_all_scanning_res_GHTS_scanning_with_CHS_pwm(TF, exp, dataset, flag, root, mode, model_name, pwm_scanning_res_list):
    pwmdir = f"{exp}/{dataset}/{TF}/pwm_scanning_results_GHTS_scanning_with_CHS_pwm_addshift"

    pwm_scanning_res_list_flaged = [x + f"_{flag}_cut.tab" for x in pwm_scanning_res_list]
    
    df_collector = []
    for file_name in pwm_scanning_res_list_flaged:
        dftmp = pd.read_csv(f"{root}/{pwmdir}/{file_name}", header=None, sep='\t')[0]
        df_collector.append(dftmp)
    
    df = pd.DataFrame({i:j for i,j in enumerate(df_collector)})
    return df



   
organism = "HUMAN_CHS_GHTS"
n_jobs = 100
scale_data = True
print_GC = False
verbose = True

#today = date.today()
#today_date = today.strftime("%d.%m.%Y")
today_date = "23.12.2023"
    

mode  = "mono"
model_name  = "RandomForestClassifier"
sys.setrecursionlimit = 10**3 


root = "..."
os.chdir(root)

model_dir = os.path.abspath('...') 
basicdir = model_dir
new_dir_name = model_dir


if not os.path.exists(model_dir):
    os.makedirs(model_dir)

os.chdir(model_dir)


TF_сalc = 0
for TF in files_pwm_HUMAN_mono:   
    print(" ")
    print(f"{TF} #" + " {сalc} of {all_c}".format(all_c=len(files_pwm_HUMAN_mono), сalc=TF_сalc))
    TF_сalc += 1

    pwmdir_mono = "/home/ivankozin/projects/best_20_motif_CHS_AFS/AFS"
    dir_list = os.listdir(pwmdir_mono + "/" + TF) 
    pwm_scanning_res_list = []

    for exp_pwm_dir in dir_list:
        pwm_local_dir = pwmdir_mono + "/" + TF + "/" + exp_pwm_dir
        pwm_local_list = [x.split(".pwm")[0] for x in os.listdir(pwm_local_dir) if "pwm" in x]
        pwm_scanning_res_list.extend(pwm_local_list)


    pwm_scanning_res_list = list(set(pwm_scanning_res_list))

    df_pos = collect_all_scanning_res(TF, "GHTS", "Train", "positives", root, mode, model_name, pwm_scanning_res_list)
    df_neg = collect_all_scanning_res(TF, "GHTS", "Train", "random", root, mode, model_name, pwm_scanning_res_list)

    df_pos["ind"] = 1
    df_neg["ind"] = 0
    df = pd.concat([df_pos, df_neg])
    
    X_train_GHTS, Y_train_GHTS = df.iloc[:,:-1], df["ind"]
    
    X_out = X_train_GHTS.copy()
    scaler = StandardScaler()
    scaler_fit = scaler.fit(X_out)
    
    X_train_GHTS = pd.DataFrame(scaler_fit.transform(X_train_GHTS), columns = X_train_GHTS.columns)
    
    df_pos = collect_all_scanning_res(TF, "GHTS", "Test", "positives", root, mode, model_name, pwm_scanning_res_list)
    df_neg = collect_all_scanning_res(TF, "GHTS", "Test", "random", root, mode, model_name, pwm_scanning_res_list)
    
    df_pos["ind"] = 1
    df_neg["ind"] = 0
    df = pd.concat([df_pos, df_neg])
    
    X_test_GHTS, Y_test_GHTS = df.iloc[:,:-1], df["ind"]
    X_test_GHTS = pd.DataFrame(scaler_fit.transform(X_test_GHTS), columns = X_test_GHTS.columns)
    
    df_pos = collect_all_scanning_res_CHS_scanning_with_GHTS_pwm(TF, "CHS", "Test", "positives", root, mode, model_name, pwm_scanning_res_list)
    df_neg = collect_all_scanning_res_CHS_scanning_with_GHTS_pwm(TF, "CHS", "Test", "random", root, mode, model_name, pwm_scanning_res_list)
    
    df_pos["ind"] = 1
    df_neg["ind"] = 0
    df = pd.concat([df_pos, df_neg])
    
    X_test_CHS, Y_test_CHS = df.iloc[:,:-1], df["ind"]
    X_test_CHS = pd.DataFrame(scaler_fit.transform(X_test_CHS), columns = X_test_CHS.columns)
    
    MODEL = model_building(X_train_GHTS, X_test_GHTS, Y_train_GHTS, Y_test_GHTS, model_name)
    
    filename = f"{basicdir}/{TF}_trained_model.sav"
    pickle.dump(MODEL, open(filename, 'wb'))

    
    Y_train_predicted_proba_GHTS = MODEL.predict_proba(X_train_GHTS)[:, 1]
    Y_test_predicted_proba_GHTS = MODEL.predict_proba(X_test_GHTS)[:, 1]
    Y_test_predicted_proba_CHS = MODEL.predict_proba(X_test_CHS)[:, 1]
        
     
    rocauc_plotting(X_train_GHTS, X_test_GHTS, Y_train_GHTS, Y_test_GHTS,
                            X_test_CHS, Y_test_CHS,
                            Y_train_predicted_proba_GHTS,
                            Y_test_predicted_proba_GHTS,
                            Y_test_predicted_proba_CHS, 
                            "GHTS", 
                            "CHS", 
                            X_test_CHS.shape[1], 
                            model_name + " on " + TF, 
                            TF, False, MODEL)
    
    
print("Done!")