In [1]:
from csv import DictReader
from random import Random
from collections import Counter
import numpy as np
from matplotlib import pyplot as plt
import pandas as pd
from tqdm import tqdm
from keras.utils.np_utils import to_categorical
import pickle as pkl
import warnings
warnings.filterwarnings('ignore')
import os

In [2]:
RANDOM = Random(42)

## Load Data

In [4]:
def train_val_split(samples, split=0.2):
    RANDOM.shuffle(samples)
    n_val = int(len(samples) * split)
    return samples[:-n_val], samples[-n_val:]

In [5]:
import pandas as pd
dat_prot = pd.read_table('/Users/suhancho/data/Uniprot_metalbinding_challenge/sequence_df.tsv')
chebi = pd.read_table('/Users/suhancho/data/Uniprot_metalbinding_challenge/POS_TRAIN_FULL.tsv')

In [6]:
dat_prot_bindsite = pd.merge(dat_prot,chebi,left_on = 'Protein name',right_on='Accession',how = 'outer')
neg_seqs = dat_prot_bindsite[dat_prot_bindsite.Type=='neg_sequence']
neg_seqs['Position'] = [int(divmod(len(p),2)[0]) for p in neg_seqs['Protein sequence'].tolist()]
neg_seqs['ChEBI-ID'] = 'NB'
dat_prot_bindsite = pd.concat([dat_prot_bindsite[dat_prot_bindsite['Type']=='pos_sequence'],neg_seqs]).reset_index(drop = True)
dat_prot_bindsite['ChEBI-ID'] = pd.Categorical(dat_prot_bindsite['ChEBI-ID'])
dat_prot_bindsite['target'] = dat_prot_bindsite['ChEBI-ID'].cat.codes
dat_prot_bindsite['Position'] = dat_prot_bindsite['Position'].astype(int)
dat_prot_bindsite_sampled = dat_prot_bindsite.sample(frac=1,random_state=9510)
# dat_prot_bindsite_sampled.to_csv('/Users/suhancho/data/Uniprot_metalbinding_challenge/data_before_windowing.tsv',sep='\t')

In [7]:
def get_windowdf(protfile,winsize=4):
    window=[]
    with open(protfile,'r') as p:
        for line in tqdm(p):
            if not line.count('Protein sequence'):
                protseq = line.split('\t')[2].strip()
                bindsite = int(line.split('\t')[-2].strip())

                if ((bindsite-winsize)<0 )& ((bindsite+winsize) > len(protseq)) :
                    front_pad_seq = 'X'*(winsize-bindsite)
                    bindseq = protseq[bindsite-winsize : bindsite+winsize]
                    rear_pad_seq = 'X'*((winsize+bindsite)-len(protseq))
                    windowed = front_pad_seq+bindseq+rear_pad_seq

                elif ((bindsite-winsize)<0) & ((bindsite+winsize) < len(protseq)):
                    front_pad_seq = 'X'*(winsize-bindsite)
                    bindseq = protseq[bindsite-winsize : bindsite+winsize]
                    windowed = front_pad_seq+bindseq

                elif ((bindsite-winsize)>0) & ((bindsite+winsize) > len(protseq)):
                    rear_pad_seq = 'X'*((winsize+bindsite)-len(protseq))
                    bindseq = protseq[bindsite-winsize : bindsite+winsize]
                    windowed = bindseq+rear_pad_seq

                else:
                    windowed = protseq[bindsite-winsize : bindsite+winsize]
                    
                window.append(windowed)
    return(window)

In [8]:
test = get_windowdf('/Users/suhancho/data/Uniprot_metalbinding_challenge/data_before_windowing.tsv',winsize = 4)
dat_prot_bindsite_sampled['window_4'] = test
dat_prot_bindsite_sampled = dat_prot_bindsite_sampled.reset_index()
dat_prot_bindsite_sampled.rename(columns = {'index':'qid','window_4':'question_text'},inplace = True)
dat_prot_bindsite_sampled['qid'] = 'HASH_'+dat_prot_bindsite_sampled['qid'].astype(str)
one_hot_labels = to_categorical(dat_prot_bindsite_sampled['target'].tolist())
dat_prot_bindsite_sampled['target'] = list(one_hot_labels)
samples=dat_prot_bindsite_sampled[['qid','question_text','target']].to_dict('records')
train_samples, val_samples = train_val_split(samples)

603123it [00:01, 548661.26it/s]


## Preprocessing

