In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics

import random
random.seed(7)


def generator(X, Y, batch_size=32, train=True):
    while True:
        for offset in range(0, len(X), batch_size):
            X_batch = np.stack(X[offset:offset+batch_size], axis=0)
            Y_batch = np.stack(Y[offset:offset+batch_size], axis=0)
            
            if train:
                Y_f = np.array(list(map(lambda x:np.sum(np.sum(x, axis=0), axis=0), Y_batch)))
                Y_f[Y_f < 25] = 0
                Y_f[Y_f >= 25] = 1
                
            
                yield (X_batch, Y_f)
            else:
                yield X_batch

structure_ids = []   
for line in open('./structures lists/stucture ids homo sapiens.txt', 'r'):
    line = line.strip('\n').lower()
    structure_ids.append(line)
for line in open('./structures lists/stucture ids synthetic construct.txt', 'r'):
    line = line.strip('\n').lower()
    structure_ids.append(line)
for line in open('./structures lists/stucture ids virus.txt', 'r'):
    line = line.strip('\n').lower()
    structure_ids.append(line)

structure_ids.remove('1a9n')
structure_ids.remove('2adc')
random.shuffle(structure_ids)
print(len(structure_ids))

num_test = int(len(structure_ids)*0.3)
pp_train = 0
pp_test = 0
pos = 0
neg = 0
for i, structure_id in enumerate(structure_ids):
    rna = np.load('../data/voxelized data 14x14x17 2/' + structure_id + '_rna_3D.npy', mmap_mode='r')

    pp = 0
    p = 0
    n = 0
    for bb in rna:
        n_nucleotides = np.sum(bb)
        if n_nucleotides >= 100:
            pp +=1
        if 100 > n_nucleotides > 0:
            p +=1
        if n_nucleotides == 0:
            n +=1
    pos +=p
    neg +=n
    if i <= num_test:
        pp_test +=pp
    else:
        pp_train +=pp

proc_test = pp_test/(pp_test+pp_train)
pos_test = int(pos*proc_test)
neg_test = int(neg*proc_test)

print(pp_train, pp_test, (pp_train + pp_test), pos, neg)
print(proc_test, pos_test, neg_test)

X_train = []
X_test = []
Y_train = []
Y_test = []
num_aa_train = 0
num_aa_test = 0
num_pp = 0
num_p = 0
num_n = 0
for j, structure_id in enumerate(structure_ids):
    protein = np.load('../data/voxelized data 14x14x17 2/' + structure_id + '_protein.npy', mmap_mode='r')
    rna = np.load('../data/voxelized data 14x14x17 2/' + structure_id + '_rna_3D.npy', mmap_mode='r')

    k = 0
    pp_ = 0
    while (np.sum(rna[k]) >= 100) and (k < len(rna)-1):
        pp_ +=1
        k +=1

    if j <= num_test:
        X_test.extend(protein[:pp_, :, :, :, :3])
        Y_test.extend(rna[:pp_])
        num_aa_test +=pp_
        num_pp +=pp_

    else:
        X_train.extend(protein[:pp_, :, :, :, :3])
        Y_train.extend(rna[:pp_])
        num_aa_train +=pp_

    p_= 0 
    while (100 > np.sum(rna[k]) > 0)  and (k < len(rna)-1):
        if (pos_test > 0):
            p_ +=1
            pos_test -=1
        k +=1


    X_test.extend(protein[pp_:(pp_+p_), :, :, :, :3])
    Y_test.extend(rna[pp_:(pp_+p_)])
    num_aa_test +=p_
    num_p +=p_

    n_= 0
    while (np.sum(rna[k]) == 0) and (neg_test > 0) and (k < len(rna)-1):
        n_ +=1
        neg_test -=1
        k +=1

    X_test.extend(protein[(pp_+p_):(pp_+p_+n_), :, :, :, :3])
    Y_test.extend(rna[(pp_+p_):(pp_+p_+n_)])
    num_aa_test +=n_
    num_n +=n_


Y_test_ = np.array(list(map(lambda x:np.sum(np.sum(x, axis=0), axis=0), Y_test)))
Y_test_[Y_test_ < 25] = 0
Y_test_[Y_test_ >= 25] = 1


print(Y_test_.shape)

n_steps_train = int(num_aa_train/1024) 
n_steps_test = int(num_aa_test/1024)

print(num_aa_train, num_aa_test)

print(pp_train, pp_test, pos_test, neg_test)
print(num_pp, num_p, num_n)

