In [2]:
import tensorflow as tf
from tensorflow.contrib import rnn
import os
import numpy as np
from keras.utils import np_utils
from keras.preprocessing import sequence

Using TensorFlow backend.


# Data preparation

Load train and test data

In [3]:
def read_file(path):
    with open(path) as f:
        a = f.readlines()
    return a

In [4]:
def load_data(dir_path):
    data = []
    
    for f in set(list(map(lambda x: ".".join(x.split('.')[:-1]), os.listdir(dir_path)))):
        dssp = read_file(os.path.join(dir_path, f+".dssp"))[1].strip("\n")
        fasta = read_file(os.path.join(dir_path, f+".fasta"))[1].strip("\n")
        data.append([fasta, dssp])
        
    return np.array(data)

In [5]:
train = load_data("train/")

test = load_data("test/")

In [6]:
train[2]

array([ 'VSKDLDYISTANHDQPPRHLGSRFSAEGEFLPEPGNTVVCHLVEGSQTESAIVSTRQRFLDMPEASQLAFTPVSSLHMTVFQGVIESRRALPYWPQTLPLDTPIDAVTDYYRDRLSTFPTLPAFNMRVTGLRPVGMVMKGATAEDDSIVALWRDTFADFFGYRHPDHDTYEFHITLSYIVSWFEPECLPRWQAMLDEELEKLRVAAPVIQMRPPAFCEFKDMNHFKELVVFD',
       '--------E----------E---E-----E----EEEEEEEE----HHHHHHHHHHHHHH-------EEE------EEEEEEEEE------------------HHHHHHHHHHH-----------EEEEEEE--EEEEEE--HHHHHHHHHHHHHHHHHH-----------EEEE-EEE---E----HHHHHHHHHHHHHHHHHH-----E---EEEEE------EEEEE--'],
      dtype='<U759')

Tokenizing

In [7]:
#acid_table = list(np.unique(list("".join(train[:,0]))))

#class_table = list(np.unique(list("".join(train[:,1]))))

acid_table = list(np.load('acid.npz.npy'))
class_table = list(np.load('class.npz.npy'))

In [8]:
np.save('acid.npz', np.array(acid_table))
np.save('class.npz', np.array(class_table))

In [9]:
def translate_data(data, acid_table, class_table):
    
    acid_code = lambda x: np.array([acid_table.index(y) for y in x])
    class_code = lambda x: np.array([class_table.index(y) + 1 for y in x])
    
    return np.array(list(map(lambda x: [acid_code(x[0]), class_code(x[1])], data)))

In [10]:
train = translate_data(train, acid_table, class_table)

test = translate_data(test, acid_table, class_table)

In [11]:
train[0]

array([ array([ 3, 11,  9,  3,  2, 16,  4,  9,  8, 10, 17,  8, 17,  9, 14,  0,  3,
       14,  3,  4,  9,  3,  3,  3,  0, 13,  8,  9, 20,  8,  3, 17,  8,  8,
        5, 11,  1,  9,  2, 17,  8,  8,  9,  8,  3,  8, 12,  9,  0,  9, 13,
       14, 14, 17,  7, 14,  8,  4,  7,  5,  3,  8,  2, 20,  3,  8, 17,  3,
        9, 17, 14, 15,  9,  9,  3,  8,  5,  5,  3, 17, 11,  9,  5,  8,  5,
        8, 17,  9,  8, 14,  8,  3, 14, 18,  9]),
       array([1, 1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
       3, 3, 3, 3, 3, 3, 3, 3, 3, 2, 2, 1, 1, 2, 2, 1, 3, 3, 3, 3, 1, 1, 1,
       1, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 1, 1, 1, 1, 3, 3, 3, 3, 3, 3,
       3, 3, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2,
       2, 2, 1])], dtype=object)

Data utils

In [341]:

def split_seq(seq, maxlen, stride):
    seq_len = seq.shape[0]
    
    temp = []
    
    if seq_len < maxlen:
        temp.append(seq)
        
    else:
        full_batches = seq_len // maxlen

        for i in range(full_batches):
            idx = slice(i * maxlen, (i+1) * maxlen)
            temp.append(seq[idx])
        
        elements_left = seq_len % maxlen
        
        if elements_left:
            start_element = elements_left + stride
            
            if start_element > maxlen:
                start_element = maxlen
                
            temp.append(seq[-start_element:])
    return temp


In [352]:
test[0][0]

array([12,  9,  7, 17, 12, 20, 11,  9, 12,  9, 12,  5,  5, 17, 17, 12, 14,
       10,  9,  7, 16,  7,  9,  5, 16, 17,  8, 12, 11,  0, 11, 14,  7,  0,
        9,  2,  4, 13, 14,  5, 11,  2, 17,  0,  4,  6,  4, 11, 12, 14,  4,
       11,  3, 11, 11, 14, 14, 17,  7, 17,  1, 11, 16,  8,  9,  2, 11, 11,
       18,  5, 14,  3,  3, 14, 13, 15, 17,  4, 12,  4,  3, 15,  5,  8, 12,
        4,  8,  7, 13, 17,  9, 17,  3, 12,  2,  6,  4,  8, 17,  0, 17, 11,
        2,  0,  6,  9,  9, 13, 20, 11,  6, 14, 17,  8,  8,  9, 11,  3,  7,
       15,  8,  9,  5,  7, 15,  5,  2,  7,  2,  9, 16, 15,  0, 15, 20, 16,
       10,  7])