In [9]:
def build_vocabulary(samples, vocab_min_freq=100):
    counts = Counter(ch for sample in samples for ch in sample['question_text'])
    chars = sorted(ch for ch, count in counts.items() if count >= vocab_min_freq)
    return {char: i for i, char in enumerate(chars)}

In [10]:
vocabulary = build_vocabulary(train_samples)
# print(len(vocabulary))
# vocabulary

In [11]:
def transform(sample, vocabulary):
    sample['encoded_text'] = [vocabulary[ch] for ch in sample['question_text'] if ch in vocabulary]
    return sample

In [12]:
train_samples = [transform(sample, vocabulary) for sample in train_samples]
val_samples = [transform(sample, vocabulary) for sample in val_samples]

## Modeling

In [13]:
import keras
from keras import backend as K
from keras import initializers, regularizers, constraints
# from keras.engine import Layer
# from tensorflow.keras.layers import Layer
def dot_product(x, kernel):
    """
    Wrapper for dot product operation, in order to be compatible with both
    Theano and Tensorflow
    Args:
        x (): input
        kernel (): weights
    Returns:
    """
    if K.backend() == 'tensorflow':
        return K.squeeze(K.dot(x, K.expand_dims(kernel)), axis=-1)
    else:
        return K.dot(x, kernel)
# from keras.layers import InputLayer, Input

class AttentionWithContext(keras.layers.Layer):
    """
    Attention operation, with a context/query vector, for temporal data.
    Supports Masking.
    Follows the work of Yang et al. [https://www.cs.cmu.edu/~diyiy/docs/naacl16.pdf]
    "Hierarchical Attention Networks for Document Classification"
    by using a context vector to assist the attention
    # Input shape
        3D tensor with shape: `(samples, steps, features)`.
    # Output shape
        2D tensor with shape: `(samples, features)`.

    How to use:
    Just put it on top of an RNN Layer (GRU/LSTM/SimpleRNN) with return_sequences=True.
    The dimensions are inferred based on the output shape of the RNN.

    Note: The layer has been tested with Keras 2.0.6

    Example:
        model.add(LSTM(64, return_sequences=True))
        model.add(AttentionWithContext())
        # next add a Dense layer (for classification/regression) or whatever...
    """

    def __init__(self,
                 W_regularizer=None, u_regularizer=None, b_regularizer=None,
                 W_constraint=None, u_constraint=None, b_constraint=None,
                 bias=True, **kwargs):

        self.supports_masking = True
        self.init = initializers.get('glorot_uniform')

        self.W_regularizer = regularizers.get(W_regularizer)
        self.u_regularizer = regularizers.get(u_regularizer)
        self.b_regularizer = regularizers.get(b_regularizer)

        self.W_constraint = constraints.get(W_constraint)
        self.u_constraint = constraints.get(u_constraint)
        self.b_constraint = constraints.get(b_constraint)

        self.bias = bias
        super(AttentionWithContext, self).__init__(**kwargs)

    def build(self, input_shape):
        assert len(input_shape) == 3

        self.W = self.add_weight(shape = (input_shape[-1], input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_W'.format(self.name),
                                 regularizer=self.W_regularizer,
                                 constraint=self.W_constraint)
        if self.bias:
            self.b = self.add_weight(shape = (input_shape[-1],),
                                     initializer='zero',
                                     name='{}_b'.format(self.name),
                                     regularizer=self.b_regularizer,
                                     constraint=self.b_constraint)

        self.u = self.add_weight(shape = (input_shape[-1],),
                                 initializer=self.init,
                                 name='{}_u'.format(self.name),
                                 regularizer=self.u_regularizer,
                                 constraint=self.u_constraint)

        super(AttentionWithContext, self).build(input_shape)

    def compute_mask(self, input, input_mask=None):
        # do not pass the mask to the next layers
        return None

    def call(self, x, mask=None):
        uit = dot_product(x, self.W)

        if self.bias:
            uit += self.b

        uit = K.tanh(uit)
        ait = dot_product(uit, self.u)

        a = K.exp(ait)

        # apply mask after the exp. will be re-normalized next
        if mask is not None:
            # Cast the mask to floatX to avoid float64 upcasting in theano
            a *= K.cast(mask, K.floatx())

        # in some cases especially in the early stages of training the sum may be almost zero
        # and this results in NaN's. A workaround is to add a very small positive number ε to the sum.
        # a /= K.cast(K.sum(a, axis=1, keepdims=True), K.floatx())
        a /= K.cast(K.sum(a, axis=1, keepdims=True) + K.epsilon(), K.floatx())

        a = K.expand_dims(a)
        weighted_input = x * a
        return K.sum(weighted_input, axis=1)

    def compute_output_shape(self, input_shape):
        return input_shape[0], input_shape[-1]


In [14]:
from keras import models, Model
from keras.layers import Input, Embedding, Conv1D, Add, Dense, SpatialDropout1D

def _conv_block(x, filters, kernel_size):
    conv = Conv1D(filters, kernel_size, activation='relu', padding='same')(x)
    conv = Conv1D(filters, kernel_size, activation='relu', padding='same')(conv)
    return conv


def _resblock(x, filters, kernel_size):
    conv = _conv_block(x, filters, kernel_size)
    projection = Conv1D(filters, 1, padding='same')(x)
    return Add()([conv, projection])


def predict_bindsite(vocab_size,
                    char_embedding_size,
                    base_filters,
                    doc_embedding_size,
                    dropout):
    text = Input(shape=(None,))
    embedding = Embedding(vocab_size, char_embedding_size)(text)

    conv_1 = _resblock(embedding, base_filters, 3)
    conv_1 = SpatialDropout1D(dropout)(conv_1)
    conv_2 = _resblock(conv_1, base_filters * 2, 3)
    conv_2 = SpatialDropout1D(dropout)(conv_2)
    conv_3 = _resblock(conv_2, base_filters * 4, 3)
    conv_3 = SpatialDropout1D(dropout)(conv_3)
    conv_4 = _resblock(conv_3, base_filters * 8, 3)
    conv_4 = SpatialDropout1D(dropout)(conv_4)

    attention = AttentionWithContext()(conv_4)

    fc_1 = Dense(doc_embedding_size, activation='relu')(attention)
    fc_2 = Dense(doc_embedding_size, activation='relu')(fc_1)
    # prediction = Dense(1, activation='sigmoid')(fc_2)
    # prediction = Dense(29, activation='sigmoid')(fc_2)
    prediction = Dense(30, activation='sigmoid')(fc_2) # Originally sigmoid

    model = Model(text, prediction)
    # model.compile('adam', 'binary_crossentropy', metrics=['acc'])
    model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])


    return model

