In [1]:
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import (
    Conv1D,
    MaxPooling1D,
    LSTM,
    Bidirectional,
    Dropout,
    Flatten,
    Dense,
    Input,
)
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.optimizers import Adam
from tensorflow.keras import metrics

from utils import load_dataset, train
from einops import rearrange
from sklearn.metrics import plot_roc_curve

<IPython.core.display.Javascript object>

[Reference](https://github.com/uci-cbcl/DanQ/blob/master/DanQ_train.py) 

In [2]:
from tensorflow.config import list_physical_devices

print(list_physical_devices("GPU"))

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


<IPython.core.display.Javascript object>

In [3]:
model = Sequential(
    [
        Input(shape=(1000, 4)),
        Conv1D(320, kernel_size=8, activation="relu", name="conv1d_1"),
        MaxPooling1D(pool_size=4, strides=4),
        Dropout(0.2),
        Bidirectional(LSTM(2, return_sequences=True)),
        Dropout(0.5),
        Flatten(),
        Dense(512, activation="relu"),
        Dense(1, activation="relu"),
    ]
)

<IPython.core.display.Javascript object>

In [4]:
es = EarlyStopping(monitor="loss", patience=100)
optimizer = Adam(lr=1e-3)

epochs = 50
validation_freq = 5

<IPython.core.display.Javascript object>

In [None]:
for n in [0, 1, 2, 3, 5, 8]:
    X_train, y_train, X_test, y_test = load_dataset(
        file=f"m{n}",
        directory="/home/victor/Documents/datasets/",
        labels="binlabels",
        download=False,
    )
    X_train = rearrange(X_train, "w h c -> w c h")
    X_test = rearrange(X_test, "w h c -> w c h")

    model.compile(optimizer=Adam(), loss="binary_crossentropy", metrics=["accuracy", metrics.AUC()])

    model = train(
        dataset=(X_train, y_train, X_test, y_test),
        model=model,
        epochs=epochs,
        verbose=1,
        validation_freq=validation_freq,
        optimizer=optimizer,
        callbacks=[es],
    )

#     plot_roc(y_test, model.predict(X_test), boundaries)
    plot_roc_curve(model, X_test, y_test)

Epoch 1/50
 225/3125 [=>............................] - ETA: 1:07 - loss: 0.2657