In [None]:
import sys, os, tarfile
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import tensorflow.keras as keras
from tiberius.write_tfrecord_species import get_species_data_hmm
from tiberius.models import lstm_model, custom_cce_f1_loss, add_hmm_layer

In [None]:
batch_size = 20
seq_len = 9999
strand = '+'

hmm_factor=1

# extract test_data if necassary
inp_data_dir = 'inp/'
if not os.path.exists(inp_data_dir):
    os.mkdir(inp_data_dir)  
    with tarfile.open("inp.tar.gz", "r:gz") as tar:
        tar.extractall(path=inp_data_dir)

out_dir = 'test_train/'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)

genome_path = f'{inp_data_dir}/genome.fa'
annot_path= f'{inp_data_dir}/annot.gtf'

# load input data x_seq 
x_seq, y_seq = get_species_data_hmm(annot_path=annot_path, genome_path=genome_path,
    seq_len=seq_len, transition=True)

In [None]:
# see lstm_model documentation for more arguments 
config = {
    "num_epochs": 10,
    "stride": 0,
    "units": 100,
    "filter_size": 32,
    "numb_lstm": 2,
    "numb_conv": 3,
    "dropout_rate": 0.0,
    "pool_size": 9,
    "lr": 1e-4,
    "batch_size": batch_size,
    "w_size": seq_len,
    'output_size': 15,
    'hmm_share_intron_parameters': False,
    'hmm_nucleotides_at_exons': False,
    'hmm_trainable_transitions': False,
    'hmm_trainable_starting_distribution': False,
    'hmm_trainable_emissions': False, 
    'hmm_factor': 99,    
    "loss_f1_factor": 2.0,
}

relevant_keys = ['units', 'filter_size', 'kernel_size', 
                     'numb_conv', 'numb_lstm', 'dropout_rate', 
                     'pool_size', 'stride',  
                     'output_size', 'multi_loss']

relevant_args = {key: config[key] for key in relevant_keys if key in config}
model = lstm_model(**relevant_args)
model = add_hmm_layer(model,  
                output_size=config['output_size'], 
                num_hmm=1,
                hmm_factor=config['hmm_factor'], 
                share_intron_parameters=config['hmm_share_intron_parameters'],
                trainable_nucleotides_at_exons=config['hmm_nucleotides_at_exons'],
                trainable_emissions=config['hmm_trainable_emissions'],
                trainable_transitions=config['hmm_trainable_transitions'],
                trainable_starting_distribution=config['hmm_trainable_starting_distribution'],
                include_lstm_in_output=False,)
              
adam = Adam(learning_rate=config["lr"])
f1loss = custom_cce_f1_loss(config["loss_f1_factor"], batch_size=config["batch_size"], from_logits=True)

model.compile(loss=f1loss, optimizer=adam, metrics=['accuracy'])
model.summary()


In [None]:
model.fit(x=x_seq, y=y_seq, 
          epochs=config["num_epochs"], 
          batch_size=config["batch_size"])

In [None]:
if tf_version > '2.12':
    model.save(f"{out_dir}/test_train_hmm.keras")
else:
    model.save(f"{out_dir}/test_train_hmm", save_traces=False)

# the trained model can be used with tiberius using the --model_old option
