In [1]:
import keras
from keras import layers
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

X = pd.read_csv('X.csv', index_col=0)
Y = pd.read_csv('Y.csv', index_col=0)

In [2]:
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size = 0.2, random_state = 1)

# Transpose the dataset
x_train = np.array(x_train)
x_test = np.array(x_test)
y_train =np.array(y_train)
y_test = np.array(y_test)

In [3]:
def transformer_encoder(inputs, head_size, num_heads, ff_dim, dropout=0):
    # Attention and Normalization
    x = layers.MultiHeadAttention(
        key_dim=head_size, num_heads=num_heads, dropout=dropout
    )(inputs, inputs)
    x = layers.Dropout(dropout)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    res = x + inputs

    # Feed Forward Part
    x = layers.Conv1D(filters=ff_dim, kernel_size=1, activation="relu")(res)
    x = layers.Dropout(dropout)(x)
    x = layers.Conv1D(filters=inputs.shape[-1], kernel_size=1)(x)
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    return x + res

In [4]:
def build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout=0,
    mlp_dropout=0,
):
    inputs = keras.Input(shape=input_shape)
    x = inputs
    for _ in range(num_transformer_blocks):
        x = transformer_encoder(x, head_size, num_heads, ff_dim, dropout)

    x = layers.GlobalAveragePooling1D(data_format="channels_last")(x)
    for dim in mlp_units:
        x = layers.Dense(dim, activation="relu")(x)
        x = layers.Dropout(mlp_dropout)(x)
    outputs = layers.Dense(1, activation="linear")(x)
    return keras.Model(inputs, outputs)

In [5]:
input_shape = (885, 1)
head_size = 64
num_heads = 1
ff_dim= 64
num_transformer_blocks = 1
mlp_units=[128, 64]
mlp_dropout=0.4
dropout=0.25


model = build_model(
    input_shape,
    head_size,
    num_heads,
    ff_dim,
    num_transformer_blocks,
    mlp_units,
    dropout,
    mlp_dropout,
)

model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mean_absolute_error'])

# Assume X_train and y_train are already defined and preprocessed
X_train = x_train.reshape((x_train.shape[0], x_train.shape[1], 1))
X_test = x_test.reshape((x_test.shape[0], x_test.shape[1], 1))

# Train the model
history = model.fit(X_train, y_train, epochs=30, batch_size=32, validation_split=0.2)

# Evaluate the model
loss, mae = model.evaluate(X_test, y_test)
print(f'Test Mean Absolute Error: {mae}')

# Make predictions
y_pred = model.predict(X_test)

plt.figure(figsize=(10, 6))
plt.scatter(y_test, y_pred, alpha=0.5)
plt.xlabel('Actual pIC50 Values')
plt.ylabel('Predicted pIC50 Values')
plt.title('Actual vs Predicted pIC50 Values')
plt.plot([min(y_test), max(y_test)], [min(y_test), max(y_test)], color='red')  # Line of equality
plt.show()

Epoch 1/30


