# Load modules

In [2]:
import numpy as np
import os
import tensorflow as tf
from tensorflow import keras
from joblib import Parallel,delayed

2021-09-19 17:15:26.951778: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0


# Functions

## Helpers - Processing

In [46]:
# sequential sliding window
# converts N by M into N by 2(flank+1)(M)
def window(matin, flank):
#     print(matin.shape)
    nrow = matin.shape[0]
    ncol = matin.shape[1]
    matout = np.zeros(shape=(nrow, (2*flank+1)*ncol), dtype=np.float32)
    for i in list(range(0,nrow)):
        s, e = i-flank, i+flank+1
        k = 0;
        for j in list(range(s,e)):
            if (j > 0 and j < nrow):
                matout[i,k:k+ncol] = matin[j]
            k += ncol
    return tf.convert_to_tensor(matout)

# one hot encoding
# e.g. input:
# -H---HHHH--EE--
# will output:
# [[0,0,1],[0,1,0],[0,0,1],...]
def encode_dssp(s):
    res = np.empty(shape=(len(s),3),dtype=np.byte)
    res[:] = np.nan
    for i in range(0,len(s)):
        if s[i] == 'H':
            res[i] = np.array([0,1,0])
        else:
            if s[i] == 'E':
                res[i] = np.array([1,0,0])
            else:
                assert s[i]
                res[i] = np.array([0,0,1])
    assert not np.isnan(np.sum(res))
    return res

## Helpers - Parsing

In [4]:
def parse_seqID(seqID):
    data_dir = '/cluster/gjb_lab/2472402/data/retr231_raw_files/training/'
    hmm_path = data_dir + seqID + '.hmm'
    pssm_path = data_dir + seqID + '.pssm'
    dssp_path = data_dir + seqID + '.dssp'
    assert os.path.exists(pssm_path)
    hmm = np.loadtxt(hmm_path,delimiter=' ',dtype=np.float32)
    pssm = np.loadtxt(pssm_path,delimiter=' ',dtype=np.float32)
    dssp = parse_dssp(dssp_path)
    return [tf.convert_to_tensor(x,dtype=tf.float32) for x in [dssp,hmm,pssm]]

def parse_dssp(dssp_path):
    with open(dssp_path,'r') as f:
        dssp = f.read().strip()
    return encode_dssp(dssp)

def parse_log_file(log_file_path):
    val_splits = []
    set_idx = -1
    cur_set = set() 
    with open(log_file_path,'r') as f:
        lines = f.read().splitlines()
        for line in lines:
            if line.startswith('#SET'):
                if set_idx > -1:
                    val_splits.append(cur_set)
                    cur_set = set()
                set_idx += 1
            else:
                seqID = line.split('/')[-1].replace('.pssm','')
                cur_set.add(seqID)
        # append last set which is not followed by another line '#SET...'
        val_splits.append(cur_set)
    assert sum([len(s) for s in val_splits])==1348
    return val_splits

## Jnet classifier class

In [41]:
class JnetClassifier(keras.Model):
    
    def __init__(self):
        super(JnetClassifier, self).__init__()
        self.hmm1 = MLPBlock(408)
        self.hmm2 = MLPBlock(57)
        self.psi1 = MLPBlock(340)
        self.psi2 = MLPBlock(57)
    
    def call(self, data):
        
        x1 = data[:,:24]
        x1 = window(x1,flank=8)
        x1 = self.hmm1(x1)
        x1 = window(x1,flank=9)
        x1 = self.hmm2(x1)
        
        x2 = data[:,24:44]
        x2 = window(x2,flank=8)
        x2 = self.psi1(x2)
        x2 = window(x2,flank=9)
        x2 = self.psi2(x2)
        
        return x1+x2//2
   
    def compile(self,loss,optimizer,metrics,**kwargs):
        super(JnetClassifier, self).compile(**kwargs)
        self.hmm1.compile(loss=loss,optimizer=optimizer,metrics=metrics)
        self.hmm2.compile(loss=loss,optimizer=optimizer,metrics=metrics)
        self.psi1.compile(loss=loss,optimizer=optimizer,metrics=metrics)
        self.psi2.compile(loss=loss,optimizer=optimizer,metrics=metrics)
    
    def train_step(self,inputs):
        
        data,dssp = inputs
        x1 = data[:,:24]
        x2 = data[:,24:44]
        assert x1.shape[0]==x2.shape[0]
        
        # window inputs
        x11 = window(x1,flank=8)
        x21 = window(x2,flank=8)
        
        # should transfer this to MLPBlock class
        with tf.GradientTape() as tape:
            hmm_pred1 = self.hmm1(x11)
            hmm_loss1 = self.hmm1.compiled_loss(dssp,hmm_pred1)
        grad = tape.gradient(hmm_loss1, self.hmm1.trainable_variables)
        self.hmm1.optimizer.apply_gradients(zip(grad,self.hmm1.trainable_variables))
        
        with tf.GradientTape() as tape:
            psi_pred1 = self.psi1(x21)
            psi_loss1 = self.psi1.compiled_loss(dssp,psi_pred1)
        grad = tape.gradient(psi_loss1, self.psi1.trainable_variables)
        self.psi1.optimizer.apply_gradients(zip(grad,self.psi1.trainable_variables))
        
        # window inputs 
        x12 = window(hmm_pred1,flank=9)
        x22 = window(psi_pred1,flank=9)
        
        with tf.GradientTape() as tape:
            hmm_pred2 = self.hmm2(x12)
            hmm_loss2 = self.hmm2.compiled_loss(dssp,hmm_pred2)
        grad = tape.gradient(hmm_loss2, self.hmm2.trainable_variables)
        self.hmm2.optimizer.apply_gradients(zip(grad,self.hmm2.trainable_variables))

        with tf.GradientTape() as tape:
            psi_pred2 = self.hmm2(x22)
            psi_loss2 = self.psi2.compiled_loss(dssp,psi_pred2)
        grad = tape.gradient(psi_loss2, self.psi2.trainable_variables)
        self.psi2.optimizer.apply_gradients(zip(grad,self.psi2.trainable_variables))
        
        return {'hmm_loss1':hmm_loss1,'hmm_loss2':hmm_loss2,'pssm_loss1':psi_loss1,'pssm_loss2':psi_loss2}

