In [2]:
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, Model
from sklearn.model_selection import StratifiedKFold
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, matthews_corrcoef, roc_auc_score, average_precision_score
from sklearn.preprocessing import StandardScaler  # Import StandardScaler

# Parameters
num_classes = 32
num_folds = 10
batch_size = 64
epochs = 10

# Data paths
data_dir = '../Data/'
labels_file_path = data_dir + '31_Y_ratio1.txt'
combinations_file_path = data_dir + '31_XIndex_ratio1.txt'
small_molecule_matrix_file = data_dir + 'total_small_drugs.txt'
biotech_matrix_file = data_dir + 'total_biotech_drugs.txt'

# Load labels
def read_label_file(file_path):
    with open(file_path, 'r') as file:
        return np.array([int(line.strip()) for line in file])

labels = np.array(read_label_file(labels_file_path)).flatten()

# Load combinations
def read_combinations_file(file_path):
    combinations = []
    with open(file_path, 'r') as file:
        for line in file:
            biotech_idx, small_idx = map(int, line.strip().split())
            combinations.append((biotech_idx, small_idx))
    return combinations

combinations = read_combinations_file(combinations_file_path)

# Compute class weights
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
class_weights_dict = {label: weight for label, weight in zip(np.unique(labels), class_weights)}

# Load feature matrices
small_molecule_total_matrix = np.loadtxt(small_molecule_matrix_file, dtype=float)
biotech_total_matrix = np.loadtxt(biotech_matrix_file, dtype=float)



from sklearn.preprocessing import MinMaxScaler
# Generate features
features = np.array([
    np.concatenate((small_molecule_total_matrix[small_idx], biotech_total_matrix[biotech_idx]))
    for biotech_idx, small_idx in combinations
])

# Normalize features
scaler = StandardScaler()
features = scaler.fit_transform(features)  # Apply normalization

# Reshape for CNN
features = features[..., np.newaxis]




# Validate dimensions
assert features.shape[0] == len(labels), "Mismatch between number of features and labels"

# Create results directory
results_directory = '../CNN_results'
os.makedirs(results_directory, exist_ok=True)
CNN_results = os.path.join(results_directory, 'CNN')
os.makedirs(CNN_results, exist_ok=True)

# Define CNN model
def build_cnn_model(input_shape, num_classes):
    input_layer = layers.Input(shape=input_shape)
    
    x = layers.Conv1D(128, 3, activation='relu')(input_layer)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Conv1D(64, 3, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Conv1D(32, 3, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Flatten()(x)
    
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dropout(0.5)(x)
    output_layer = layers.Dense(num_classes, activation='softmax')(x)
    
    model = Model(input_layer, output_layer)
    model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), 
                  loss=tf.keras.losses.SparseCategoricalCrossentropy(), 
                  metrics=['accuracy'])
    return model
