# Tabular Transformer for DDI Prediction on Top-1000 Gene Subset
This notebook performs unseen-drug splitting on the Top-1000 gene feature subset, then builds, trains, and evaluates a Keras-based Tabular Transformer.

In [1]:
# ✅ Cell 1: Import libraries
import os
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Reshape, MultiHeadAttention, LayerNormalization, Add, Flatten, Dropout
from tensorflow.keras.callbacks import EarlyStopping
import tensorflow as tf


In [2]:
# ✅ Cell 2: Load Top-1000 pair-feature matrix
file_path = r"C:\project grad linex\project\models\last update for Deep Learning\new datasplitting\pair_feature_matrix_labeled_top_1000_gene.csv"
df = pd.read_csv(file_path)
df.dropna(inplace=True)
# Ensure consistent drug column names
if 'drugA_name' in df.columns:
    df.rename(columns={'drugA_name':'drugA','drugB_name':'drugB'}, inplace=True)
print(f"Total pairs in Top-1000 dataset: {df.shape[0]}")
df.head(2)

Total pairs in Top-1000 dataset: 489


Unnamed: 0,drugA,sig_id_A,drugB,sig_id_B,gene1_A,gene2_A,gene3_A,gene4_A,gene5_A,gene6_A,...,gene992_B,gene993_B,gene994_B,gene995_B,gene996_B,gene997_B,gene998_B,gene999_B,gene1000_B,label
0,allopurinol,CPD003_MCF7_24H:BRD-K86307448-001-09-2:10,altretamine,CPC004_HA1E_24H:BRD-K67043667-001-15-7:10,-0.000486,1.039225,-0.457621,1.259835,0.655588,-0.716319,...,-0.137224,0.097937,-0.057878,-0.764311,-0.77865,-0.324212,0.757073,-1.060788,-1.642454,0
1,allopurinol,CPD003_MCF7_24H:BRD-K86307448-001-09-2:10,anastrozole,CPC010_A549_24H:BRD-K52172416-001-07-2:10,-0.000486,1.039225,-0.457621,1.259835,0.655588,-0.716319,...,-1.05965,0.1535,0.945,0.5798,2.3737,-0.12595,-0.2901,1.3091,1.03975,0


In [3]:
# ✅ Cell 3: Unseen-drug split: hold out 5 drugs for test
all_drugs = pd.unique(df[['drugA','drugB']].values.ravel('K'))
np.random.seed(42)
test_drugs = np.random.choice(all_drugs, size=5, replace=False).tolist()
mask_test = df['drugA'].isin(test_drugs) | df['drugB'].isin(test_drugs)
df_test = df[mask_test].reset_index(drop=True)
df_train_val = df[~mask_test].reset_index(drop=True)
print('Held-out test drugs:', test_drugs)
print(f'Train+Val pairs: {df_train_val.shape[0]}')
print(f'Test pairs:      {df_test.shape[0]}')

Held-out test drugs: ['floxuridine', 'allopurinol', 'thiotepa', 'azacitidine', 'crizotinib']
Train+Val pairs: 363
Test pairs:      126


In [4]:
# ✅ Cell 4: Train/Validation split and prepare arrays
feature_cols = [c for c in df.columns if c.startswith('gene')]
df_train, df_val = train_test_split(
    df_train_val, test_size=0.2, stratify=df_train_val['label'], random_state=42
)
X_train = df_train[feature_cols].values
y_train = df_train['label'].values
X_val   = df_val[feature_cols].values
y_val   = df_val['label'].values
X_test  = df_test[feature_cols].values
y_test  = df_test['label'].values
print(f"X_train: {X_train.shape}, y_train: {y_train.shape}")
print(f"X_val:   {X_val.shape}, y_val:   {y_val.shape}")
print(f"X_test:  {X_test.shape}, y_test:  {y_test.shape}")

X_train: (290, 2000), y_train: (290,)
X_val:   (73, 2000), y_val:   (73,)
X_test:  (126, 2000), y_test:  (126,)