## Training Support

In [15]:
from math import ceil

class BatchProvider:
    def __init__(self, samples, batch_size, shuffle=False, run_forever=False):
        self._samples = samples
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._run_forever = run_forever

    def generate_batches(self):
        batch = []
        indices = list(range(len(self._samples)))
        while True:
            if self._shuffle:
                RANDOM.shuffle(indices)
            for i in indices:
                batch.append(self.get_item(i))
                if len(batch) == self._batch_size:
                    yield self.transform_batch(batch)
                    batch = []
            if not self._run_forever:
                break
        if batch:
            yield self.transform_batch(batch)

    def __len__(self):
        return int(ceil(len(self._samples) / self._batch_size))

    def get_item(self, idx):
        sample = self._samples[idx]
        return sample['encoded_text'], sample['target']

    def transform_batch(self, items):
        texts, targets = zip(*items)
        max_length = max(len(text) for text in texts)
        text_batch = np.zeros((len(texts), max_length))
        for i, text in enumerate(texts):
            text_batch[i, :len(text)] = text
        target_batch = np.array(targets)
        return text_batch, target_batch

## Inference Support

In [16]:
from math import ceil

class TestBatchProvider:
    def __init__(self, samples, batch_size, shuffle=False, run_forever=False):
        self._samples = samples
        self._batch_size = batch_size
        self._shuffle = shuffle
        self._run_forever = run_forever

    def generate_batches(self):
        batch = []
        indices = list(range(len(self._samples)))
        while True:
            if self._shuffle:
                RANDOM.shuffle(indices)
            for i in indices:
                batch.append(self.get_item(i))
                if len(batch) == self._batch_size:
                    yield self.transform_batch(batch)
                    batch = []
            if not self._run_forever:
                break
        if batch:
            yield self.transform_batch(batch)

    def __len__(self):
        return int(ceil(len(self._samples) / self._batch_size))

    def get_item(self, idx):
        sample = self._samples[idx]
        return sample['encoded_text']

    def transform_batch(self, items):
        texts= items
        max_length = max(len(text) for text in texts)
        text_batch = np.zeros((len(texts), max_length))
        for i, text in enumerate(texts):
            text_batch[i, :len(text)] = text
        return text_batch

## Training