generator_train = generator(X_train, Y_train, 1024, True)
generator_validation = generator(X_test, Y_test, 1024, True)
generator_test = generator(X_test, Y_test, 1024, False)

ins = tf.keras.layers.Input((14, 14, 17, 3))
con1 = tf.keras.layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(ins)
con2 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(con1)
con3 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(con2)
maxp1 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con3)
con4 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(maxp1)
con5 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(con4)
con6 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(con5)
maxp2 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con6)
con7 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(maxp2)
con8 = tf.keras.layers.Conv3D(filters=8, kernel_size=(3, 3, 3), padding='same', activation='relu')(con7)
con9 = tf.keras.layers.Conv3D(filters=4, kernel_size=(3, 3, 3), padding='same', activation='relu')(con8)
maxp3 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con9)
batch = tf.keras.layers.BatchNormalization()(maxp3)
flat = tf.keras.layers.Flatten()(batch)
dens2 = tf.keras.layers.Dense(units=256, activation='relu')(flat)
drop2 = tf.keras.layers.Dropout(0.6)(dens2)
outs = tf.keras.layers.Dense(units=17, activation='sigmoid')(drop2)
model = tf.keras.models.Model(inputs=ins, outputs=outs)
model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adadelta(), metrics=['accuracy', 'mse'])

model.summary()

# checkpoint
es = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=100, mode='min', min_delta=0.0001)
# mc = tf.keras.callbacks.ModelCheckpoint("weights_best.hdf5", monitor='val_loss', verbose=0, save_best_only=True, mode='min')

# model.fit(X_train, Y_train_f, validation_split=0.33, epochs=1, batch_size=200, callbacks=callbacks_list, verbose=0)
history = model.fit_generator(generator_train, steps_per_epoch=n_steps_train, epochs=500, 
                    validation_data = generator_validation, validation_steps=n_steps_test, callbacks=[es], verbose=1)

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs = range(1, len(acc) + 1)
plt.plot(epochs, acc, 'r', label='Training acc')
plt.plot(epochs, val_acc, 'b', label='Validation acc')
plt.title('Training and validation accuracy')
plt.legend()
plt.savefig('accuracy')
plt.figure()
plt.plot(epochs, loss, 'r', label='Training loss')
plt.plot(epochs, val_loss, 'b', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()
plt.savefig('loss')


# model_best = model
# model_best.load_weights("weights_best.hdf5")
# print(model.evaluate(X_test, Y_test, verbose=0, batch_size=100))
# model_best.save('model_cnn_15_2.h5')
# Y_pred = model_best.predict(X_test, batch_size=200)
Y_pred_ = model.predict_generator(generator_test, steps=n_steps_test)
# print(Y_pred.shape)
# Y_pred_ = np.array([Y_pred[i].reshape((2,2,17)) for i in range(Y_pred.shape[0])])

#CNN
Y_pred_prob = Y_pred_
Y_pred_[Y_pred_ >= 0.5] = 1
Y_pred_[Y_pred_ < 0.5] = 0

print(Y_pred_.shape)
print(Y_test_.shape)
Y_test_ = Y_test_[:Y_pred_.shape[0]]


print('CNN: \n')

confusion_matrix = [sklearn.metrics.confusion_matrix(Y_test_[:,i], Y_pred_[:, i]) for i in range(17)]
accuracy = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix]
auc = [sklearn.metrics.roc_auc_score(Y_test_[:,i], Y_pred_prob[:,i]) for i in range(17)]

for q in range(17):
    print(f'level {q+1}')
    print(confusion_matrix[q], np.round(accuracy[q], 2), np.round(auc[q], 2))

# baseline model
# predict all zeros; at least 50% correct predictions because there are 1/2 of negative examples
Y_pred_base = np.zeros(Y_test_.shape)

# po = np.sum(Y_train, axis=0)/Y_train.shape[0]
# po[po >= 0.5] = 1
# po[po < 0.5] = 0
# Y_pred_base = np.tile(po, (Y_test.shape[0],1))

print(Y_pred_base.shape)
print(f'\n BASELINE MODEL: \n')

confusion_matrix_base = [sklearn.metrics.confusion_matrix(Y_test_[:,i], Y_pred_base[:,i]) for i in range(17)]
accuracy_base = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix_base]
auc_base = [sklearn.metrics.roc_auc_score(Y_test_[:, i], Y_pred_base[:,i]) for i in range(17)]

for q in range(17):
    print(f'level {q+1}')
    print(confusion_matrix_base[q], np.round(accuracy_base[q], 2), np.round(auc_base[q], 2))

