In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import sklearn.metrics

import random
random.seed(7)


def generator(X, Y, batch_size=32, train=True):
    while True:
        for offset in range(0, len(X), batch_size):
            X_batch = np.stack(X[offset:offset+batch_size], axis=0)
            Y_batch = np.stack(Y[offset:offset+batch_size], axis=0)

            Y_batch_ = np.empty((Y_batch.shape[0], 2,2,10))
            for m in range(Y_batch.shape[0]):
                for i in range(10):
                    Y_batch_[m, :,:,i] = [[np.sum(Y_batch[m, :5, :5, i]), np.sum(Y_batch[m, :5, 5:, i])], [np.sum(Y_batch[m, 5:, :5, i]), np.sum(Y_batch[m, 5:, 5:, i])]]
            Y_batch_[Y_batch_ > 0] = 1

            if train:
                Y_f = np.array([Y_batch_[e].flatten() for e in range(Y_batch_.shape[0])])
                yield (X_batch, Y_f)
            else:
                yield X_batch

structure_ids = []
for line in open('./structures lists/structures human.txt', 'r'):
    line = line.strip('\n')
    structure_ids.append(line)
# for line in open('./structures lists/structures ecoli.txt', 'r'):
#     line = line.strip('\n')
#     structure_ids.append(line)
structure_ids.remove('4pkd')
structure_ids.remove('1a9n')
structure_ids.remove('2adc')
random.shuffle(structure_ids)
print(len(structure_ids))

X_train = []
X_test = []
Y_train = []
Y_test = []
num_aa_train = 0
num_aa_test = 0
num_train = int(len(structure_ids)*0.7)
for i, structure_id in enumerate(structure_ids):
    protein = np.load('../data/voxelized data 10x10x10/' + structure_id + '_protein.npy', mmap_mode='r')
    rna = np.load('../data/voxelized data 10x10x10/' + structure_id + '_rna_3D.npy', mmap_mode='r')
    na = 0
    pos = 0
    while (np.sum(rna[na]) > 0) and (na < len(rna)-1):
        pos +=1
        na +=1
    

    if i <= num_train:
        if pos > len(rna)/2:
            X_train.extend(protein[:, :, :, :, :3])
            Y_train.extend(rna)
            num_aa_train +=len(rna)
        else:
            X_train.extend(protein[:pos, :, :, :, :3])
            X_train.extend(protein[-pos:, :, :, :, :3])
            Y_train.extend(rna[:pos])
            Y_train.extend(rna[-pos:])
            num_aa_train +=2*pos
    else:
        if pos > len(rna)/2:
            X_test.extend(protein[:, :, :, :, :3])
            Y_test.extend(rna)
            num_aa_test +=len(rna)
        else:
            X_test.extend(protein[:pos, :, :, :, :3])
            X_test.extend(protein[-pos:, :, :, :, :3])
            Y_test.extend(rna[:pos])
            Y_test.extend(rna[-pos:])
            num_aa_test +=2*pos

Y_test = np.stack(Y_test, axis=0)
Y_test_ = np.empty((len(Y_test), 2,2,10))
for m in range(len(Y_test)):
    for i in range(10):
        Y_test_[m, :,:,i] = [[np.sum(Y_test[m, :5, :5, i]), np.sum(Y_test[m, :5, 5:, i])], [np.sum(Y_test[m, 5:, :5, i]), np.sum(Y_test[m, 5:, 5:, i])]]
Y_test_[Y_test_ > 0] = 1
   
n_steps_train = int(num_aa_train/400) 
n_steps_test = int(num_aa_test/400)

print(num_aa_train, num_aa_test)

generator_train = generator(X_train, Y_train, 400, True)
generator_test = generator(X_test, Y_test, 400, False)