In [17]:
model = predict_bindsite(
    vocab_size=len(vocabulary),
    char_embedding_size=16,
    base_filters=32,#32
    doc_embedding_size=300,
    dropout=0.1
)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, None)]       0           []                               
                                                                                                  
 embedding (Embedding)          (None, None, 16)     336         ['input_1[0][0]']                
                                                                                                  
 conv1d (Conv1D)                (None, None, 32)     1568        ['embedding[0][0]']              
                                                                                                  
 conv1d_1 (Conv1D)              (None, None, 32)     3104        ['conv1d[0][0]']                 
                                                                                              

In [18]:
train_batch_provider = BatchProvider(train_samples, batch_size=4096, shuffle=True, run_forever=True)
train_batches = train_batch_provider.generate_batches()
val_batch_provider = BatchProvider(val_samples, batch_size=4096, shuffle=False, run_forever=True)
val_batches = val_batch_provider.generate_batches()

In [20]:
from keras.callbacks import ModelCheckpoint, EarlyStopping
EPOCHS = 100
model_checkpoint_callback = ModelCheckpoint('./weights.hdf5',monitor='val_acc',save_best_only=True)
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=5)

model.fit_generator(
        generator=train_batches,
        steps_per_epoch=len(train_batch_provider),
        epochs=EPOCHS,
        validation_data=val_batches,
        validation_steps=len(val_batch_provider),
        callbacks= [model_checkpoint_callback,early_stopping_callback]
    )

Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100


<keras.callbacks.History at 0x3735d9bb0>

In [21]:
model.save('/Users/suhancho/script/uniprot/IonBind/saved_models/Ionbind_221214_9mer_epoch100')
model = keras.models.load_model("/Users/suhancho/script/uniprot/IonBind/saved_models/Ionbind_221214_9mer_epoch100")



INFO:tensorflow:Assets written to: /Users/suhancho/script/uniprot/IonBind/saved_models/Ionbind_221214_9mer_epoch100/assets


INFO:tensorflow:Assets written to: /Users/suhancho/script/uniprot/IonBind/saved_models/Ionbind_221214_9mer_epoch100/assets


## Calibration

In [22]:
from sklearn.metrics import classification_report

val_batch_provider = BatchProvider(val_samples, batch_size=2056, shuffle=False, run_forever=False)
val_predictions = model.predict_generator(val_batch_provider.generate_batches(), steps=len(val_batch_provider))
val_targets = np.array([sample['target'] for sample in val_samples])
val_prediction_argmax = [np.argmax(p) for p in val_predictions]
val_targets_argmax = [np.argmax(p) for p in val_targets]

report = classification_report(val_targets_argmax,val_prediction_argmax,output_dict = True)
report = pd.DataFrame(report).T

# Target ion information processing and result processing

In [23]:
target_ion_info = dat_prot_bindsite_sampled[['target','ChEBI-ID']]
target_ion_info['target'] = target_ion_info['target'].apply(np.argmax).astype(int)
target_ion_info.drop_duplicates(inplace = True)

In [31]:
report = classification_report(val_targets_argmax,val_prediction_argmax,output_dict = True)
report = pd.DataFrame(report)
report_class = report.drop(['accuracy','macro avg','weighted avg'],axis=1).T
report_class = report_class.reset_index()
report_class['index'] = report_class['index'].astype(int)
report_class = pd.merge(report_class,target_ion_info,left_on = 'index',right_on = 'target').drop('index',axis=1)
report_class = report_class[['target','ChEBI-ID','precision','recall','f1-score','support']]

In [33]:
report_class.to_csv("/Users/suhancho/script/uniprot/Ionbind/results/result.epoch1000.tsv",sep='\t',index = None)

## Submission

In [None]:
from keras.utils import GeneratorEnqueuer
import csv

# Inference

