In [1]:
import os
import sys

import pandas as pd
import numpy as np
import sklearn
from hmmlearn import hmm
try: # version > 0.2.7
   from hmmlearn.hmm import CategoricalHMM as MultinomialHMM
except: # version <= 0.2.7
   from hmmlearn.hmm import MultinomialHMM
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from sklearn.model_selection import KFold
import pickle

sys.path.insert(0, '../scripts')
from flaHMM_functions import *

In [2]:
NX_all=pd.read_csv('../data/NX_stats.csv', index_col=0)

In [3]:
test = ['Dbia.GCF_018148935','Dbia.d101g','Dbia.d15genomes',
                       'Dere.GCF_003286155','Dere.d101g','Dere.droEre1','Dere.d15genomes',
                       'Dsuz.GCF_013340165',
                       'Dtei.GCF_016746235','Dtei.d101g_2733','Dtei.d101g_CT02']

In [4]:
test2 = ['Dfic.GCF_018152265','Dfic.d101g','Dfic.GCF_000220665',
                       'Dosh.d101g',
                       'Dath.GCA_008121215',
                       'Dazt.GCA_005876895',
                       'Dmir.GCF_003369915',
                       'Dper.GCF_003286085','Dper.d101g','Dper.d15genomes',
                       'Dpse.d15genomes','Dpse.GCF_009870125',
                       'Dinn.GCF_004354385',
                       'Damb.d101g',
                       'Dbif.GCA_009664405',
                       'Dobs.d101g','Dobs.GCF_018151105',
                       'Dtris.d101g']

In [5]:
def train_combination_models(data, model_name, threshold):
    X_test_all=pd.DataFrame()
    all_data_species=pd.read_csv(data,sep='\t', index_col=0)
    for species_test in test + test2:
        
        X_test_all_species_test=pd.DataFrame()
        for threshold in [threshold]:
            threshold_num='threshold_'+str(threshold)
            
            #Read Test Data
            all_data=all_data_species[all_data_species['Data']==species_test]
            try: 
                strand_ref=build2coords[species_test.replace('.','_build_')][1]
            except:
                try: strand_ref=build2coords_flamlike[species_test.replace('.','_build_')][1]
                except: strand_ref=None
            #Calculate emission for the test set
            emission_new=calculate_3emissions(all_data, threshold)
            
            with open('models_pkl/'+model_name, "rb") as file: 
                model=pickle.load(file)
            
            minus_df=all_data[['chr','bin_start', 'bin_end']]
            minus_df.insert(minus_df.shape[1],'strand', 'minus')
            plus_df=all_data[['chr','bin_start', 'bin_end']]
            plus_df.insert(plus_df.shape[1],'strand', 'plus')
            new_index=pd.concat([plus_df, minus_df]).reset_index(drop=True)
            X_test=pd.concat([new_index,pd.DataFrame(emission_new)], axis=1)

            predictions=[]
            probabilities_NC=[]
            probabilities_C=[]
            probabilities_Cent=[]
            
            for strand in ['plus','minus']:
                for chromosome in X_test['chr'].unique().tolist():
                    X_test_chr=X_test[(X_test['chr']==chromosome) & (X_test['strand']==strand)]
                    X_test_chr=np.atleast_2d(X_test_chr[0].tolist()).T
                    
                    if (strand == "plus"):
                        predictions=predictions + model.predict(X_test_chr).tolist()
                        probabilities=model.predict_proba(X_test_chr)
                        probabilities_NC=probabilities_NC + probabilities.T[0].tolist()
                        probabilities_C=probabilities_C + probabilities.T[1].tolist()
                        probabilities_Cent=probabilities_Cent + probabilities.T[2].tolist()
                    elif (strand == "minus"):
                        # Swap and re-swap to retain orientation
                        predictions_tmp = model.predict(np.flip(X_test_chr)).tolist()
                        predictions_tmp.reverse()
                        predictions=predictions + predictions_tmp
                        
                        probabilities_tmp = model.predict_proba(np.flip(X_test_chr)).tolist()
                        probabilities_tmp.reverse()
                        probabilities_tmp = np.array(probabilities_tmp).T.tolist()
                        probabilities_NC=probabilities_NC+probabilities_tmp[0]
                        probabilities_C=probabilities_C+probabilities_tmp[1]
                        probabilities_Cent=probabilities_Cent+probabilities_tmp[2]
                    else:
                        print("Error: Unknown strand")
                
            X_test['pred_'+str(threshold)]=predictions
            X_test['proba_NoCluster_'+str(threshold)]=probabilities_NC
            X_test['proba_Cluster_'+str(threshold)]=probabilities_C
            X_test['proba_Centromere_'+str(threshold)]=probabilities_Cent
            X_test=X_test.rename(columns={0:'emission_'+str(threshold)})
            
            X_test['species_test']=species_test
            
            if strand_ref=='plus':
                 region_binary_test=all_data['region_binary'].tolist()+all_data['region_binary'].replace(1,0).tolist()
            elif strand_ref=='minus':
                region_binary_test=all_data['region_binary'].replace(1,0).tolist()+all_data['region_binary'].tolist()
            elif  strand_ref=='both':
                region_binary_test=all_data['region_binary'].tolist()+all_data['region_binary'].tolist()
            else:
                print("Error: Stand information failed.")
                region_binary_test=all_data['region_binary'].tolist()+all_data['region_binary'].tolist()
            
            X_test['region_binary']=region_binary_test
            
            X_test_all_species_test=pd.concat([X_test_all_species_test,X_test.set_index(['chr', 'bin_start', 'bin_end', 'strand','species_test','region_binary'])], axis=1)
            
        X_test_all=pd.concat([X_test_all, X_test_all_species_test], axis=0)
        
    return(X_test_all)       

