# Dataset Splitting

In [None]:
import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split

# Set working directory and define paths for input and output data
work_dir = os.getcwd()  # Use the current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location

# Load the merged graph dataset with labels
merged_file = os.path.join(input_data_dir, 'all_graphs_with_labels-train.pt')
merged_graphs = torch.load(merged_file)

# Extract labels from each graph in the dataset, converting to NumPy array
labels = np.array([graph.y.numpy() if isinstance(graph.y, torch.Tensor) else graph.y for graph in merged_graphs])

# Function to randomly split data into training and testing sets
def random_train_test_split(graphs, labels, test_size=0.3, random_state=42):
    """Split graphs and labels into training and test sets using random split"""
    train_graphs, test_graphs, train_labels, test_labels = train_test_split(
        graphs, labels, test_size=test_size, random_state=random_state, shuffle=True)

    return train_graphs, test_graphs, train_labels, test_labels

# Split dataset into training (70%) and temporary (30%) sets
train_graphs, temp_graphs, train_labels, temp_labels = random_train_test_split(
    merged_graphs, labels, test_size=0.3, random_state=42)

# Further split the temporary set (30% of original) into validation (20%) and test sets (10%)
val_graphs, test_graphs, val_labels, test_labels = random_train_test_split(
    temp_graphs, temp_labels, test_size=0.33, random_state=42)

# Function to calculate the proportion of '1's in each label across the labels dataset
def calculate_label_proportions(labels):
    proportions = np.mean(labels == 1, axis=0)  # Calculate the proportion of '1's for each label
    return proportions

# Calculate the proportion of '1's in each subset's labels
train_proportions = calculate_label_proportions(train_labels)
val_proportions = calculate_label_proportions(val_labels)
test_proportions = calculate_label_proportions(test_labels)

# Convert labels to torch.Tensor format for compatibility with PyTorch models
train_labels = torch.tensor(train_labels)
val_labels = torch.tensor(val_labels)
test_labels = torch.tensor(test_labels)

# Print the size of each subset
print(f"Training set: {len(train_graphs)} graphs")
print(f"Validation set: {len(val_graphs)} graphs")
print(f"Test set: {len(test_graphs)} graphs")

# Print the proportion of '1's in each label for each subset
print("Proportion of '1's for each label in training set:", train_proportions)
print("Proportion of '1's for each label in validation set:", val_proportions)
print("Proportion of '1's for each label in test set:", test_proportions)


# GAT Model Architecture

In [2]:
import torch
import torch.nn as nn
from torch_geometric.nn import GATConv, global_mean_pool
from torch.optim.lr_scheduler import ReduceLROnPlateau

# Set random seed for reproducibility
def set_seed(seed):
    import random
    import numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)  # For GPU, if used

set_seed(42)  # Set the random seed

