# Imports

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.keras import Input, layers, Model

from tfomics.layers import MultiHeadAttention

import os, h5py

# Load data

In [2]:
file_path = os.path.abspath(os.path.join('..', '..', 'Datasets', 'Task4', 'complex_synthetic_dataset.h5'))

with h5py.File(file_path, 'r') as dataset:
    X = np.array(dataset['X'])
    Y = np.array(dataset['Y'])
    L = np.array(dataset['L'])
    
train = int(len(X) * 0.7)
valid = train + int(len(X) * 0.1 )
test = valid + int(len(X) * 0.2)

x_train = X[:train]
x_valid = X[train:valid]
x_test = X[valid:test]

y_train = Y[:train]
y_valid = Y[train:valid]
y_test = Y[valid:test]

# Define model

In [3]:
inputs = Input(shape=X.shape[1:])

nn = layers.Conv1D(48, 19, use_bias=False, padding='same', kernel_regularizer=None)(inputs)
nn = layers.BatchNormalization()(nn)
nn = layers.Activation('relu', name='conv_activation')(nn)
nn = layers.MaxPool1D(20)(nn)
nn = layers.Dropout(0.1)(nn)

# nn = layers.LayerNormalization()(nn)
nn, att = MultiHeadAttention(num_heads=4, d_model=4*96)(nn, nn, nn)
nn = layers.Dropout(0.1)(nn)

nn = layers.Flatten()(nn)

nn = layers.Dense(512)(nn)
nn = layers.BatchNormalization()(nn)
nn = layers.Activation('relu')(nn)
nn = layers.Dropout(0.5)(nn)

outputs = layers.Dense(Y.shape[1], activation='sigmoid')(nn)

model = Model(inputs=inputs, outputs=outputs)

# Train model

In [4]:
auroc = tf.keras.metrics.AUC(curve='ROC', name='auroc')
aupr = tf.keras.metrics.AUC(curve='PR', name='aupr')
model.compile(
    tf.keras.optimizers.Adam(0.001),
    loss='binary_crossentropy',
    metrics=[auroc, aupr]
)
batch_size = 200

lr_decay = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.2, patient=5, verbose=1, min_lr=1e-7, mode='min')
early_stop = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=15, verbose=1, mode='min', restore_best_weights=True)
model.fit(x_train, y_train, epochs=100, validation_data=(x_valid, y_valid), callbacks=[lr_decay, early_stop], verbose=1, batch_size=batch_size)

Epoch 1/100
 29/350 [=>............................] - ETA: 1:03 - loss: 0.5403 - auroc: 0.4992 - aupr: 0.1355

KeyboardInterrupt: 

In [None]:
loss, auc_roc, auc_pr = model.evaluate(x_test, y_test)

In [None]:
classes = ['ELF', 'SIX', 'FOSL', 'FOXN', 'CEBPB', 'YY', 'GATA', 'MEF', 'SP', 'NFIB', 'TEAD', 'TAL']


import sklearn

u, counts = np.unique(np.where(Y == 1)[1], return_counts=True)
class_balance = counts/len(Y)

y_preds = model.predict(x_test)
for i in range(y_preds.shape[1]):
    
    y_pred = y_preds[:,i]
    y_true = y_test[:,i]
    
    precision, recall, thresholds = sklearn.metrics.precision_recall_curve(y_true, y_pred)
    auc = sklearn.metrics.auc(recall, precision)
    print()
    print(f'{classes[i]} \t AUC PR {auc:.3f} \t Weightage {class_balance[i]:.3f}')