## MLP block class

In [6]:
class MLPBlock(keras.Sequential):
    def __init__(self, input_shape):
        super(MLPBlock, self).__init__()
        self.kernit = keras.initializers.RandomUniform(minval=-0.05, maxval=0.05)
        self.layer1 = keras.layers.InputLayer(input_shape=[input_shape])
        self.layer2 = keras.layers.Dense(units=100,
                                         activation='sigmoid',
                                         kernel_initializer=self.kernit)
        self.layer3 = keras.layers.Dense(units=3,
                                         activation='softmax',
                                         kernel_initializer=self.kernit)

    def call(self, inputs):
        x = self.layer1(inputs)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

## Data Generators

In [47]:
class TrainDataGenerator(keras.utils.Sequence):
    
    def __init__(self, dssp, hmm, pssm, batch_size, shuffle=True):
        self.n = 0
        self.dssp = dssp
        self.hmm = hmm
        self.pssm = pssm
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.on_epoch_end()
    
    def __len__(self):
        return int(np.floor(len(self.dssp)/self.batch_size))
    
    def __getitem__(self, idx):
        startIdx = idx*self.batch_size
        endIdx = (idx+1)*self.batch_size
        batch_dssp = self.dssp[startIdx:endIdx]
        batch_data = tf.concat([
            self.hmm[startIdx:endIdx],
            self.pssm[startIdx:endIdx]
        ],axis=1)
        return (batch_data,batch_dssp)
    
    def __next__(self):
        if self.n >= self.__len__():
            self.n = 0
        idx = self.n
        self.n += 1
        return self.__getitem__(idx)
        
    def on_epoch_end(self):
        if self.shuffle:
            idx = tf.range(start=0, limit=tf.shape(self.dssp)[0])
            new_idx = tf.random.shuffle(idx)
            self.hmm = tf.gather(self.hmm,new_idx)
            self.pssm = tf.gather(self.pssm,new_idx)
            self.dssp = tf.gather(self.dssp,new_idx)

class ValidDataGenerator(keras.utils.Sequence):
    
    # variable batch sizes
    # D is dictionary of key => [dssp,hmm,pssm]
    def __init__(self, D):
        self.D = D
        self.len = len(D)
        self.seqIDs = list(self.D.keys())
    
    def __len__(self):
        return self.len
    
    def __getitem__(self, idx):
        val = self.D[self.seqIDs[idx]]
        batch_dssp = val[0]
        batch_data = tf.concat([
            val[1],val[2]
        ],axis=1)
        return (batch_data,batch_dssp)

## Data wrappers

In [8]:
def load_data():
    log_file = '/cluster/gjb_lab/2472402/data/retr231_shuffles/shuffle02/best_shuffle_th_1.log'
    splits = parse_log_file(log_file)
    return Parallel(n_jobs=-1,verbose=0)(
        delayed(
            lambda split: {seqID:parse_seqID(seqID) for seqID in split}
        ) (split_i) for split_i in splits
    )

def get_train_generator(d_train,batch_size):
    dsspL, hmmL, pssmL = [],[],[]
    for _, (dssp,hmm,pssm) in d_train.items():
        dsspL.append(dssp)
        hmmL.append(hmm)
        pssmL.append(pssm)

    dssp, hmm, pssm = [tf.concat(L,axis=0) for L in [dsspL,hmmL,pssmL]]
    #shuffle
    idx = tf.range(start=0, limit=tf.shape(dssp)[0])
    shuffled_idx = tf.random.shuffle(idx)
    dssp, hmm, pssm = [tf.gather(mat,shuffled_idx) for mat in [dssp,hmm,pssm]]
    
    return TrainDataGenerator(dssp,hmm,pssm,batch_size)

