# IDC_Conv1D Model Training and Evaluation

This notebook reproduces the experiments for the IDC_Conv1D model used for drug–disease link prediction. It covers environment configuration, data loading, model definition, training with stratified K-fold cross-validation, and evaluation on a held-out test set, including ROC/PR curves and detailed metrics.

## 1. Environment Setup and Dependencies

This cell lists the main Python package dependencies. You can install them into a fresh environment (e.g., via `conda` or `venv`).

**Example (pip):**
```bash
pip install numpy pandas matplotlib scikit-learn tensorflow
```

Make sure you also have a GPU-enabled TensorFlow installation if you want to reproduce the GPU-based experiments.

## 2. Import Libraries and Configure GPU

This cell imports all required libraries (TensorFlow/Keras, NumPy, Pandas, scikit-learn, and Matplotlib) and configures the GPU for controlled memory growth, as used in the original experiments.

In [None]:
# Import Libraries
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.layers import (Input, Dense, Dropout, Conv1D, MaxPooling1D,
                                     Flatten, Reshape, UpSampling1D, Concatenate)
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
from sklearn.metrics import (roc_auc_score, precision_recall_curve, auc, brier_score_loss,
                             cohen_kappa_score, matthews_corrcoef, f1_score, recall_score,
                             accuracy_score, precision_score, roc_curve)
from sklearn.model_selection import StratifiedKFold

# (Optional) List local devices including GPUs
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

# Basic GPU configuration (if a GPU is available)
physical_devices = tf.config.list_physical_devices('GPU')
if len(physical_devices) > 0:
    try:
        tf.config.experimental.set_memory_growth(physical_devices[0], True)
    except Exception as e:
        print('Could not set memory growth:', e)

# Additional GPU configuration (TensorFlow v1-style session, optional)
gpu_options = tf.compat.v1.GPUOptions(per_process_gpu_memory_fraction=0.05)
sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(gpu_options=gpu_options))
tf.config.experimental.set_memory_growth(tf.config.list_physical_devices('GPU')[0], True)


## 3. Data Loading

This cell loads the preprocessed training and test datasets from `.npy` files. `X` arrays contain concatenated drug and disease feature vectors, and `y` arrays contain the corresponding labels.

- `X_train_0`, `y_train_0`: Training features and labels
- `X_test_0`, `Y_test_0`: Held-out test features and labels

Update the file paths if your data is stored in a different directory.

In [None]:
# Data loading
X_test_0 = np.load('./data/X_me.npy')
Y_test_0 = np.load('./data/Y_me.npy')
X_train_0 = np.load('./data/X.npy')
y_train_0 = np.load('./data/y.npy')

print('X_train shape:', X_train_0.shape)
print('y_train shape:', y_train_0.shape)
print('X_test shape:', X_test_0.shape)
print('Y_test shape:', Y_test_0.shape)


## 4. Loss Function and Basic Plotting Utilities

This cell defines:
- `calculate_class_weights`: Computes positive and negative class weights based on label distribution.
- `weighted_binary_crossentropy_loss`: A weighted binary cross-entropy loss that handles class imbalance.
- `plot_roc_pr_curves`: A helper function to plot ROC and Precision–Recall curves for a single prediction run.

In [None]:
# Define weighted loss and basic ROC/PR plotting
def calculate_class_weights(y_true):
    pos_weight = tf.reduce_sum(1 - y_true) / tf.reduce_sum(y_true)
    neg_weight = 1.0
    return pos_weight, neg_weight

def weighted_binary_crossentropy_loss(y_true, y_pred, pos_weight=None, neg_weight=None):
    if pos_weight is None or neg_weight is None:
        pos_weight, neg_weight = calculate_class_weights(y_true)
    
    y_true = tf.cast(y_true, dtype=tf.float64)
    y_pred = tf.cast(y_pred, dtype=tf.float64)
    
    # Binary cross-entropy loss
    bce_loss = tf.keras.losses.binary_crossentropy(y_true, y_pred)
    
    # Apply weights
    weighted_loss = bce_loss * (y_true * pos_weight + (1 - y_true) * neg_weight)
    
    return tf.reduce_mean(weighted_loss)