# Define the GAT-based model
class GATModel(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, num_heads, dropout_rate=0.3, dosage_weight=1.0):
        super(GATModel, self).__init__()
        self.dosage_weight = dosage_weight  # Controls amplification of the 91st feature
        self.layer1 = GATConv(in_dim, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer2 = GATConv(hidden_dim * num_heads, hidden_dim, heads=num_heads, dropout=dropout_rate)
        self.layer3 = GATConv(hidden_dim * num_heads, hidden_dim, heads=1, dropout=dropout_rate)
        self.fc = nn.Linear(hidden_dim, out_dim)

        # Initialize weights
        self._initialize_weights()

    def _initialize_weights(self):
        # Initialize weights for each layer using Xavier initialization
        for layer in [self.layer1, self.layer2, self.layer3]:
            nn.init.xavier_uniform_(layer.lin.weight)  # Linear layer weight initialization
            if layer.lin.bias is not None:
                nn.init.zeros_(layer.lin.bias)  # Initialize bias to 0

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        
        # Amplify the 91st feature (dosage feature) within a range of 0-10
        x[:, 90] = torch.clamp(x[:, 90] * self.dosage_weight, min=0, max=10)

        # Perform GAT layers computations and obtain attention weights
        h, attn_weights_1 = self.layer1(x, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        h, attn_weights_2 = self.layer2(h, edge_index, return_attention_weights=True)
        h = torch.relu(h)

        h, attn_weights_3 = self.layer3(h, edge_index, return_attention_weights=True)
        
        # Global mean pooling to aggregate node information into a graph-level representation
        hg = global_mean_pool(h, batch)
        out = self.fc(hg)
        
        # Return output, pooled node features, and attention weights from each layer
        return out, hg, (attn_weights_1, attn_weights_2, attn_weights_3)


# Model parameter configuration
in_dim = 91       # Input dimension for node features
hidden_dim = 64   # Dimension of hidden layers
out_dim = 5       # Output dimension, corresponding to 5 labels
num_heads = 4     # Number of attention heads in GAT layers
dropout_rate = 0.5
dosage_weight = 1  # Amplification weight for dosage feature

# Instantiate the GAT model
model = GATModel(in_dim, hidden_dim, out_dim, num_heads, dropout_rate, dosage_weight=dosage_weight)
print(model)


GATModel(
  (layer1): GATConv(91, 64, heads=4)
  (layer2): GATConv(256, 64, heads=4)
  (layer3): GATConv(256, 64, heads=1)
  (fc): Linear(in_features=64, out_features=5, bias=True)
)


# Model Training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, recall_score
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# Compute class weights to handle label imbalance
num_classes = train_labels.size(1)
pos_counts = train_labels.sum(dim=0)
neg_counts = train_labels.size(0) - pos_counts
pos_weight = neg_counts / (pos_counts + 1e-6)

# Loss function, optimizer, and learning rate scheduler
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
learning_rate = 0.0001
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=2)

# Early stopping class
class EarlyStopping:
    def __init__(self, patience=5, min_delta=0.001):
        self.patience = patience
        self.min_delta = min_delta
        self.best_loss = np.inf
        self.counter = 0

    def check_early_stop(self, val_loss):
        if self.best_loss - val_loss > self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
        return self.counter >= self.patience

# Print current learning rate for tracking adjustments
def print_learning_rate(optimizer):
    for param_group in optimizer.param_groups:
        print(f"Current Learning Rate: {param_group['lr']}")

# Custom collate function to directly return batches without additional processing
def custom_collate(batch):
    return batch  # Return original batch as is

# Create data loader for batching graphs and labels
def create_batches(graphs, labels, batch_size):
    for i, graph in enumerate(graphs):
        graph.y = labels[i]  # Attach labels to graph data
    data_loader = DataLoader(graphs, batch_size=batch_size, shuffle=True, collate_fn=custom_collate)
    return data_loader

# Model evaluation function
def evaluate_model(val_loader, model, loss_fn):
    model.eval()
    val_loss = 0
    all_preds, all_labels = [], []
    all_attn_weights = []  # Collect attention weights
    with torch.no_grad():
        for batch in val_loader:
            output, _, attn_weights = model(batch)

            # Adjust label shape
            batch_labels = batch.y.view(output.shape)
            loss = loss_fn(output, batch_labels)
            val_loss += loss.item()

            preds = torch.round(torch.sigmoid(output))
            all_preds.append(preds.cpu().numpy())
            all_labels.append(batch_labels.cpu().numpy())
            all_attn_weights.append(attn_weights)  # Store attention weights

    avg_val_loss = val_loss / len(val_loader)
    all_preds = np.vstack(all_preds)
    all_labels = np.vstack(all_labels)

    # Calculate recall and F1 scores
    recall = recall_score(all_labels, all_preds, average='micro')
    f1 = f1_score(all_labels, all_preds, average='micro')

    return avg_val_loss, recall, f1, all_attn_weights  # Return attention weights for analysis

# Training function for the model
def train_model(train_graphs, train_labels, val_graphs, val_labels, model, loss_fn, optimizer, scheduler, num_epochs=50, batch_size=16, early_stopping_patience=5, grad_clip_value=1.0):
    train_loader = create_batches(train_graphs, train_labels, batch_size)
    val_loader = create_batches(val_graphs, val_labels, batch_size)

    early_stopping = EarlyStopping(patience=early_stopping_patience)
    val_losses = []
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{num_epochs}"):
            # Forward pass
            output, _, _ = model(batch)  # Ignore attention weights during training

            # Adjust label shape
            batch_labels = batch.y.view(output.shape)
            loss = loss_fn(output, batch_labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_value)

            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {avg_loss:.4f}")

        # Validate the model
        val_loss, val_recall, val_f1, attn_weights = evaluate_model(val_loader, model, loss_fn)
        print(f"Epoch {epoch + 1}/{num_epochs}, Validation Loss: {val_loss:.4f}, Validation Recall: {val_recall:.4f}, Validation F1: {val_f1:.4f}")

        # Step the learning rate scheduler and print the current learning rate
        scheduler.step(val_loss)
        print_learning_rate(optimizer)

        # Check early stopping
        if early_stopping.check_early_stop(val_loss):
            print(f"Early stopping at epoch {epoch + 1}")
            break

# Train the model
train_model(train_graphs, train_labels, val_graphs, val_labels, model, loss_fn, optimizer, scheduler, num_epochs=30, batch_size=32)


# Model Evaluation

In [None]:
import torch
import numpy as np
from sklearn.metrics import roc_curve, auc, precision_recall_fscore_support, accuracy_score, confusion_matrix
import pandas as pd
import os
from matplotlib import rcParams

# Set global font to Arial for consistent plots
rcParams['font.family'] = 'Arial'

# Function to evaluate the model on a given dataset and optionally output attention weights
def evaluate_model(graphs, labels, model, output_dir, data_name, cpm_id=None):
    model.eval()  # Set model to evaluation mode
    all_outputs = []
    all_labels = []
    all_attn_weights = []

    # Collect predictions, labels, and attention weights for the dataset
    with torch.no_grad():
        for i, graph in enumerate(graphs):
            output, hg, attn_weights = model(graph)  # Model returns output and attention weights
            all_outputs.append(output.cpu().numpy())
            all_labels.append(labels[i].cpu().numpy())
            all_attn_weights.append(attn_weights)  # Store attention weights for each graph

    final_outputs = np.vstack(all_outputs)
    final_labels = np.vstack(all_labels)

    # Calculate and save performance metrics
    compute_and_save_metrics(final_labels, final_outputs, output_dir, data_name)

    # Save attention weights for a specific `cpm_id`, if provided
    if cpm_id is not None:
        output_attention_weights(all_attn_weights, graphs, cpm_id, output_dir)

# Function to save attention weights for a specific `cpm_id`
def output_attention_weights(all_attn_weights, graphs, cpm_id, output_dir):
    for i, graph in enumerate(graphs):
        if hasattr(graph, 'cpm_id') and graph.cpm_id == cpm_id:  # Find the graph with the specified `cpm_id`
            attn_weights = all_attn_weights[i]
            attn_weights_1, attn_weights_2, _ = attn_weights

            # Convert attention weights to NumPy arrays for saving
            attn_weights_1_array = [aw.cpu().numpy() for aw in attn_weights_1]
            attn_weights_2_array = [aw.cpu().numpy() for aw in attn_weights_2]

            # Extract node names
            node_names = graph.node_names  # Ensure this attribute is available

            # Transpose first attention weight array for easier access to node pairs
            transposed_0 = attn_weights_1_array[0].T  # Now shaped (num_nodes, num_heads)

            # Map indices to corresponding node names for readability
            corresponding_node_names = []
            for index_pair in transposed_0:
                name_pair = [node_names[int(idx)] for idx in index_pair]
                corresponding_node_names.append(name_pair)

            # Convert to NumPy array
            corresponding_node_names = np.array(corresponding_node_names)

            # Combine node names and attention weights for saving
            merged_array = np.column_stack((corresponding_node_names, attn_weights_1_array[1]))

            # Save as CSV
            np.savetxt(os.path.join(output_dir, f'{cpm_id}_attn_weights_1-multi_attention.csv'), merged_array, delimiter=',', fmt='%s')
            print(f"Attention weights attn_weights_1 saved as {cpm_id}_attn_weights_1-multi_attention.csv")

            # Repeat for `attn_weights_2` if needed
            if attn_weights_2_array:
                transposed_2 = attn_weights_2_array[0].T  # Transpose second array
                corresponding_node_names_2 = []
                for index_pair in transposed_2:
                    name_pair = [node_names[int(idx)] for idx in index_pair]
                    corresponding_node_names_2.append(name_pair)

                corresponding_node_names_2 = np.array(corresponding_node_names_2)
                merged_array_2 = np.column_stack((corresponding_node_names_2, attn_weights_2_array[1]))

                # Save as CSV
                np.savetxt(os.path.join(output_dir, f'{cpm_id}_attn_weights_2-multi_attention.csv'), merged_array_2, delimiter=',', fmt='%s')
                print(f"Attention weights attn_weights_2 saved as {cpm_id}_attn_weights_2-multi_attention.csv")

# Function to compute and save model performance metrics
def compute_and_save_metrics(labels, outputs, output_dir, data_name):
    num_classes = labels.shape[1]
    metrics = {
        'Class': [],
        'Precision': [],
        'Recall': [],
        'F1 Score': [],
        'AUC': [],
        'Accuracy': [],
        'Specificity': []
    }
    roc_data_long_format = {'Class': [], 'Reference': [], 'Predicted': []}
    
    for i in range(num_classes):
        # Apply sigmoid to convert logits to probabilities
        probabilities = torch.sigmoid(torch.tensor(outputs))
        
        # ROC curve and AUC calculation
        fpr, tpr, thresholds = roc_curve(labels[:, i], probabilities[:, i].numpy())
        roc_auc = auc(fpr, tpr)
        
        # Store ROC data in long format for each class
        for ref, pred in zip(labels[:, i], probabilities[:, i].numpy()):
            roc_data_long_format['Class'].append(f'Class_{i+1}')
            roc_data_long_format['Reference'].append(ref)
            roc_data_long_format['Predicted'].append(pred)
        
        # Calculate Precision, Recall, F1, Accuracy, Specificity
        pred_binary = (probabilities[:, i] > 0.5).numpy().astype(int)  # Binary predictions with threshold 0.5
        precision, recall, f1, _ = precision_recall_fscore_support(labels[:, i], pred_binary, average='binary')
        accuracy = accuracy_score(labels[:, i], pred_binary)
        
        tn, fp, fn, tp = confusion_matrix(labels[:, i], pred_binary).ravel()
        specificity = tn / (tn + fp) if (tn + fp) > 0 else 0  # Avoid division by zero

        # Store metrics for each class
        metrics['Class'].append(f'Class_{i+1}')
        metrics['Precision'].append(precision)
        metrics['Recall'].append(recall)
        metrics['F1 Score'].append(f1)
        metrics['AUC'].append(roc_auc)
        metrics['Accuracy'].append(accuracy)
        metrics['Specificity'].append(specificity)
    
    # Calculate average metrics across all classes
    avg_metrics = {
        'Class': ['Average'],
        'Precision': [np.mean(metrics['Precision'])],
        'Recall': [np.mean(metrics['Recall'])],
        'F1 Score': [np.mean(metrics['F1 Score'])],
        'AUC': [np.mean(metrics['AUC'])],
        'Accuracy': [np.mean(metrics['Accuracy'])],
        'Specificity': [np.mean(metrics['Specificity'])]
    }
    
    # Append average metrics to the metrics dictionary
    for key in metrics:
        metrics[key].append(avg_metrics[key][0])
    
    # Save ROC data in long format to CSV
    roc_df_long = pd.DataFrame(roc_data_long_format)
    roc_df_long.to_csv(os.path.join(output_dir, f'{data_name}_roc_data_multi_attention.csv'), index=False)

    # Save metrics data to CSV
    metrics_df = pd.DataFrame(metrics)
    metrics_df.to_csv(os.path.join(output_dir, f'{data_name}_metrics_multi_attention.csv'), index=False)
    print(f"Metrics and ROC data saved to {output_dir}.")

# Set working directory and define paths for input/output data
work_dir = os.getcwd()  # Use current directory as work_dir
input_data_dir = os.path.join(work_dir, '../Data')  # Set ../Data as input data location
output_dir = os.path.join(work_dir, '../Data')  # Set ../Data as output data location

# Evaluate on different datasets
# Uncomment as needed
# evaluate_model(train_graphs, train_labels, model, output_dir, "train")
# evaluate_model(val_graphs, val_labels, model, output_dir, "validation")

# Evaluate test set and optionally output attention weights for a specific `cpm_id`
# evaluate_model(test_graphs, test_labels, model, output_dir, "test-0", cpm_id='CPM05651')
evaluate_model(test_graphs, test_labels, model, output_dir, "test")