def get_valid_generator(d_valid):
    return ValidDataGenerator(d_valid)

## Cross validation

In [9]:
def run_cross_val(DEBUG=False,**params):
    
    # load params
    if DEBUG:
        batch_size = 128
        loss_fn = keras.losses.CategoricalCrossentropy()
        optim = keras.optimizers.SGD(learning_rate=1e-2)
        epochs = 10
    else:
        batch_size = params['batch_size']
        loss_fn = params['loss_fn']
        optim = params['optimizer']
        epochs = params['epochs']
    
    splits = load_data()
    
    def split_data(val_idx):
        d_train={}
        for idx, d_split in enumerate(splits):
            if idx!=val_idx:
                d_train.update(d_split)
        return d_train, splits[val_idx]
    
    krange = [4] if DEBUG else range(7)
        
    for k in krange:
        jnet = JnetClassifier()
        
        jnet.compile(loss = loss_fn, optimizer = optim)
        
        d_train,d_valid = split_data(k)
        train_generator = get_train_generator(d_train,batch_size)
        valid_generator = get_valid_generator(d_valid)
        
        fit_params = {
            'x' : train_generator,
            'validation_data' : valid_generator,
            'verbose' : 2,
            'callbacks' : None,
            'workers' : 4,
            'use_multiprocessing' : True,
            'epochs':epochs
        }
        
        jnet.fit(**fit_params)
        
        
        

# Main

In [10]:
%%time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
splits = load_data()
def split_data(val_idx):
    splits
    d_train={}
    for idx, d_split in enumerate(splits):
        if idx!=val_idx:
            d_train.update(d_split)
    return d_train, splits[val_idx]
d_train, d_valid = split_data(4)

2021-09-19 17:15:47.450123: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.508781: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.554582: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.757955: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.866470: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.917904: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2021-09-19 17:15:47.959730: 

CPU times: user 431 ms, sys: 1.33 s, total: 1.76 s
Wall time: 6.42 s


# Debugging

In [48]:
def get_train_generator(d_train,batch_size):
    dsspL, hmmL, pssmL = [],[],[]
    for i, (dssp,hmm,pssm) in d_train.items():
        if int(i) % 100 == 0:
            dsspL.append(dssp)
            hmmL.append(hmm)
            pssmL.append(pssm)

    dssp, hmm, pssm = [tf.concat(L,axis=0) for L in [dsspL,hmmL,pssmL]]
    #shuffle
    idx = tf.range(start=0, limit=tf.shape(dssp)[0])
    shuffled_idx = tf.random.shuffle(idx)
    dssp, hmm, pssm = [tf.gather(mat,shuffled_idx) for mat in [dssp,hmm,pssm]]
    
    return TrainDataGenerator(dssp,hmm,pssm,batch_size)

loss_fn = keras.losses.CategoricalCrossentropy()
optim = keras.optimizers.SGD(learning_rate=1e-2)
jnet = JnetClassifier()
d_train,d_valid = split_data(4)
train_generator = get_train_generator(d_train,batch_size=256)
valid_generator = get_valid_generator(d_valid)

# works (run_eagerly=True)
jnet.compile(loss = loss_fn, optimizer = optim, metrics = [],run_eagerly=True)
jnet.fit(train_generator, epochs=2)

# does not work (run_eagerly=False)
# jnet.compile(loss = loss_fn, optimizer = optim, metrics = [],run_eagerly=False)
# jnet.fit(train_generator, epochs=2)


Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x2b4559996670>

In [43]:
batch_size = 256
loss_fn = keras.losses.CategoricalCrossentropy()
optim = keras.optimizers.SGD(learning_rate=1e-2)
epochs = 10

hmm1 = MLPBlock(408)
hmm1.compile(loss = loss_fn, optimizer = optim)

d_train,d_valid = split_data(4)
train_generator = get_train_generator(d_train,batch_size)
valid_generator = get_valid_generator(d_valid)

fit_params = {
    'data_generator' : train_generator,
    'validation_data' : valid_generator,
    'verbose' : 2,
    'callbacks' : None,
    'workers' : 4,
    'use_multiprocessing' : True,
    'epochs':epochs
}

hmm1.fit(train_generator, epochs=2)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x2ba8eb9f25b0>

In [57]:
batch_size = 256
loss_fn = keras.losses.CategoricalCrossentropy()
optim = keras.optimizers.SGD(learning_rate=1e-2)
epochs = 10

psi1 = MLPBlock(340)
psi1.compile(loss = loss_fn, optimizer = optim)

d_train,d_valid = split_data(4)
train_generator = get_train_generator(d_train,batch_size)
valid_generator = get_valid_generator(d_valid)

fit_params = {
    'data_generator' : train_generator,
    'validation_data' : valid_generator,
    'verbose' : 2,
    'callbacks' : None,
    'workers' : 4,
    'use_multiprocessing' : True,
    'epochs':epochs
}

psi1.fit(train_generator, epochs=2)

Epoch 1/2
Epoch 2/2


<tensorflow.python.keras.callbacks.History at 0x2abee3a254f0>