def plot_roc_pr_curves(y_true, y_pred, file_prefix):
    fpr, tpr, _ = roc_curve(y_true, y_pred)
    plt.figure(figsize=(12, 6))
    
    # ROC Curve
    plt.subplot(1, 2, 1)
    plt.plot(fpr, tpr, label='ROC curve (AUC = %0.2f)' % roc_auc_score(y_true, y_pred))
    plt.plot([0, 1], [0, 1], linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc='lower right')
    
    # Precision-Recall Curve
    precision, recall, _ = precision_recall_curve(y_true, y_pred)
    plt.subplot(1, 2, 2)
    plt.plot(recall, precision, label='Precision-Recall (AUC = %0.2f)' % auc(recall, precision))
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')
    
    plt.savefig(f'{file_prefix}_roc_pr_curve.png')
    plt.close()


## 5. Cross-Validation Plotting Utilities

This cell defines helper functions to visualize model performance across K-folds:
- `plot_roc_pr_curves_all_folds`: Plots individual and mean ROC/PR curves over all folds.
- `model_history`: Plots training and validation accuracy/loss per epoch for each fold.
- `plot_fold_pr_curves`: An additional function to visualize individual and mean PR curves across folds.

In [None]:
from sklearn.metrics import roc_curve, precision_recall_curve, roc_auc_score, auc