def compute_metrics(y_true, y_pred, y_probs, num_classes):
    metrics = {"Micro": {}, "Macro": {}, "Weighted": {}}
    
    # Micro metrics
    metrics["Micro"]["Accuracy"] = accuracy_score(y_true, y_pred)
    metrics["Micro"]["Precision"], metrics["Micro"]["Recall"], metrics["Micro"]["F1"], _ = \
        precision_recall_fscore_support(y_true, y_pred, average='micro')
    metrics["Micro"]["MCC"] = matthews_corrcoef(y_true, y_pred)

    # Handle potential ValueError in AUC computation
    try:
        if len(np.unique(y_true)) > 1:
            metrics["Micro"]["AUC"] = roc_auc_score(y_true, y_probs, average="micro", multi_class="ovr")
            metrics["Macro"]["AUC"] = roc_auc_score(y_true, y_probs, average="macro", multi_class="ovr")
            metrics["Weighted"]["AUC"] = roc_auc_score(y_true, y_probs, average="weighted", multi_class="ovr")
        else:
            raise ValueError("AUC cannot be computed for a single-class test set")
    except ValueError:
        metrics["Micro"]["AUC"] = np.nan
        metrics["Macro"]["AUC"] = np.nan
        metrics["Weighted"]["AUC"] = np.nan



    metrics["Micro"]["AUPR"] = average_precision_score(y_true, y_probs, average="micro")

    # Macro metrics
    precision_per_class, recall_per_class, f1_per_class, _ = \
        precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    aupr_per_class = []
    for i in range(num_classes):
        try:
            score = average_precision_score((np.array(y_true) == i).astype(int), np.array(y_probs)[:, i])
        except ValueError:
            score = np.nan
        aupr_per_class.append(score)



    metrics["Macro"]["Precision"] = np.mean(precision_per_class)
    metrics["Macro"]["Recall"] = np.mean(recall_per_class)
    metrics["Macro"]["F1"] = np.mean(f1_per_class)
    metrics["Macro"]["MCC"] = matthews_corrcoef(y_true, y_pred)
    metrics["Macro"]["AUPR"] = np.mean(aupr_per_class)

    # Weighted metrics
    class_weights = np.bincount(y_true) / len(y_true)
    metrics["Weighted"]["Precision"] = np.sum(precision_per_class * class_weights)
    metrics["Weighted"]["Recall"] = np.sum(recall_per_class * class_weights)
    metrics["Weighted"]["F1"] = np.sum(f1_per_class * class_weights)
    metrics["Weighted"]["MCC"] = matthews_corrcoef(y_true, y_pred)
    metrics["Weighted"]["AUPR"] = np.average(aupr_per_class, weights=class_weights)

    return metrics


# Cross-validation
kfolder = StratifiedKFold(n_splits=num_folds, shuffle=True, random_state=42)
all_metrics = {"Micro": [], "Macro": [], "Weighted": []}

for fold_num, (train_indices, test_indices) in enumerate(kfolder.split(features, labels), start=1):
    print(f"Training Fold {fold_num}...")
    fold_directory = os.path.join(CNN_results, f'fold_{fold_num}')
    os.makedirs(fold_directory, exist_ok=True)
    
    # Split data
    X_train, y_train = features[train_indices], labels[train_indices]
    X_test, y_test = features[test_indices], labels[test_indices]
    
    # Compute class weights
    unique_train_labels = np.unique(y_train)
    train_class_weights = {i: class_weights_dict[i] if i in unique_train_labels else 1.0 for i in range(num_classes)}
    
    # Build model
    model = build_cnn_model((features.shape[1], 1), num_classes)
    
    # Train model
    model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, class_weight=train_class_weights)
    
    # Predict on test set
    y_probs = model.predict(X_test)
    y_pred = np.argmax(y_probs, axis=1)
    
    # Compute metrics
    metrics = compute_metrics(y_test, y_pred, y_probs, num_classes)
    
    # Save metrics
    for fmt in ["Micro", "Macro", "Weighted"]:
        with open(os.path.join(fold_directory, f'metrics_{fmt}.txt'), 'w') as fold_log_file:
            for metric, value in metrics[fmt].items():
                fold_log_file.write(f"{metric}: {value:.4f}\n")
        all_metrics[fmt].append(metrics[fmt])
    
    # Save model
    model.save(os.path.join(fold_directory, 'model.h5'))

# Compute average metrics
avg_metrics = {fmt: {metric: np.mean([fold_metrics[metric] for fold_metrics in all_metrics[fmt]])
                     for metric in all_metrics[fmt][0].keys()} for fmt in ["Micro", "Macro", "Weighted"]}

std_metrics = {fmt: {metric: np.std([fold_metrics[metric] for fold_metrics in all_metrics[fmt]])
                     for metric in all_metrics[fmt][0].keys()} for fmt in ["Micro", "Macro", "Weighted"]}

# Save average metrics
with open(os.path.join(CNN_results, 'average_metrics.txt'), 'w') as avg_file:
    for fmt in ["Micro", "Macro", "Weighted"]:
        for metric, value in avg_metrics[fmt].items():
            avg_file.write(f"{metric}: {value:.4f} ± {std_metrics[fmt][metric]:.4f}\n")


Training Fold 1...

Epoch 1/10


  52/1234 [>.............................] - ETA: 16:10 - loss: 6.7951 - accuracy: 0.0922

KeyboardInterrupt: 