#random model
Y_pred_random = np.random.random(Y_test_.shape)
Y_pred_random[Y_pred_random >= 0.5] = 1
Y_pred_random[Y_pred_random < 0.5] = 0

print(f'\n RANDOM MODEL: \n')

confusion_matrix_random = [sklearn.metrics.confusion_matrix(Y_test_[:,i], Y_pred_random[:,i]) for i in range(17)]
accuracy_random = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix_random]
auc_random = [sklearn.metrics.roc_auc_score(Y_test_[:,i], Y_pred_random[:,i]) for i in range(17)]

for q in range(17):
    print(f'level {q+1}')
    print(confusion_matrix_random[q], np.round(accuracy_random[q], 2), np.round(auc_random[q], 2))

549
24676 13143 37819 94223 334320
0.3475237314577329 32744 116184
(162070, 17)
24673 162070
24676 13143 0 0
13142 32744 116184
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 14, 14, 17, 3)     0         
_________________________________________________________________
conv3d (Conv3D)              (None, 14, 14, 17, 64)    5248      
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 14, 14, 17, 32)    55328     
_________________________________________________________________
conv3d_2 (Conv3D)            (None, 14, 14, 17, 32)    27680     
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 7, 7, 8, 32)       0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 7, 7, 8, 32)       27680     
______________

Epoch 29/500
Epoch 30/500
Epoch 31/500
Epoch 32/500
Epoch 33/500
Epoch 34/500
Epoch 35/500
Epoch 36/500
Epoch 37/500
Epoch 38/500
Epoch 39/500
Epoch 40/500
Epoch 41/500
Epoch 42/500
Epoch 43/500
Epoch 44/500
Epoch 45/500
Epoch 46/500
Epoch 47/500
Epoch 48/500
Epoch 49/500
Epoch 50/500
Epoch 51/500
Epoch 52/500
Epoch 53/500
Epoch 54/500
Epoch 55/500
Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500


Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78/500
Epoch 79/500
Epoch 80/500
Epoch 81/500
Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500


Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 130/500
Epoch 131/500
Epoch 132/500
Epoch 133/500
Epoch 134/500
Epoch 135/500
Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/500
Epoch 155/500


Epoch 156/500
Epoch 157/500
Epoch 158/500
Epoch 159/500
Epoch 160/500
Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
(161792, 17)
(162070, 17)
CNN: 

level 1
[[142832  13472]
 [  3062   2426]] 0.9 0.68
level 2
[[142593  14198]
 [  2633   2368]] 0.9 0.69
level 3
[[144236  12669]
 [  2587   2300]] 0.91 0.69
level 4
[[145128  11717]
 [  2571   2376]] 0.91 0.7
level 5
[[145665  10949]
 [  2549   2629]] 0.92 0.72
level 6
[[144886  11249]
 [  2460   3197]] 0.92 0.75
level 7
[[143490  11996]
 [  2443   3863]] 0.91 0.77
level 8
[[142264  12829]
 [  2208   4491]] 0.91 0.79


level 1
[[156304      0]
 [  5488      0]] 0.97 0.5
level 2
[[156791      0]
 [  5001      0]] 0.97 0.5
level 3
[[156905      0]
 [  4887      0]] 0.97 0.5
level 4
[[156845      0]
 [  4947      0]] 0.97 0.5
level 5
[[156614      0]
 [  5178      0]] 0.97 0.5
level 6
[[156135      0]
 [  5657      0]] 0.97 0.5
level 7
[[155486      0]
 [  6306      0]] 0.96 0.5
level 8
[[155093      0]
 [  6699      0]] 0.96 0.5
level 9
[[154637      0]
 [  7155      0]] 0.96 0.5
level 10
[[154517      0]
 [  7275      0]] 0.96 0.5
level 11
[[154585      0]
 [  7207      0]] 0.96 0.5
level 12
[[154734      0]
 [  7058      0]] 0.96 0.5
level 13
[[154966      0]
 [  6826      0]] 0.96 0.5
level 14
[[155267      0]
 [  6525      0]] 0.96 0.5
level 15
[[155543      0]
 [  6249      0]] 0.96 0.5
level 16
[[155892      0]
 [  5900      0]] 0.96 0.5
level 17
[[156278      0]
 [  5514      0]] 0.97 0.5

 RANDOM MODEL: 

level 1
[[78204 78100]
 [ 2773  2715]] 0.5 0.5
level 2
[[78562 78229]
 [ 2472  2529]] 0.5 