ins = tf.keras.layers.Input((10, 10, 10, 3))
con1 = tf.keras.layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding='same', activation='relu')(ins)
con2 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(con1)
con3 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(con2)
maxp1 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con3)
con4 = tf.keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding='same', activation='relu')(maxp1)
con5 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(con4)
con6 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(con5)
maxp2 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con6)
con7 = tf.keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding='same', activation='relu')(maxp2)
con8 = tf.keras.layers.Conv3D(filters=8, kernel_size=(3, 3, 3), padding='same', activation='relu')(con7)
con9 = tf.keras.layers.Conv3D(filters=4, kernel_size=(3, 3, 3), padding='same', activation='relu')(con8)
maxp3 = tf.keras.layers.MaxPool3D(pool_size=(2, 2, 2))(con9)
batch = tf.keras.layers.BatchNormalization()(maxp3)
flat = tf.keras.layers.Flatten()(batch)
dens2 = tf.keras.layers.Dense(units=256, activation='relu')(flat)
drop2 = tf.keras.layers.Dropout(0.6)(dens2)
outs = tf.keras.layers.Dense(units=40, activation='sigmoid')(drop2)
model = tf.keras.models.Model(inputs=ins, outputs=outs)
model.compile(loss='binary_crossentropy', optimizer=tf.keras.optimizers.Adam(lr=0.00001), metrics=['accuracy'])

model.summary()

# checkpoint
# filepath="weights_best.hdf5"
# checkpoint = tf.keras.callbacks.ModelCheckpoint(filepath, monitor='val_acc', verbose=0, save_best_only=True, mode='max')
# callbacks_list = [checkpoint]

# model.fit(X_train, Y_train_f, validation_split=0.33, epochs=1, batch_size=200, callbacks=callbacks_list, verbose=0)
model.fit_generator(generator_train, steps_per_epoch=n_steps_train, epochs=500, callbacks=None, verbose=1, max_queue_size=2)

# model_best = model
# model_best.load_weights("weights_best.hdf5")
# print(model.evaluate(X_test, Y_test, verbose=0, batch_size=100))
# model_best.save('model_cnn_15_2.h5')
# Y_pred = model_best.predict(X_test, batch_size=200)
Y_pred = model.predict_generator(generator_test, steps=n_steps_test)
print(Y_pred.shape)
Y_pred_ = np.array([Y_pred[i].reshape((2,2,10)) for i in range(Y_pred.shape[0])])

#CNN
Y_pred_[Y_pred_ >= 0.5] = 1
Y_pred_[Y_pred_ < 0.5] = 0

print(Y_pred_.shape)
print(Y_test_.shape)
Y_test_ = Y_test_[:Y_pred_.shape[0]]

print('CNN: \n')
for i in range(10):
    confusion_matrix = [sklearn.metrics.confusion_matrix(Y_test_[:,l , c, i], Y_pred_[:,l , c, i]) for l in range(2) for c in range(2)]
    accuracy = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix]
    auc = [sklearn.metrics.roc_auc_score(Y_test_[:,l , c, i], Y_pred_[:,l , c, i]) for l in range(2) for c in range(2)]

    print(f'level {i}')
    for q in range(len(confusion_matrix)):
        print(confusion_matrix[q], np.round(accuracy[q], 2), np.round(auc[q], 2))

# baseline model
# predict all zeros; at least 50% correct predictions because there are 1/2 of negative examples
Y_pred_base = np.zeros(Y_test_.shape)

# po = np.sum(Y_train, axis=0)/Y_train.shape[0]
# po[po >= 0.5] = 1
# po[po < 0.5] = 0
# Y_pred_base = np.tile(po, (Y_test.shape[0],1))

Y_pred_base[Y_pred_base >= 0.5] = 1
Y_pred_base[Y_pred_base < 0.5] = 0
print(Y_pred_base.shape)
print(f'\n BASELINE MODEL: \n')
for i in range(10):
    confusion_matrix_base = [sklearn.metrics.confusion_matrix(Y_test_[:,l, c, i], Y_pred_base[:,l, c, i]) for l in range(2) for c in range(2)]
    accuracy_base = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix_base]
    auc_base = [sklearn.metrics.roc_auc_score(Y_test_[:,l, c, i], Y_pred_base[:,l, c, i]) for l in range(2) for c in range(2)]
    
    print(f'level {i}')
    for q in range(len(confusion_matrix_base)):
        print(confusion_matrix_base[q], np.round(accuracy_base[q], 2), np.round(auc_base[q], 2))

#random model
Y_pred_random = np.random.random(Y_test_.shape)
Y_pred_random[Y_pred_random >= 0.5] = 1
Y_pred_random[Y_pred_random < 0.5] = 0