In [5]:
# ✅ Cell 5: Build the Tabular Transformer model
def build_tabular_transformer(input_dim,
                              num_tokens=40,
                              token_dim=50,
                              num_heads=4,
                              ff_dim=128,
                              num_layers=2,
                              dropout_rate=0.2):
    inputs = Input(shape=(input_dim,), name="features")
    x = Dense(num_tokens * token_dim, activation="relu")(inputs)
    x = Reshape((num_tokens, token_dim))(x)
    for _ in range(num_layers):
        attn = MultiHeadAttention(num_heads=num_heads, key_dim=token_dim)(x, x)
        x = Add()([x, attn])
        x = LayerNormalization()(x)
        ff = Dense(ff_dim, activation="relu")(x)
        ff = Dense(token_dim)(ff)
        x = Add()([x, ff])
        x = LayerNormalization()(x)
        x = Dropout(dropout_rate)(x)
    x = Flatten()(x)
    outputs = Dense(1, activation="sigmoid", name="synergy")(x)
    model = Model(inputs, outputs, name="TabularTransformer_Top1000")
    model.compile(optimizer="adam",
                  loss="binary_crossentropy",
                  metrics=["accuracy", tf.keras.metrics.AUC(name='auroc')])
    return model

input_dim = X_train.shape[1]
model = build_tabular_transformer(input_dim)
model.summary()

In [6]:
# ✅ Cell 6: Train the model
early_stop = EarlyStopping(monitor="val_loss", patience=5, restore_best_weights=True)
history = model.fit(
    X_train, y_train,
    validation_data=(X_val, y_val),
    epochs=50,
    batch_size=32,
    callbacks=[early_stop],
    verbose=2
)

Epoch 1/50
10/10 - 7s - 699ms/step - accuracy: 0.7103 - auroc: 0.5438 - loss: 1.7640 - val_accuracy: 0.6438 - val_auroc: 0.6822 - val_loss: 0.9783
Epoch 2/50
10/10 - 1s - 70ms/step - accuracy: 0.8241 - auroc: 0.8148 - loss: 0.3987 - val_accuracy: 0.8082 - val_auroc: 0.6737 - val_loss: 0.5025
Epoch 3/50
10/10 - 1s - 74ms/step - accuracy: 0.8172 - auroc: 0.8182 - loss: 0.3888 - val_accuracy: 0.7945 - val_auroc: 0.6834 - val_loss: 0.4937
Epoch 4/50
10/10 - 1s - 57ms/step - accuracy: 0.7793 - auroc: 0.7493 - loss: 0.4476 - val_accuracy: 0.7945 - val_auroc: 0.6786 - val_loss: 0.5124
Epoch 5/50
10/10 - 1s - 60ms/step - accuracy: 0.8379 - auroc: 0.8991 - loss: 0.3173 - val_accuracy: 0.7671 - val_auroc: 0.6743 - val_loss: 0.6128
Epoch 6/50
10/10 - 1s - 56ms/step - accuracy: 0.8276 - auroc: 0.8906 - loss: 0.3277 - val_accuracy: 0.6986 - val_auroc: 0.6646 - val_loss: 0.8871
Epoch 7/50
10/10 - 0s - 49ms/step - accuracy: 0.8483 - auroc: 0.8985 - loss: 0.3439 - val_accuracy: 0.7671 - val_auroc: 0.6

In [7]:
# ✅ Cell 7: Evaluate on Test Set
y_pred = model.predict(X_test).ravel()
y_class = (y_pred > 0.5).astype(int)

print("Classification Report:")
print(classification_report(y_test, y_class))
print("Test AUROC:", roc_auc_score(y_test, y_pred))

[1m4/4[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 98ms/step
Classification Report:
              precision    recall  f1-score   support

           0       0.60      0.99      0.74        71
           1       0.89      0.15      0.25        55

    accuracy                           0.62       126
   macro avg       0.74      0.57      0.50       126
weighted avg       0.73      0.62      0.53       126

Test AUROC: 0.5549295774647887