[1m 1/60[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m2:33[0m 3s/step - loss: 31.2820 - mean_absolute_error: 5.5219

[1m 2/60[0m [37m━━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 360ms/step - loss: 31.7203 - mean_absolute_error: 5.5410

[1m 3/60[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m19s[0m 342ms/step - loss: 31.9436 - mean_absolute_error: 5.5504

[1m 4/60[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m21s[0m 388ms/step - loss: 31.7459 - mean_absolute_error: 5.5274

[1m 5/60[0m [32m━[0m[37m━━━━━━━━━━━━━━━━━━━[0m [1m20s[0m 371ms/step - loss: 31.4485 - mean_absolute_error: 5.4972

[1m 6/60[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m19s[0m 356ms/step - loss: 31.1481 - mean_absolute_error: 5.4686

[1m 7/60[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m18s[0m 352ms/step - loss: 30.9022 - mean_absolute_error: 5.4453

[1m 8/60[0m [32m━━[0m[37m━━━━━━━━━━━━━━━━━━[0m [1m19s[0m 379ms/step - loss: 30.7184 - mean_absolute_error: 5.4278

[1m 9/60[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m20s[0m 403ms/step - loss: 30.5462 - mean_absolute_error: 5.4111

[1m10/60[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m20s[0m 410ms/step - loss: 30.3770 - mean_absolute_error: 5.3950

[1m11/60[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m20s[0m 421ms/step - loss: 30.2459 - mean_absolute_error: 5.3823

[1m12/60[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m20s[0m 420ms/step - loss: 30.1264 - mean_absolute_error: 5.3704

[1m13/60[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m19s[0m 424ms/step - loss: 30.0015 - mean_absolute_error: 5.3583

[1m14/60[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m19s[0m 432ms/step - loss: 29.8730 - mean_absolute_error: 5.3460

[1m15/60[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m19s[0m 428ms/step - loss: 29.7536 - mean_absolute_error: 5.3347

[1m16/60[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m18s[0m 428ms/step - loss: 29.6366 - mean_absolute_error: 5.3236

[1m17/60[0m [32m━━━━━[0m[37m━━━━━━━━━━━━━━━[0m [1m18s[0m 437ms/step - loss: 29.5112 - mean_absolute_error: 5.3116

[1m18/60[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m19s[0m 458ms/step - loss: 29.3893 - mean_absolute_error: 5.3000

[1m19/60[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m19s[0m 466ms/step - loss: 29.2638 - mean_absolute_error: 5.2878

[1m20/60[0m [32m━━━━━━[0m[37m━━━━━━━━━━━━━━[0m [1m19s[0m 492ms/step - loss: 29.1353 - mean_absolute_error: 5.2755

[1m21/60[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m20s[0m 514ms/step - loss: 29.0069 - mean_absolute_error: 5.2631

[1m22/60[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m20s[0m 539ms/step - loss: 28.8762 - mean_absolute_error: 5.2503

[1m23/60[0m [32m━━━━━━━[0m[37m━━━━━━━━━━━━━[0m [1m21s[0m 586ms/step - loss: 28.7513 - mean_absolute_error: 5.2381

[1m24/60[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m23s[0m 654ms/step - loss: 28.6282 - mean_absolute_error: 5.2259

[1m25/60[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m23s[0m 669ms/step - loss: 28.5065 - mean_absolute_error: 5.2139

[1m26/60[0m [32m━━━━━━━━[0m[37m━━━━━━━━━━━━[0m [1m23s[0m 699ms/step - loss: 28.3790 - mean_absolute_error: 5.2010

[1m27/60[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m23s[0m 711ms/step - loss: 28.2532 - mean_absolute_error: 5.1883

[1m28/60[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m23s[0m 721ms/step - loss: 28.1305 - mean_absolute_error: 5.1758

[1m29/60[0m [32m━━━━━━━━━[0m[37m━━━━━━━━━━━[0m [1m22s[0m 726ms/step - loss: 28.0081 - mean_absolute_error: 5.1632

[1m30/60[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m22s[0m 740ms/step - loss: 27.8839 - mean_absolute_error: 5.1504

[1m31/60[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m22s[0m 770ms/step - loss: 27.7578 - mean_absolute_error: 5.1372

[1m32/60[0m [32m━━━━━━━━━━[0m[37m━━━━━━━━━━[0m [1m21s[0m 774ms/step - loss: 27.6304 - mean_absolute_error: 5.1237

[1m33/60[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m21s[0m 779ms/step - loss: 27.5018 - mean_absolute_error: 5.1100

[1m34/60[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m20s[0m 772ms/step - loss: 27.3739 - mean_absolute_error: 5.0963

[1m35/60[0m [32m━━━━━━━━━━━[0m[37m━━━━━━━━━[0m [1m19s[0m 763ms/step - loss: 27.2451 - mean_absolute_error: 5.0824

[1m36/60[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m18s[0m 754ms/step - loss: 27.1152 - mean_absolute_error: 5.0682

[1m37/60[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m17s[0m 745ms/step - loss: 26.9852 - mean_absolute_error: 5.0539

[1m38/60[0m [32m━━━━━━━━━━━━[0m[37m━━━━━━━━[0m [1m16s[0m 734ms/step - loss: 26.8533 - mean_absolute_error: 5.0392

KeyboardInterrupt: 