print(f'\n RANDOM MODEL: \n')
for i in range(10):
    confusion_matrix_random = [sklearn.metrics.confusion_matrix(Y_test_[:,l, c, i], Y_pred_random[:,l, c, i]) for l in range(2) for c in range(2)]
    accuracy_random = [np.sum(np.trace(cm))/np.sum(cm) for cm in confusion_matrix_random]
    auc_random = [sklearn.metrics.roc_auc_score(Y_test_[:,l, c, i], Y_pred_random[:,l, c, i]) for l in range(2) for c in range(2)]

    print(f'level {i}')
    for q in range(len(confusion_matrix_random)):
        print(confusion_matrix_random[q], np.round(accuracy_random[q], 2), np.round(auc_random[q], 2))

230
83862 44057
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 10, 10, 10, 3)     0         
_________________________________________________________________
conv3d (Conv3D)              (None, 10, 10, 10, 64)    5248      
_________________________________________________________________
conv3d_1 (Conv3D)            (None, 10, 10, 10, 32)    55328     
_________________________________________________________________
conv3d_2 (Conv3D)            (None, 10, 10, 10, 32)    27680     
_________________________________________________________________
max_pooling3d (MaxPooling3D) (None, 5, 5, 5, 32)       0         
_________________________________________________________________
conv3d_3 (Conv3D)            (None, 5, 5, 5, 32)       27680     
_________________________________________________________________
conv3d_4 (Conv3D)            (None, 5, 5, 5, 16)       13840

Epoch 56/500
Epoch 57/500
Epoch 58/500
Epoch 59/500
Epoch 60/500
Epoch 61/500
Epoch 62/500
Epoch 63/500
Epoch 64/500
Epoch 65/500
Epoch 66/500
Epoch 67/500
Epoch 68/500
Epoch 69/500
Epoch 70/500
Epoch 71/500
Epoch 72/500
Epoch 73/500
Epoch 74/500
Epoch 75/500
Epoch 76/500
Epoch 77/500
Epoch 78/500
Epoch 79/500
Epoch 80/500
Epoch 81/500
Epoch 82/500
Epoch 83/500
Epoch 84/500
Epoch 85/500
Epoch 86/500
Epoch 87/500
Epoch 88/500
Epoch 89/500
Epoch 90/500
Epoch 91/500
Epoch 92/500
Epoch 93/500
Epoch 94/500
Epoch 95/500
Epoch 96/500
Epoch 97/500
Epoch 98/500
Epoch 99/500
Epoch 100/500
Epoch 101/500
Epoch 102/500
Epoch 103/500
Epoch 104/500
Epoch 105/500
Epoch 106/500
Epoch 107/500
Epoch 108/500
Epoch 109/500
Epoch 110/500
Epoch 111/500
Epoch 112/500
Epoch 113/500
Epoch 114/500
Epoch 115/500
Epoch 116/500
Epoch 117/500
Epoch 118/500
Epoch 119/500
Epoch 120/500
Epoch 121/500
Epoch 122/500
Epoch 123/500
Epoch 124/500
Epoch 125/500
Epoch 126/500
Epoch 127/500
Epoch 128/500
Epoch 129/500
Epoch 13

Epoch 136/500
Epoch 137/500
Epoch 138/500
Epoch 139/500
Epoch 140/500
Epoch 141/500
Epoch 142/500
Epoch 143/500
Epoch 144/500
Epoch 145/500
Epoch 146/500
Epoch 147/500
Epoch 148/500
Epoch 149/500
Epoch 150/500
Epoch 151/500
Epoch 152/500
Epoch 153/500
Epoch 154/500
Epoch 155/500
Epoch 156/500
Epoch 157/500
Epoch 158/500
Epoch 159/500
Epoch 160/500
Epoch 161/500
Epoch 162/500
Epoch 163/500
Epoch 164/500
Epoch 165/500
Epoch 166/500
Epoch 167/500
Epoch 168/500
Epoch 169/500
Epoch 170/500
Epoch 171/500
Epoch 172/500
Epoch 173/500
Epoch 174/500
Epoch 175/500
Epoch 176/500
Epoch 177/500
Epoch 178/500
Epoch 179/500
Epoch 180/500
Epoch 181/500
Epoch 182/500
Epoch 183/500
Epoch 184/500
Epoch 185/500
Epoch 186/500
Epoch 187/500
Epoch 188/500
Epoch 189/500
Epoch 190/500
Epoch 191/500
Epoch 192/500
Epoch 193/500
Epoch 194/500
Epoch 195/500
Epoch 196/500
Epoch 197/500
Epoch 198/500
Epoch 199/500
Epoch 200/500
Epoch 201/500
Epoch 202/500
Epoch 203/500
Epoch 204/500
Epoch 205/500
Epoch 206/500
Epoch 

