In [1]:
import keras
import numpy as np
import pickle
from sequence_attention import SeqAttModel, preprocess_data, preprocess_data_pickle
from sequence_attention import DataGenerator, DataGeneratorUnlabeled, DataGeneratorPickle, DataGeneratorUnlabeledPickle
from config import Config

Using TensorFlow backend.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])


# Preprocess the data

There are three different ways to prepare the data: 

2. convert each fna file to a pickle file with a python dictionary data structure (key is the unique sequence identifier and value is the actual sequence). Then the data generator can load a pickle file and look up for a read and construct a batch of data.
3. The fastest way is, of course, training and testing your model without using the data generator. Instead, permitted by the computer memory, users can load all the training data in memory and directly fit the model. In this way, there will be no additional I/O process. 

When designing this tool, we don't make any assumption on the computer memory at the user's disposal, therefore, this demo focuses on the second preprocessing method since it covers the most use cases. But if the users do have enough memory, they are more than welcome to fit the model without the generator function.

In [2]:
# cargar a configuración
opt = Config()

# pre-procesar dos datos
preprocess_data_pickle(opt)

02-Jun-22 15:34:32 - Processing raw data: 0.0% completed.
02-Jun-22 15:34:32 - Processing raw data: 10.0% completed.
02-Jun-22 15:34:32 - Processing raw data: 20.0% completed.
02-Jun-22 15:34:33 - Processing raw data: 30.0% completed.
02-Jun-22 15:34:33 - Processing raw data: 40.0% completed.
02-Jun-22 15:34:33 - Processing raw data: 50.0% completed.
02-Jun-22 15:34:34 - Processing raw data: 60.0% completed.
02-Jun-22 15:34:34 - Processing raw data: 70.0% completed.
02-Jun-22 15:34:35 - Processing raw data: 80.0% completed.
02-Jun-22 15:34:35 - Processing raw data: 90.0% completed.
02-Jun-22 15:34:36 - Processing raw data: 100.0% completed.
02-Jun-22 15:34:37 - Processing raw data: 110.0% completed.


# Load metadata and initialize the deep learning model

In [3]:
# cargar os pickle cos datos pre-procesados no paso anterior
label_dict = pickle.load(open('{}/label_dict.pkl'.format(opt.out_dir), 'rb')) 
sample_to_label, read_meta_data = pickle.load(open('{}/meta_data.pkl'.format(opt.out_dir), 'rb'))
#partition = pickle.load(open('{}/train_test_split.pkl'.format(opt.out_dir), 'rb'))

In [None]:
label_dict

In [None]:
sample_to_label

In [None]:
read_meta_data

In [None]:
# crear o modelo
seq_att_model = SeqAttModel(opt)
seq_att_model.model.summary() # resumo do modelo (capas de cada tipo, etc)

# Train and evaluate the model
###### Prepare the data generator

In [4]:
# one-hot encoding das secuencias
generator = DataGeneratorPickle(partition, sample_to_label, label_dict, 
                                   dim=(opt.SEQLEN,opt.BASENUM), batch_size=opt.batch_size, shuffle=opt.shuffle)

In [7]:
# adestrar batch a batch se os datos non caben en memoria

# X = sequences one-hot encoded, y = sequences' labels
X, y = training_generator.__getitem__(0) # Generate one batch of data

#X.shape  # 1359 batch size (training sequences in the batch), 100 sequence len, 4 bases (unique characters of your data) # sequence len: trim all your reads to the same length or pad zeros to the end so that all the reads are in the same length.
          # PRUEBA  ->  12 batch size, 100 seq len, 4 bases

#y.shape  # 1359 labels for the training data in the batch
          # PRUEBA  ->  12 

# ---------------------------------------------

#X.shape  # 1024 batch size (training sequences in the batch), 100 sequence len, 4 bases (unique characters of your data) # sequence len: trim all your reads to the same length or pad zeros to the end so that all the reads are in the same length.
          # PRUEBA  ->  6 batch size, 100 seq len, 4 bases

#y.shape  # 1024 labels for the training data in the batch
          # PRUEBA  ->  6 

In [None]:
# NO
# adestrar o modelo batch a batch e avaliar a acurracy do train set
seq_att_model.train_generator(training_generator, n_workers=opt.n_workers)

21-May-22 15:49:32 - Training started:


Epoch 1/1
 3643/78927 [>.............................] - ETA: 14:36:16 - loss: 0.5809 - acc: 0.6846

In [11]:
# NO
# avaliar a acurracy do test set
seq_att_model.evaluate_generator(testing_generator, n_workers=opt.n_workers)

In [None]:
# axustar os datos ao modelo
seq_att_model.model.fit(X, y, batch_size=opt.batch_size, epochs=opt.epochs)

###### Training without the generator (if you have enough memmory)

# Model interpretation and sequence visualization
This step is exploratory and completely depends on you. To get started, please review the following data requirements:

1. Prepare the X_visual (N by SEQ_LEN by NUMBASE) in *numpy array*, see also ***Note: training without the generator (1)***
2. y_visual (phenotypic labels in integers) in *numpy array*, use `label_dict` as the label to integer map.
3. a list of taxonomic labels of those sequences (e.g., genus level labels as python strings). 

Once you have the data ready, run the following commands to plot embedding and attention weights visualization figures.

###### Note: we provide a toy example of how it works and what results it produces in the appendix section below.

In [None]:
# extract sequence attention weigths and sequence embedding from the model for input sequences, X.
prediction, attention_weights, sequence_embedding = seq_att_model.extract_weigths(X_visual)

from sequence_attention import SeqVisualUnit
idx_to_label = {label_dict[label]: label for label in label_dict}
seq_visual_unit = SeqVisualUnit(X_visual, y_visual, idx_to_label, taxa_label_list, 
                                prediction, attention_weights, sequence_embedding, 'Figures')

seq_visual_unit.plot_embedding()
seq_visual_unit.plot_attention('CD')

In the code snippet above, we also need , *label_dict* (phenotypic labels to integer dictionary saved in `opt.out_dir/label_dict.pkl` by the previous steps).

In [None]:
############# PRUEBAS