In [6]:
! ls models_pkl/

Model_bin_10k_threshold_0.025.pkl   Model_bin_2.5k_threshold_0.1.pkl
Model_bin_10k_threshold_0.05.pkl    Model_bin_2.5k_threshold_0.2.pkl
Model_bin_10k_threshold_0.075.pkl   Model_bin_5k_threshold_0.025.pkl
Model_bin_10k_threshold_0.1.pkl     Model_bin_5k_threshold_0.05.pkl
Model_bin_10k_threshold_0.2.pkl     Model_bin_5k_threshold_0.075.pkl
Model_bin_2.5k_threshold_0.025.pkl  Model_bin_5k_threshold_0.1.pkl
Model_bin_2.5k_threshold_0.05.pkl   Model_bin_5k_threshold_0.2.pkl
Model_bin_2.5k_threshold_0.075.pkl


In [7]:
# Create folder if it doesn't exist    
if not os.path.isdir('results_combinations'):
    os.mkdir('results_combinations')
        
if not os.path.isdir('results_combinations/ext'):
    os.mkdir('results_combinations/ext')

In [8]:
# Make predictions using saved movels for external test set
for threshold in tqdm([0.025, 0.05, 0.075, 0.1]):
#for threshold in tqdm([0.05]):
    for data_ext in tqdm(['all_data_species_extendedList10k.txt','all_data_species_extendedList5k.txt','all_data_species_extendedList2.5k.txt']):
    #for data_ext in tqdm(['all_data_species_extendedList5k.txt']):
        X_test_all_results=pd.DataFrame()
        
        get_bin_size=data_ext.split('all_data_species_extendedList')[1].split('.txt')[0]
        model_make_pred='Model_bin_'+get_bin_size+'_threshold_'+str(threshold)+'.pkl'
        X_test_all_iteration=train_combination_models(data_ext, model_make_pred, threshold)
        X_test_all_results=pd.concat([X_test_all_results, X_test_all_iteration])
        
        X_test_all_results.reset_index().to_csv('results_combinations/ext/X_test_all_extendedList_Bin_'+get_bin_size+'_threshold_'+str(threshold)+'.txt', sep='\t')
        

  0%|          | 0/4 [00:00<?, ?it/s]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [11:14<22:27, 674.00s/it][A
 67%|██████▋   | 2/3 [22:14<11:05, 665.99s/it][A
100%|██████████| 3/3 [42:27<00:00, 849.30s/it][A
 25%|██▌       | 1/4 [42:27<2:07:23, 2547.90s/it]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [11:14<22:28, 674.13s/it][A
 67%|██████▋   | 2/3 [22:23<11:11, 671.18s/it][A
100%|██████████| 3/3 [42:41<00:00, 853.83s/it][A
 50%|█████     | 2/4 [1:25:09<1:25:11, 2555.90s/it]
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [11:08<22:16, 668.24s/it][A
 67%|██████▋   | 2/3 [22:11<11:05, 665.34s/it][A
100%|██████████| 3/3 [41:50<00:00, 836.68s/it][A
 75%|███████▌  | 3/4 [2:06:59<42:14, 2534.96s/it]  
  0%|          | 0/3 [00:00<?, ?it/s][A
 33%|███▎      | 1/3 [10:34<21:09, 634.93s/it][A
 67%|██████▋   | 2/3 [21:12<10:36, 636.24s/it][A
100%|██████████| 3/3 [40:31<00:00, 810.46s/it][A
100%|██████████| 4/4 [2:47:30<00:00, 2512.71

In [9]:
X_test_all_results

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,Unnamed: 4_level_0,Unnamed: 5_level_0,emission_0.1,pred_0.1,proba_NoCluster_0.1,proba_Cluster_0.1,proba_Centromere_0.1
chr,bin_start,bin_end,strand,species_test,region_binary,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1
chrUn_025319170,0,2500,plus,Dbia.GCF_018148935,0,0,0,0.991009,6.117517e-04,0.008379
chrUn_025319170,2500,4719,plus,Dbia.GCF_018148935,0,0,0,0.990990,6.105655e-04,0.008400
chrUn_025319171,0,2383,plus,Dbia.GCF_018148935,0,1,0,0.906055,2.893880e-02,0.065006
chrUn_025319172,0,2500,plus,Dbia.GCF_018148935,0,0,0,0.993884,2.513108e-04,0.005865
chrUn_025319172,2500,5000,plus,Dbia.GCF_018148935,0,0,0,0.993874,2.479901e-04,0.005878
...,...,...,...,...,...,...,...,...,...,...
contig_276,107500,110000,minus,Dtris.d101g,0,0,0,0.999995,2.400912e-07,0.000004
contig_276,110000,112500,minus,Dtris.d101g,0,0,0,0.999993,5.502535e-07,0.000006
contig_276,112500,115000,minus,Dtris.d101g,0,0,0,0.999990,1.322094e-06,0.000009
contig_276,115000,117500,minus,Dtris.d101g,0,0,0,0.999984,3.242824e-06,0.000013