In [None]:
inference_result=[]
with open('/Users/suhancho/data/Uniprot_metalbinding_challenge/test_sequence_df.tsv','r') as testfile:
    for t in tqdm(testfile):
        if not t.count('Protein'):
            test_input_lst=[]
            testseq_tmp = t.split('\t')[1].strip()
            for i in list(range(len(testseq_tmp))):
                test_input={} # Generate Dictionary for prediciton
                testprot_tmp = 'HASH_'+str(t.split('\t')[2].strip())+'.'+str(i)
                if i-4<0:
                    windowed_tmp = 'X'*(4-i)+testseq_tmp[0:i+5]
                elif i+5>len(testprot_tmp):
                    windowed_tmp = testseq_tmp[i-4:i+5]+'X'*(i+5-len(testseq_tmp)+1)
                else : 
                    windowed_tmp = testseq_tmp[i-4:i+5]

                test_input['qid'] = testprot_tmp
                test_input['question_text'] = windowed_tmp
                test_input_lst.append(test_input)

            test_input_transformed = [transform(sample, vocabulary) for sample in test_input_lst]
            sample_ids = (sample['qid'] for sample in test_input_lst)
            test_batch_provider = TestBatchProvider(test_input_lst, batch_size=len(test_input_lst), shuffle=False, run_forever=False)
            enqueuer = GeneratorEnqueuer(test_batch_provider.generate_batches())
            enqueuer.start()
            test_batches = enqueuer.get()
            for batch in test_batches:
                test_predictions = model.predict_on_batch(batch)
                test_prediction_argmax = [np.argmax(p) for p in test_predictions]
                test_prediction_proba = [prob[idx] for prob,idx in zip(test_predictions,test_prediction_argmax)]
                inference_result.append([testprot_tmp.split('.')[0].lstrip('HASH_'),test_prediction_argmax,test_prediction_proba])

# inference_result = pd.DataFrame(inference_result,columns =['Protein_ID','Ion','Binding Score'])
# inference_result['protein length'] = [len(inference_result.Ion.tolist()[i]) for i in range(len(inference_result))]

In [None]:
import gc
import tensorflow as tf

In [None]:
def inference_fasta(infile):
    model = keras.models.load_model("Saved_Ionbind_NLP_221210")
    outpath='/Users/suhancho/data/Uniprot_metalbinding_challenge/inference_result/'
    testprot_tmp = infile.split('/')[-1].split('.')[0]
    testprot_og = infile.split('/')[-1].split('.')[0]
    testseq_tmp_file = open(infile,'r')
    testseq_tmp = testseq_tmp_file.readlines()[0].strip()
    test_input_lst=[]
    for i in list(range(len(testseq_tmp))):
        test_input={} # Generate Dictionary for prediciton
        testprot_tmp = 'HASH_'+str(testprot_og)+'.'+str(i)
        if i-4<0:
            windowed_tmp = 'X'*(4-i)+testseq_tmp[0:i+5]
        elif i+5>len(testprot_tmp):
            windowed_tmp = testseq_tmp[i-4:i+5]+'X'*(i+5-len(testseq_tmp)+1)
        else : 
            windowed_tmp = testseq_tmp[i-4:i+5]

        test_input['qid'] = testprot_tmp
        test_input['question_text'] = windowed_tmp
        test_input_lst.append(test_input)

    test_input_lst = [transform(sample, vocabulary) for sample in test_input_lst]
    test_batch_provider = TestBatchProvider(test_input_lst, batch_size=len(test_input_lst), shuffle=False, run_forever=False)
    enqueuer = GeneratorEnqueuer(test_batch_provider.generate_batches())
    enqueuer.start()
    test_batches = enqueuer.get()
    for batch in test_batches:
        test_predictions = model.predict_on_batch(batch)
        test_prediction_argmax = [np.argmax(p) for p in test_predictions]
        test_prediction_proba = [str(round(prob[idx],4)) for prob,idx in zip(test_predictions,test_prediction_argmax)]
        test_prediction_argmax = [str(p) for p in test_prediction_argmax]
        tf.keras.backend.clear_session()
        gc.collect()
        testseq_tmp_file.close()

    # inference_result.append([testprot_og,test_prediction_argmax,test_prediction_proba])
    # return([testprot_og,test_prediction_argmax,test_prediction_proba])
    with open(outpath+testprot_og+'.result.txt','w') as outfile:
        outfile.write('Protein_Name\tIon\tPrediction_Score\n')
        outfile.write('\t'.join([testprot_og,','.join(test_prediction_argmax),','.join(test_prediction_proba)]))
        outfile.close()


In [None]:
inpath = '/Users/suhancho/data/Uniprot_metalbinding_challenge/neg_sequence/'
infiles = [inpath+f for f in os.listdir(inpath)]


In [None]:
for i in infiles:
    inference_fasta(i)

In [None]:
from joblib import Parallel, delayed
inpath = '/Users/suhancho/data/Uniprot_metalbinding_challenge/neg_sequence/'
infiles = [inpath+f for f in os.listdir(inpath)]

Parallel(n_jobs=8)(delayed(inference_fasta)(i) for i in infiles)