Epoch 217/500
Epoch 218/500
Epoch 219/500
Epoch 220/500
Epoch 221/500
Epoch 222/500
Epoch 223/500
Epoch 224/500
Epoch 225/500
Epoch 226/500
Epoch 227/500
Epoch 228/500
Epoch 229/500
Epoch 230/500
Epoch 231/500
Epoch 232/500
Epoch 233/500
Epoch 234/500
Epoch 235/500
Epoch 236/500
Epoch 237/500
Epoch 238/500
Epoch 239/500
Epoch 240/500
Epoch 241/500
Epoch 242/500
Epoch 243/500
Epoch 244/500
Epoch 245/500
Epoch 246/500
Epoch 247/500
Epoch 248/500
Epoch 249/500
Epoch 250/500
Epoch 251/500
Epoch 252/500
Epoch 253/500
Epoch 254/500
Epoch 255/500
Epoch 256/500
Epoch 257/500
Epoch 258/500
Epoch 259/500
Epoch 260/500
Epoch 261/500
Epoch 262/500
Epoch 263/500
Epoch 264/500
Epoch 265/500
Epoch 266/500
Epoch 267/500
Epoch 268/500
Epoch 269/500
Epoch 270/500
Epoch 271/500
Epoch 272/500
Epoch 273/500
Epoch 274/500
Epoch 275/500
Epoch 276/500
Epoch 277/500
Epoch 278/500
Epoch 279/500
Epoch 280/500
Epoch 281/500
Epoch 282/500
Epoch 283/500
Epoch 284/500
Epoch 285/500
Epoch 286/500
Epoch 287/500
Epoch 

Epoch 297/500
Epoch 298/500
Epoch 299/500
Epoch 300/500
Epoch 301/500
Epoch 302/500
Epoch 303/500
Epoch 304/500
Epoch 305/500
Epoch 306/500
Epoch 307/500
Epoch 308/500
Epoch 309/500
Epoch 310/500
Epoch 311/500
Epoch 312/500
Epoch 313/500
Epoch 314/500
Epoch 315/500
Epoch 316/500
Epoch 317/500
Epoch 318/500
Epoch 319/500
Epoch 320/500
Epoch 321/500
Epoch 322/500
Epoch 323/500
Epoch 324/500
Epoch 325/500
Epoch 326/500
Epoch 327/500
Epoch 328/500
Epoch 329/500
Epoch 330/500
Epoch 331/500
Epoch 332/500
Epoch 333/500
Epoch 334/500
Epoch 335/500
Epoch 336/500
Epoch 337/500
Epoch 338/500
Epoch 339/500
Epoch 340/500
Epoch 341/500
Epoch 342/500
Epoch 343/500
Epoch 344/500
Epoch 345/500
Epoch 346/500
Epoch 347/500
Epoch 348/500
Epoch 349/500
Epoch 350/500
Epoch 351/500
Epoch 352/500
Epoch 353/500
Epoch 354/500
Epoch 355/500
Epoch 356/500
Epoch 357/500
Epoch 358/500
Epoch 359/500
Epoch 360/500
Epoch 361/500
Epoch 362/500
Epoch 363/500
Epoch 364/500
Epoch 365/500
Epoch 366/500
Epoch 367/500
Epoch 