def plot_roc_pr_curves_all_folds(folds_roc, folds_pr, mean_fpr, mean_tpr, mean_recall, mean_precision, file_prefix):
    plt.figure(figsize=(18, 6))
    cmap = plt.get_cmap('tab10')  
    # Plot all ROC Curves
    plt.subplot(1, 2, 1)
    for fold, (fpr, tpr) in enumerate(folds_roc, 1):
        plt.plot(fpr, tpr, lw=1.2, color=cmap(fold % 10), alpha=0.6,
                 label=f'Fold {fold} ROC (AUC = {auc(fpr, tpr):.2f})')
    plt.plot(mean_fpr, mean_tpr, linestyle='--', lw=1.5,
             label=f'Mean ROC (AUC = {auc(mean_fpr, mean_tpr):.2f})', color='black')
    plt.plot([0, 1], [0, 1], linestyle='--', color='gray')
    plt.xlim([-0.05, 1.05])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC)')
    plt.legend(loc='lower right')

    # Plot all Precision-Recall Curves
    plt.subplot(1, 2, 2)
    for fold, (precision, recall) in enumerate(folds_pr, 1):
        plt.plot(recall, precision, lw=1.2, alpha=0.6, color=cmap(fold % 10),
                 label=f'Fold {fold} PR (AUC = {auc(recall, precision):.2f})')
    plt.plot(mean_recall, mean_precision, linestyle='--', lw=1.5,
             label=f'Mean PR (AUC = {auc(mean_recall, mean_precision):.2f})', color='black')
    plt.xlim([-0.05, 1.0])
    plt.ylim([-0.05, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curve')
    plt.legend(loc='lower left')

    plt.savefig(f'{file_prefix}_all_folds_roc_pr_curves.png', bbox_inches='tight')
    plt.close()

def model_history(histories, n):
    for i, history in enumerate(histories):
        plt.figure(figsize=(12, 4))
    
        plt.subplot(1, 2, 1)
        plt.plot(history.history['accuracy'])
        plt.plot(history.history['val_accuracy'])
        plt.title(f'Model accuracy for fold {i+1}')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
    
        plt.subplot(1, 2, 2)
        plt.plot(history.history['loss'])
        plt.plot(history.history['val_loss'])
        plt.title(f'Model loss for fold {i+1}')
        plt.ylabel('Loss')
        plt.xlabel('Epoch')
        plt.legend(['Train', 'Validation'], loc='upper left')
    
        plt.savefig(f'./output/{n}_fold_{i+1}_accuracy_loss.png')
        plt.close()

def plot_fold_pr_curves(folds_pr, mean_pr, file_prefix):
    plt.figure(figsize=(10, 8))
    
    for fold_pr in folds_pr:
        plt.plot(fold_pr[1], fold_pr[0], lw=1, alpha=0.7)  # recall, precision
    
    plt.plot(mean_pr[1], mean_pr[0], linestyle='--', lw=2, label='Mean Precision-Recall', color='k')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title('Precision-Recall Curves for All Folds and Mean PR')
    plt.legend(loc='lower left')
    
    plt.savefig(f'{file_prefix}_folds_pr_curves.png')
    plt.close()


## 6. IDC_Conv1D Model Definition

This cell defines the convolutional model used in the experiments:
- `first_conv_layer`: A feature extractor that applies multi-kernel convolutions, pooling, and dropout to a 650-dimensional input vector.
- `create_model`: Builds the full model by applying the convolutional block to both drug and disease feature vectors, concatenating them, and passing the result through fully connected layers to predict a binary label.

Each input (drug and disease) is a 650-dimensional vector; together they form the 1300-dimensional `X` input.

In [None]:
def first_conv_layer(input_tensor):
    x = Reshape((650, 1))(input_tensor)
    x1 = Conv1D(32, 1, activation='relu', padding='same')(x)
    x2 = Conv1D(32, 3, activation='relu', padding='same')(x1)
    x3 = Conv1D(32, 5, activation='relu', padding='same')(x1)
    x4 = MaxPooling1D(pool_size=2)(x)
    x5 = Conv1D(32, 1, activation='relu', padding='same')(x4)
    x5 = UpSampling1D(size=2)(x5)
    x = Concatenate()([x1, x2, x3, x5])
    
    x = Conv1D(32, 5, activation='relu', padding='same')(x)
    x = Conv1D(32, 3, activation='relu', padding='same')(x)
    x = Conv1D(32, 1, activation='relu', padding='same')(x)
    x = Dropout(0.1)(x)
    x = MaxPooling1D(pool_size=2)(x)
    x = Flatten()(x)
    
    return x

def create_model():
    input1 = Input(shape=(650,))  # Drug features
    input2 = Input(shape=(650,))  # Disease features
    
    x1 = first_conv_layer(input1)
    x2 = first_conv_layer(input2)
    
    concatenated = Concatenate(axis=-1)([x1, x2])
    x = Reshape((int(concatenated.shape[1]), 1))(concatenated)
    x = Flatten()(x)
    
    x = Dense(128, activation='relu')(x)
    x = Dense(128, activation='relu')(x)
    x = Dense(16, activation='relu')(x)
    output = Dense(1, activation='sigmoid')(x)
    
    model = Model(inputs=[input1, input2], outputs=output)
    return model


## 7. Training and Evaluation with Stratified K-Fold Cross-Validation

This cell defines the `train_and_evaluate` function, which:
- Performs stratified K-fold cross-validation on the training data.
- Trains a new `IDC_Conv1D` model on each fold using the weighted loss.
- Evaluates each fold on the held-out test set `X_test_0`, `Y_test_0`.
- Collects per-fold metrics (AUC-ROC, AUC-PR, accuracy, precision, recall, F1, Brier score, Cohen's kappa, MCC).
- Saves ROC/PR plots and per-fold metric summaries to the `./output` directory.

Make sure that the `./output` directory exists before running this cell.

In [None]:
import os
os.makedirs('./output', exist_ok=True)

def train_and_evaluate(X, y, X_test, Y_test, learning_rate, number, n_splits=10):
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    histories = []
    fold_roc_curves = []
    fold_pr_curves = []
    mean_fpr = np.linspace(0, 1, 650)
    mean_tpr = 0.0
    mean_recall = np.linspace(0, 1, 650)
    mean_precision = 0.0
    y_test_preds = []
    fold_metrics = []

    for fold, (train_index, val_index) in enumerate(skf.split(X, y), 1):
        print(f'Training fold {fold}...')
        
        x_train, x_val = X[train_index], X[val_index]
        y_train, y_val = y[train_index], y[val_index]
        
        l_train, r_train = x_train[:, :650], x_train[:, 650:]
        l_val, r_val = x_val[:, :650], x_val[:, 650:]
        
        model = create_model()
        optimizer = Adam(learning_rate=learning_rate)
        
        # Class weights for training data
        pos_weight, neg_weight = calculate_class_weights(y_train)
        
        # Compile with weighted binary cross-entropy loss
        model.compile(
            loss=lambda y_true, y_pred: weighted_binary_crossentropy_loss(
                y_true, y_pred, pos_weight, neg_weight
            ),
            optimizer=optimizer,
            metrics=['accuracy']
        )
        
        callbacks = [
            EarlyStopping(monitor='val_accuracy', patience=10, verbose=1, restore_best_weights=True),
            ReduceLROnPlateau(monitor='val_accuracy', factor=0.1, patience=3, verbose=1),
            ModelCheckpoint('./output/best_weights.keras', monitor='val_accuracy', save_best_only=True)
        ]
        
        history = model.fit(
            [l_train, r_train], y_train,
            batch_size=128,
            epochs=100,
            validation_data=([l_val, r_val], y_val),
            callbacks=callbacks,
            verbose=1
        )
        
        histories.append(history)
        
        # Test set prediction for current fold
        l_test, r_test = X_test[:, :650], X_test[:, 650:]
        y_test_pred = model.predict([l_test, r_test])
        y_test_preds.append(y_test_pred)

        # ROC and PR curves on test set for this fold
        fpr, tpr, _ = roc_curve(Y_test, y_test_pred)
        precision, recall, _ = precision_recall_curve(Y_test, y_test_pred)
        
        fold_roc_curves.append((fpr, tpr))
        fold_pr_curves.append((precision, recall))

        # Accumulate for mean curves
        mean_tpr += np.interp(mean_fpr, fpr, tpr)
        mean_precision += np.interp(mean_recall, recall[::-1], precision[::-1])

        # Metrics for this fold
        auc_roc = roc_auc_score(Y_test, y_test_pred)
        precision_score_value = precision_score(Y_test, (y_test_pred >= 0.5).astype(int))
        recall_score_value = recall_score(Y_test, (y_test_pred >= 0.5).astype(int))
        accuracy = accuracy_score(Y_test, (y_test_pred >= 0.5).astype(int))
        f1 = f1_score(Y_test, (y_test_pred >= 0.5).astype(int))
        brier = brier_score_loss(Y_test, y_test_pred)
        kappa = cohen_kappa_score(Y_test, (y_test_pred >= 0.5).astype(int))
        mcc = matthews_corrcoef(Y_test, (y_test_pred >= 0.5).astype(int))

        fold_metric = {
            'fold': fold,
            'auc-roc': auc_roc,
            'auc-pr': auc(recall, precision),
            'overall_recall': recall_score_value,
            'accuracy': accuracy,
            'precision': precision_score_value,
            'Brier score': brier,
            'Cohen Kappa': kappa,
            'MCC': mcc,
            'F1 score': f1
        }
        fold_metrics.append(fold_metric)

    # Save per-fold results
    fold_results_df = pd.DataFrame(fold_metrics)
    fold_results_df.to_csv('./output/fold_results.csv', index=False)
    print('Fold results saved to ./output/fold_results.csv')

    # Mean ROC/PR curves
    mean_tpr /= n_splits
    mean_tpr[-1] = 1.0
    mean_precision /= n_splits
    mean_precision[-1] = mean_precision[-2]

    model_history(histories, number)

    # Plot ROC/PR for all folds (test data predictions)
    plot_roc_pr_curves_all_folds(
        fold_roc_curves, fold_pr_curves,
        mean_fpr, mean_tpr, mean_recall, mean_precision,
        f'./output/{number}_test_set'
    )

    # Evaluate using the last fold's predictions on test set
    plot_roc_pr_curves(Y_test, y_test_preds[-1], f'./output/{number}_roc_pr')
    
    auc_roc = roc_auc_score(Y_test, y_test_preds[-1])
    precision_vals, recall_vals, _ = precision_recall_curve(Y_test, y_test_preds[-1])
    auc_pr = auc(recall_vals, precision_vals)
    print(f'auc_roc: {auc_roc}')
    print(f'auc_pr: {auc_pr}')

    overall_recall = recall_score(Y_test, (y_test_preds[-1] >= 0.5).astype(int))
    accuracy = accuracy_score(Y_test, (y_test_preds[-1] >= 0.5).astype(int))
    precision_val = precision_score(Y_test, (y_test_preds[-1] >= 0.5).astype(int))

    print(f'Overall Recall: {overall_recall}')
    print(f'Accuracy: {accuracy}')
    print(f'Precision: {precision_val}')

    brier = brier_score_loss(Y_test, y_test_preds[-1])
    print(f'Brier score: {brier}')
    
    y_hat_e = (y_test_preds[-1] >= 0.5).astype(int)
    
    kappa = cohen_kappa_score(Y_test, y_hat_e)
    print(f'Cohen Kappa score: {kappa}')
    
    mcc = matthews_corrcoef(Y_test, y_hat_e)
    print(f'MCC score: {mcc}')
    
    f1 = f1_score(Y_test, y_hat_e)
    print(f'F1 score: {f1}')
    
    return histories, auc_roc, auc_pr, overall_recall, accuracy, precision_val, brier, kappa, mcc, f1


## 8. Run the Full Experiment

This final cell runs the full training and evaluation pipeline:
- Iterates over learning rates specified in `param_grid` (here, a single value `0.0001`).
- Calls `train_and_evaluate` to perform 10-fold cross-validation and test-set evaluation.
- Stores the summary metrics for each configuration into `Results_df.csv` in the `./output` directory.

You can adjust the learning rate(s), number of folds, and other hyperparameters as needed.

In [None]:
# Main execution
start_time = time.time()

results_df = pd.DataFrame(columns=[
    'number', 'learning_rate', 'auc-roc', 'auc-pr', 'overall_recall',
    'accuracy', 'precision', 'Brier score', 'Cohen Kappa', 'MCC', 'F1 score'
])

param_grid = {'learning_rate': [0.0001]}  # You can add more learning rates here

for i, learning_rate in enumerate(param_grid['learning_rate']):
    histories, auc_roc, auc_pr, overall_recall, accuracy, precision_val, brier, kappa, mcc, f1 = train_and_evaluate(
        X_train_0, y_train_0, X_test_0, Y_test_0, learning_rate, i + 1
    )
    
    new_row = pd.DataFrame({
        'number': [i + 1],
        'learning_rate': [learning_rate],
        'auc-roc': [auc_roc],
        'auc-pr': [auc_pr],
        'overall_recall': [overall_recall],
        'accuracy': [accuracy],
        'precision': [precision_val],
        'Brier score': [brier],
        'Cohen Kappa': [kappa],
        'MCC': [mcc],
        'F1 score': [f1]
    })
    print(new_row)

    results_df = pd.concat([results_df, new_row], ignore_index=True)

results_df.to_csv('./output/Results_df.csv', index=False)

end_time = time.time()
execution_time = end_time - start_time
print(f'Execution time: {execution_time} seconds')
