In [1]:
import os
import sys

import pandas as pd
import numpy as np
import sklearn
from hmmlearn import hmm
try: # version > 0.2.7
   from hmmlearn.hmm import CategoricalHMM as MultinomialHMM
except: # version <= 0.2.7
   from hmmlearn.hmm import MultinomialHMM
import matplotlib.pyplot as plt
from tqdm import tqdm
import seaborn as sns
from sklearn.model_selection import KFold
import pickle

sys.path.insert(0, '../scripts')
from flaHMM_functions import *

In [2]:
# Species to use for models
species_build=['Dyak.GCF_016746365',
              'Dsan.GCF_016746245',
              'Dsim.GCF_016746395',
              'Dmau.GCF_004382145',
              'Dmel.dm6',
              'Dsubp.GCF_014743375']

In [3]:
def train_combination_models(bin_size, threshold):
        species_train=[i for i in species_build]
        for threshold in [threshold]:
            threshold_num='threshold_'+str(threshold)
            
            get_emission_matrices={}
            for emission_matrix_file in os.listdir('../matrices/emissionprobs/'):
                if bin_size in emission_matrix_file:
                    if threshold_num in emission_matrix_file:
                        if emission_matrix_file.split('emissionmat_')[1].split('_threshold_')[0] in get_emission_matrices.keys():
                            continue
                        else:
                            if emission_matrix_file.split('emissionmat_')[1].split('_threshold_')[0] in species_train:
                                #print(emission_matrix_file)
                                get_emission_matrices[emission_matrix_file.split('emissionmat_')[1].split('_threshold_')[0]]=pd.read_csv('../matrices/emissionprobs/'+emission_matrix_file,sep='\t', index_col=0).values
                            else:
                                continue
            
            get_transition_matrices={}
            for transition_matrix_file in os.listdir('../matrices/transmats/'):
                if bin_size in transition_matrix_file:
                    if threshold_num in transition_matrix_file:
                        if transition_matrix_file.split('transmat_')[1].split('_threshold_')[0] in get_transition_matrices.keys():
                            continue
                        else:
                            if transition_matrix_file.split('transmat_')[1].split('_threshold_')[0] in species_train:
                                #print(transition_matrix_file)
                                get_transition_matrices[transition_matrix_file.split('transmat_')[1].split('_threshold_')[0]]=pd.read_csv('../matrices/transmats/'+transition_matrix_file,sep='\t', index_col=0).values
                            else:
                                continue
                                
            get_startprob_matrices={}
            for startprob_matrix_file in os.listdir('../matrices/startprobs/'):
                if bin_size in startprob_matrix_file:
                    if threshold_num in startprob_matrix_file:
                        if startprob_matrix_file.split('startprob_')[1].split('_threshold_')[0] in get_startprob_matrices.keys():
                            continue
                        else:
                            if startprob_matrix_file.split('startprob_')[1].split('_threshold_')[0] in species_train:
                                #print(startprob_matrix_file)
                                get_startprob_matrices[startprob_matrix_file.split('startprob_')[1].split('_threshold_')[0]]=pd.read_csv('../matrices/startprobs/'+startprob_matrix_file,sep='\t', index_col=0).values
                            else:
                                continue
            
            # Average across species
            emission_matrix=sum(get_emission_matrices.values())/len(get_emission_matrices.keys())
            transition_matrix=sum(get_transition_matrices.values())/len(get_transition_matrices.keys())
            starprob=sum(get_startprob_matrices.values())/len(get_startprob_matrices.keys())
            starprob = [item for row in starprob for item in row]
            
            #Train Model
            model=hmm.MultinomialHMM(n_components=3,n_iter=100000,random_state=13)
            model.n_features=3
            model.startprob_=starprob
            model.transmat_=transition_matrix
            model.emissionprob_=emission_matrix
            
            #Create folder if it doesn't exist    
            if not os.path.isdir('models_pkl'):
               os.mkdir('models_pkl')

            with open('models_pkl/Model_bin_'+bin_size+'k_threshold_'+str(threshold)+'.pkl', "wb") as file: pickle.dump(model, file)

## Create combination models and save them

In [4]:
for bin_size in tqdm(['10', '5', '2.5']):
    for threshold in [0.025, 0.05, 0.075, 0.1, 0.2]:
        train_combination_models(bin_size, threshold)

100%|██████████| 3/3 [00:00<00:00,  5.71it/s]


## Ensure that models can be loaded

In [5]:
with open("models_pkl/Model_bin_5k_threshold_0.075.pkl", "rb") as file: 
    model_test=pickle.load(file)

In [6]:
model_test.emissionprob_

array([[0.90565362, 0.0748475 , 0.01949888],
       [0.30424092, 0.6340643 , 0.06169479],
       [0.58739681, 0.30854445, 0.10405874]])