In [361]:

def split_pair_seq(pair_seq, maxlen, stride):
        
    assert pair_seq[0].shape == pair_seq[1].shape

    return list(map(list, zip(split_seq(pair_seq[0], maxlen, stride), split_seq(pair_seq[1], maxlen, stride))))


In [410]:

def build_prediction(part_predictions, original_length, stride):
    max_len = part_predictions.shape[1]
    
    final_prediction = []
    
    if len(part_predictions) != 1:
        final_prediction = np.concatenate(part_predictions[:-1], axis=0)

    el_left = original_length % max_len
    start_element = stride
    
    if el_left + stride > max_len:
        start_element = max_len - el_left

    final_part_prediction = part_predictions[-1][start_element:start_element + el_left]

    final_prediction = np.concatenate([final_prediction, final_part_prediction], axis=0)
    
    return final_prediction


In [457]:

def predict(x, session, method, maxlen=400, stride=100):
    prepared = []; idx = []; original_len = []; model_len = []
    
    curr_idx = 0
    
    for seq in x:
        original_len.append(len(seq))
        
        splited = split_seq(seq, maxlen, stride)
        
        model_len.extend(list(map(len, splited)))
        splited = sequence.pad_sequences(splited, maxlen, value=-1, padding="post")
        
        prepared.extend(splited)
        idx.append([curr_idx, curr_idx+len(splited)])
        curr_idx += len(splited)
    
    predictions = session.run(method, feed_dict={
        x_raw: np.array(prepared),
        seq_len: np.array(model_len)
    })
    
    final_predictions = []
    for i, seq_idx in enumerate(idx):
        seq_predictions = predictions[slice(seq_idx[0], seq_idx[1])]

        final_predictions.append(build_prediction(seq_predictions, original_len[i], stride))
        
    return np.array(list(map(lambda x:make_pred(x, len(x), class_table), final_predictions)))


In [None]:
def score(x, y, session, method, maxlen=400, stride=100):
    

In [397]:
s = tf.InteractiveSession(); s.run(tf.global_variables_initializer())

In [458]:
c = predict(test[:10, 0], s, pred)

In [446]:

def to_normal(seq, seq_len, table):
    #print(seq)
    return "".join(list(map(lambda x: table[int(x)], seq)))[:seq_len]


In [14]:

def make_pred(pred_raw, seq_len, table):
    pred = pred_raw - 1
    pred = np.array(list(map(lambda x: 0 if x == -1 else x, pred)))
    return to_normal(pred, seq_len, table)


In [267]:

def dynamic_iter(data, batchsize, maxlen=300, stride=100, shuffle=True):
    
    splited = list(map(lambda x: split_pair_seq(x, maxlen, stride), data))
    
    prep_data = []
    
    for seq in splited:
        prep_data.extend(seq)
    
    prep_data = np.array(prep_data)
    
    # Batching
    index = list(range(len(prep_data)))
    
    if shuffle:
        np.random.shuffle(index)
    for i in range(0, len(prep_data) - batchsize + 1, batchsize):
        
        x = prep_data[index[i:i+batchsize], 0]
        y = prep_data[index[i:i+batchsize], 1]

        seq_len = np.array(list(map(lambda z: 400 if len(z) > 400 else len(z), x)), dtype="int32")
        
        x = sequence.pad_sequences(x, padding="post", value=-1, maxlen=maxlen)
        y = sequence.pad_sequences(y, padding="post", maxlen=maxlen)
        
        yield x, y, seq_len
        

# Model

In [15]:
input_class = len(acid_table)
output_class = len(class_table) + 1 # plus one for padding

learning_rate = 0.01
seq_max_len = 400
n_units = 64

In [16]:
x_raw = tf.placeholder(tf.int32, [None, seq_max_len], name="x_raw")
y_raw = tf.placeholder(tf.int32, [None, seq_max_len], name="y_raw")
seq_len = tf.placeholder(tf.int32, [None], name="seq_len")

x = tf.one_hot(x_raw, input_class, dtype=tf.float32, name="one_hot")

In [269]:
for a, b, c in dynamic_iter(test, 10):
    print(a[0])

[ 6  6  6  6 10 17 10  3 20  3  9 14 16 12  9 17  8  2 13  7  9  8  9  8 17
  5  2 17 17 20  7 16  5  3  7  4 16  0 14  2  3  0  6  0 14  0  9  3 18 10
  3  3  5  8  3  9 12  4 15  4  2  8  5 17 17 20  6  1  5 12  9 17  8  8 11
  2  3 18 14 17 17 15  0  5 12 16 16 15  0 14 10 11 12  4 16 12  8  7  9  3
  8 17  3  1 10  5  7  7  5  8  5  5 10 15  3  3 17 17  3  0 10 14  5  8  0
  0 20  4  0  4 16  5  5  0  5  0  9  0  0 10 15  7  8  8 17  8  5 17 17 18
  3  2  9  5 10 12  3  0 17 18  9  9  3 17  3 14  4  5 12  1  7 17  0  7  2
  0  6  5 11 15  9 20 14 14 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1
 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1 -1]