Epoch 377/500
Epoch 378/500
Epoch 379/500
Epoch 380/500
Epoch 381/500
Epoch 382/500
Epoch 383/500
Epoch 384/500
Epoch 385/500
Epoch 386/500
Epoch 387/500
Epoch 388/500
Epoch 389/500
Epoch 390/500
Epoch 391/500
Epoch 392/500
Epoch 393/500
Epoch 394/500
Epoch 395/500
Epoch 396/500
Epoch 397/500
Epoch 398/500
Epoch 399/500
Epoch 400/500
Epoch 401/500
Epoch 402/500
Epoch 403/500
Epoch 404/500
Epoch 405/500
Epoch 406/500
Epoch 407/500
Epoch 408/500
Epoch 409/500
Epoch 410/500
Epoch 411/500
Epoch 412/500
Epoch 413/500
Epoch 414/500
Epoch 415/500
Epoch 416/500
Epoch 417/500
Epoch 418/500
Epoch 419/500
Epoch 420/500
Epoch 421/500
Epoch 422/500
Epoch 423/500
Epoch 424/500
Epoch 425/500
Epoch 426/500
Epoch 427/500
Epoch 428/500
Epoch 429/500
Epoch 430/500
Epoch 431/500
Epoch 432/500
Epoch 433/500
Epoch 434/500
Epoch 435/500
Epoch 436/500
Epoch 437/500
Epoch 438/500
Epoch 439/500
Epoch 440/500
Epoch 441/500
Epoch 442/500
Epoch 443/500
Epoch 444/500
Epoch 445/500
Epoch 446/500
Epoch 447/500
Epoch 

Epoch 458/500
Epoch 459/500
Epoch 460/500
Epoch 461/500
Epoch 462/500
Epoch 463/500
Epoch 464/500
Epoch 465/500
Epoch 466/500
Epoch 467/500
Epoch 468/500
Epoch 469/500
Epoch 470/500
Epoch 471/500
Epoch 472/500
Epoch 473/500
Epoch 474/500
Epoch 475/500
Epoch 476/500
Epoch 477/500
Epoch 478/500
Epoch 479/500
Epoch 480/500
Epoch 481/500
Epoch 482/500
Epoch 483/500
Epoch 484/500
Epoch 485/500
Epoch 486/500
Epoch 487/500
Epoch 488/500
Epoch 489/500
Epoch 490/500
Epoch 491/500
Epoch 492/500
Epoch 493/500
Epoch 494/500
Epoch 495/500
Epoch 496/500
Epoch 497/500
Epoch 498/500
Epoch 499/500
Epoch 500/500
(44000, 40)
(44000, 2, 2, 10)
(44057, 2, 2, 10)
CNN: 

level 0
[[39470     0]
 [ 4530     0]] 0.9 0.5
[[39064   178]
 [ 4675    83]] 0.89 0.51
[[39576     0]
 [ 4424     0]] 0.9 0.5
[[39354     0]
 [ 4646     0]] 0.89 0.5
level 1
[[38992   152]
 [ 4747   109]] 0.89 0.51
[[38873   236]
 [ 4763   128]] 0.89 0.51
[[39444     0]
 [ 4556     0]] 0.9 0.5
[[39187     0]
 [ 4813     0]] 0.89 0.5
level 2

level 2
[[19475 19382]
 [ 2524  2619]] 0.5 0.51
[[19376 19388]
 [ 2634  2602]] 0.5 0.5
[[19566 19631]
 [ 2422  2381]] 0.5 0.5
[[19435 19545]
 [ 2516  2504]] 0.5 0.5
level 3
[[19329 19105]
 [ 2750  2816]] 0.5 0.5
[[19050 19452]
 [ 2749  2749]] 0.5 0.5
[[19315 19537]
 [ 2559  2589]] 0.5 0.5
[[19000 19603]
 [ 2705  2692]] 0.49 0.5
level 4
[[18891 19023]
 [ 3067  3019]] 0.5 0.5
[[18770 19128]
 [ 3099  3003]] 0.49 0.49
[[19210 19309]
 [ 2768  2713]] 0.5 0.5
[[18952 19030]
 [ 2958  3060]] 0.5 0.5
level 5
[[18512 18597]
 [ 3454  3437]] 0.5 0.5
[[18694 18576]
 [ 3435  3295]] 0.5 0.5
[[18894 18841]
 [ 3155  3110]] 0.5 0.5
[[18697 18477]
 [ 3461  3365]] 0.5 0.5
level 6
[[18198 18037]
 [ 3912  3853]] 0.5 0.5
[[17928 18475]
 [ 3755  3842]] 0.49 0.5
[[18345 18466]
 [ 3636  3553]] 0.5 0.5
[[18130 18146]
 [ 3887  3837]] 0.5 0.5
level 7
[[17760 17847]
 [ 4177  4216]] 0.5 0.5
[[17680 17844]
 [ 4293  4183]] 0.5 0.5
[[18200 17791]
 [ 3955  4054]] 0.51 0.51
[[17809 17737]
 [ 4180  4274]] 0.5 0.5
level 8
[