[16  3  3 13  8 14 10  9  9  9 12  0  3  0 17 11 17 15 14  8  5  3  4 16  9
  8 17  5  

In [17]:
def lstm_cell(reuse=tf.get_variable_scope().reuse):
    cell = tf.contrib.rnn.NASCell(n_units, reuse=reuse)
    return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=0.9)

In [18]:
fw_cell = lstm_cell(); bw_cell = lstm_cell()

output, state = tf.nn.bidirectional_dynamic_rnn(fw_cell, bw_cell, x, seq_len, dtype=tf.float32)

In [19]:
con_out = tf.concat(output, 2)

In [20]:
flatten = tf.reshape(con_out, [-1, 2 * n_units * seq_max_len])

In [21]:
weights = {
    "out" : tf.Variable(tf.truncated_normal(shape=[2 * n_units * seq_max_len, output_class * seq_max_len], stddev=0.1))
}
bias = {
    "out" : tf.Variable(tf.truncated_normal(shape=[output_class * seq_max_len]))
}

In [22]:
flat_logits = tf.matmul(flatten, weights["out"]) + bias["out"]

In [23]:
logits = tf.reshape(flat_logits, shape=[-1, seq_max_len, output_class])

In [24]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y_raw, logits=logits)

In [25]:
mask = tf.to_float(tf.not_equal(y_raw, 0))

In [26]:
norm_loss = loss * mask

In [27]:
mean_loss = tf.reduce_mean(tf.reduce_sum(norm_loss) / tf.to_float(seq_len))

In [56]:
pred = tf.arg_max(tf.nn.softmax(logits), 2)

In [29]:
train_step = tf.train.AdamOptimizer(learning_rate).minimize(mean_loss)

In [75]:
num_epochs = 10
batch_size = 50
disp_batch = 10

In [77]:
with tf.Session() as sess:
    
    sess.run(tf.global_variables_initializer())
    
    
    for e in range(num_epochs):
        avg_loss = 0.; num_iter = 0
        for x_batch, y_batch, len_batch in dynamic_iter(train, batch_size, maxlen=seq_max_len):
            
            _, c = sess.run([train_step, mean_loss], feed_dict={
                x_raw: x_batch,
                y_raw: y_batch,
                seq_len: len_batch
            })
            
            num_iter += 1; avg_loss += c
            
            if num_iter % disp_batch == 0:
                
                prediction = sess.run(pred, feed_dict={
            
                x_raw: x_batch[:1], y_raw: y_batch[:1], seq_len: len_batch[:1] 
                    
                })
                
                print("Batch {}".format(num_iter))
                print("Example: ")

                print("Real:      " + make_pred(y_batch[0], len_batch[0], class_table))
                print("Predicted: " + make_pred(prediction[0], len_batch[0], class_table))
                
                print("Loss: " + str(c))
                
        print("Epoch {} done! Average loss: ".format(str(e)) + str(avg_loss / num_iter))

Batch 10
Example: 
Real: HHHHHHHHHHHHHHHHH---HHHHHHH----HHHHHHH-----E----E----------HHHHHHHEEEEE---------EEEEE-----EEEE-HHHHHHH--------HHHHHHHHHHHHHH-E-HHHHHHHHHHHHHHHH------------HHHHHHHHHHHH---HHHHHHH---HHHHHHHH-------EE--EE---------HHHHHHHHH------------HHHHH-------------HHHHHHHHHHHHH--HHHHH----HHHHHH-----HHHHHHHHHH-----HHHHHHHHHHHHHHHHHHHHHH--------------HHHH--------------HHHHHHHHHHHHHHH-E-----E-HHHHHH-----HHHHHH------
Predicted: ---HEHHHH-H-H-HH--HH----E------HHHHH---H-------H-HH-HHHH-EHHHHE---E-------H-H-H------------HH---HHHH---H--------------EE-----HHHHHHHEH------EEE-H------------HHEEEEEE---E-E---------EEHHHHH---H-HHHH--HH--H-E-----HHHHHHHHHHEEH----------H-H--E--H-HHHHHHH-------HH-HHHHHHHHHHHH---EEEHE-----------HH--H---------E-E-----EE-----HHHH-EE--EE---E---EE-------HHHHHHHHHHH-EH--------------------HHHHHHHHHHHHHHH----
Loss: 85.3606
Batch 20
Example: 
Real: ------------EEEE------EEEEEE----EEEEE-----EEEEEEE--EEE----EEEEE--EEEEEE----EEEEEEE
Predicted: -----EHH-HHH-HHHH--HHHHHH--HH

KeyboardInterrupt: 