In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import networkx as nx
import csv

def build_multidigraph_from_csv(csv_file):
    G = nx.MultiDiGraph()

    with open(csv_file, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            # Add nodes (if not already added)
            G.add_node(row['starter_ID'], name=row['starter_ID'])
            G.add_node(row['receiver_ID'], name=row['receiver_ID'])

            # Add directed edges with additional attributes
            # Each edge is unique and can represent a different type of interaction
            G.add_edge(
                row['starter_ID'], 
                row['receiver_ID'], 
                interaction_type=row['subtype_name'],
                relation_type=row['relation_type'],
                pathway_sources=row['pathway_source'],
                credibility=row['credibility']
            )

    return G

# Path to the CSV file
csv_file_path = 'relations_train_final.csv'  # Update this to the path of your relations_train.csv file

# Build the multidigraph
MDG = build_multidigraph_from_csv(csv_file_path)

# Print basic information about the multidigraph
print(f"Number of nodes: {MDG.number_of_nodes()}")
print(f"Number of edges: {MDG.number_of_edges()}")

# The multidigraph `MDG` is now constructed with nodes and multiple types of directed edges from the relations_train.csv
# You can now use `MDG` for further analysis or as input to clustering algorithms.



Number of nodes: 4816
Number of edges: 101782


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, GATConv
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx


class GraphTransformer(torch.nn.Module):
    def __init__(self, num_nodes, embedding_dim, num_classes, cluster_feature_dim=92, num_hidden_units=32, num_heads=5):
        super(GraphTransformer, self).__init__()
        self.node_emb = torch.nn.Embedding(num_nodes, embedding_dim)

        # First Graph Transformer layer
        self.conv1 = TransformerConv(embedding_dim + cluster_feature_dim, num_hidden_units, heads=num_heads, dropout=0.6, edge_dim=None)
        
        # Output layer
        self.conv2 = TransformerConv(num_hidden_units * num_heads, num_classes, heads=1, concat=True, dropout=0.6, edge_dim=None)

    def forward(self, data, cluster_features):
        x = self.node_emb(data.node_index)
        # cluster_features is a [num_nodes x 92] tensor containing the fuzzy membership scores
        x = torch.cat([x, cluster_features], dim=1)  # Concatenate node and cluster features

        edge_index = data.edge_index
        edge_weight = None  # Update if you have edge weights

        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        return F.log_softmax(x, dim=1)


# Assuming MDG is your MultiDiGraph
train_data = from_networkx(MDG)

num_nodes = train_data.num_nodes
embedding_dim = 32  # Choose an appropriate embedding size
num_classes = 7   # Set the number of classes based on your edge types
num_clusters = 100

In [4]:
# Collect all unique interaction types from MDG
interaction_types = set()
for _, _, edge_data in MDG.edges(data=True):
    interaction_types.add(edge_data['interaction_type'])

# Update the mapping to include 'no interaction' class
interaction_type_to_label = {inter_type: i for i, inter_type in enumerate(interaction_types)}
num_classes = len(interaction_type_to_label)

def setup_edge_labels_with_no_interaction(MDG, interaction_type_to_label, data):
    num_nodes = len(MDG.nodes())
    edge_labels = torch.zeros((num_nodes, num_nodes), dtype=torch.long)

    for u, v, edge_data in MDG.edges(data=True):
        u_index = list(MDG.nodes()).index(u)
        v_index = list(MDG.nodes()).index(v)
        label = interaction_type_to_label[edge_data['interaction_type']]
        edge_labels[u_index, v_index] = label

    data.edge_label = edge_labels  # Directly modify the data object

In [5]:
import pandas as pd

def read_cluster_assignments(csv_file):
    return pd.read_csv(csv_file).set_index('Gene')

# Use this function to read your CSV file
cluster_df = read_cluster_assignments('vgae_fcm_gene_cluster_assignments.csv')

def map_nodes_to_cluster_features(G, cluster_df):
    cluster_features = []
    for node in G.nodes():
        # Default to a vector of zeros if the gene is not in the DataFrame
        features = cluster_df.loc[node].values if node in cluster_df.index else np.zeros(cluster_df.shape[1])
        cluster_features.append(features)
    return torch.tensor(cluster_features, dtype=torch.float)


In [6]:
import torch.optim as optim
import numpy as np

# Function to map nodes in a graph to global indices
def map_nodes_to_global_indices(G, global_node_to_index):
    return [global_node_to_index[node] for node in G.nodes()]


# Process validation data
val_MDG = build_multidigraph_from_csv('cleaned_relations_val_final.csv')
val_data = from_networkx(val_MDG)

# Create a global node to index mapping
all_nodes = set(MDG.nodes()).union(set(val_MDG.nodes()))
global_node_to_index = {node: idx for idx, node in enumerate(all_nodes)}


setup_edge_labels_with_no_interaction(MDG, interaction_type_to_label, train_data)
train_data.node_index = torch.tensor(map_nodes_to_global_indices(MDG, global_node_to_index), dtype=torch.long)
val_data.node_index = torch.tensor(map_nodes_to_global_indices(val_MDG, global_node_to_index), dtype=torch.long)

# Now proceed with the rest of your data preparation
train_cluster_features = map_nodes_to_cluster_features(MDG, cluster_df)
val_cluster_features = map_nodes_to_cluster_features(val_MDG, cluster_df)
setup_edge_labels_with_no_interaction(val_MDG, interaction_type_to_label, val_data)

model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

  return torch.tensor(cluster_features, dtype=torch.float)


In [7]:
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# Create a dataset of edges and their labels
edges = train_data.edge_index.t().tolist()  # List of [node_u, node_v]
labels = [train_data.edge_label[edge[0], edge[1]].item() for edge in edges]

edge_dataset = list(zip(edges, labels))
edge_loader = DataLoader(edge_dataset, batch_size=1024, shuffle=True)  # Adjust batch_size as needed

# Assuming val_data is structured similarly to train_data
val_edges = val_data.edge_index.t().tolist()
val_labels = [val_data.edge_label[edge[0], edge[1]].item() for edge in val_edges]

val_edge_dataset = list(zip(val_edges, val_labels))
val_edge_loader = DataLoader(val_edge_dataset, batch_size=1024, shuffle=False)  # You can adjust the batch size

from sklearn.metrics import f1_score, classification_report

inverse_interaction_type_to_label = {v: k for k, v in interaction_type_to_label.items()}

def validate(model, val_data, val_cluster_features, val_edge_loader, criterion, device):
    model.eval()
    total_cross_entropy_loss = 0
    all_predictions = []
    all_true_labels = []

    with torch.no_grad():
        for batch in tqdm(val_edge_loader, desc='Validating'):
            edge_tensors, label_batch = batch

            node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
            label_batch = label_batch.to(device)

            # Pass both val_data and val_cluster_features to the model
            output = model(val_data.to(device), val_cluster_features.to(device))
            edge_predictions = output[node_u_list].argmax(dim=1)

            cross_entropy_loss = criterion(output[node_u_list], label_batch)
            total_cross_entropy_loss += cross_entropy_loss.item()

            preds = edge_predictions.cpu().numpy()
            true_labels = label_batch.cpu().numpy()
            all_predictions.extend(preds)
            all_true_labels.extend(true_labels)

    avg_cross_entropy_loss = total_cross_entropy_loss / len(val_edge_loader)
    weighted_f1 = f1_score(all_true_labels, all_predictions, average='weighted')
    class_report = classification_report(all_true_labels, all_predictions, target_names=[inverse_interaction_type_to_label[i] for i in range(num_classes)], zero_division=0)

    print(f"Weighted F1 Score: {weighted_f1}\nClassification Report:\n{class_report}")
    print(f"Validation Cross-Entropy Loss: {avg_cross_entropy_loss}")

    return weighted_f1

In [8]:
# Instantiate the model
model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Learning rate
criterion = torch.nn.CrossEntropyLoss()  # Loss function
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move your model to the chosen device
model = model.to(device)

def train(model, train_data, train_cluster_features, edge_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0

    for batch in tqdm(edge_loader, desc='Training'):
        edge_tensors, label_batch = batch
        
        node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
        label_batch = label_batch.to(device)
        optimizer.zero_grad()

        # Pass both train_data and train_cluster_features to the model
        output = model(train_data.to(device), train_cluster_features.to(device))
        edge_predictions = output[node_u_list]  # Corrected indexing
        
        loss = criterion(edge_predictions, label_batch)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(edge_loader)


Using device: cuda


In [9]:
# Early stopping and scheduler parameters
patience = 10
best_val_loss = float('inf')
epochs_no_improve = 0

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

best_f1_score = 0
best_model_state = None

for epoch in range(1000):
    # Call train function with necessary parameters
    train_loss = train(model, train_data, train_cluster_features, edge_loader, optimizer, criterion, device)

    # Call validate function with necessary parameters
    val_f1_score = validate(model, val_data, val_cluster_features, val_edge_loader, criterion, device)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val F1 Score: {val_f1_score:.4f}')

    # Step the scheduler based on validation F1 score
    scheduler.step(val_f1_score)

    # Check for improvement in validation F1 score
    if val_f1_score > best_f1_score:
        best_f1_score = val_f1_score
        epochs_no_improve = 0
        best_model_state = model.state_dict()  # Save the best model state
    else:
        epochs_no_improve += 1

    # Early stopping
    if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break

# Save the best model state
if best_model_state is not None:
    torch.save(best_model_state, 'GTCFFuzzy_32_32_5_best.pth')
    print("Best model saved.")
else:
    print("No model improvement was observed.")


Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.42it/s]


Weighted F1 Score: 0.15171183995448803
Classification Report:
                     precision    recall  f1-score   support

           compound       0.04      0.24      0.07       694
         expression       0.03      0.08      0.05       722
         inhibition       0.10      0.09      0.09      1745
binding/association       0.02      0.10      0.04       458
    phosphorylation       0.03      0.10      0.04      1024
        no_relation       0.33      0.23      0.27      6143
         activation       0.42      0.06      0.11      7342

           accuracy                           0.13     18128
          macro avg       0.14      0.13      0.10     18128
       weighted avg       0.30      0.13      0.15     18128

Validation Cross-Entropy Loss: 1.9766252173317804
Epoch 1, Train Loss: 2.0380, Val F1 Score: 0.1517


Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.00it/s]


Weighted F1 Score: 0.1683099398166673
Classification Report:
                     precision    recall  f1-score   support

           compound       0.04      0.24      0.07       694
         expression       0.04      0.07      0.05       722
         inhibition       0.10      0.06      0.08      1745
binding/association       0.03      0.10      0.04       458
    phosphorylation       0.02      0.06      0.03      1024
        no_relation       0.33      0.29      0.31      6143
         activation       0.41      0.07      0.12      7342

           accuracy                           0.15     18128
          macro avg       0.14      0.13      0.10     18128
       weighted avg       0.29      0.15      0.17     18128

Validation Cross-Entropy Loss: 1.9606052372190688
Epoch 2, Train Loss: 2.0132, Val F1 Score: 0.1683


Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.65it/s]


Weighted F1 Score: 0.18162084894128164
Classification Report:
                     precision    recall  f1-score   support

           compound       0.05      0.28      0.08       694
         expression       0.01      0.02      0.02       722
         inhibition       0.15      0.06      0.08      1745
binding/association       0.02      0.05      0.03       458
    phosphorylation       0.02      0.05      0.03      1024
        no_relation       0.33      0.35      0.34      6143
         activation       0.42      0.08      0.13      7342

           accuracy                           0.17     18128
          macro avg       0.14      0.13      0.10     18128
       weighted avg       0.30      0.17      0.18     18128

Validation Cross-Entropy Loss: 1.944985224141015
Epoch 3, Train Loss: 1.9941, Val F1 Score: 0.1816


Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.66it/s]


Weighted F1 Score: 0.1976746806298375
Classification Report:
                     precision    recall  f1-score   support

           compound       0.05      0.29      0.08       694
         expression       0.01      0.01      0.01       722
         inhibition       0.12      0.04      0.06      1745
binding/association       0.02      0.04      0.03       458
    phosphorylation       0.02      0.04      0.02      1024
        no_relation       0.34      0.41      0.37      6143
         activation       0.46      0.09      0.15      7342

           accuracy                           0.20     18128
          macro avg       0.14      0.13      0.10     18128
       weighted avg       0.31      0.20      0.20     18128

Validation Cross-Entropy Loss: 1.9298885001076593
Epoch 4, Train Loss: 1.9721, Val F1 Score: 0.1977


Training: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.44it/s]


Weighted F1 Score: 0.2105401363286851
Classification Report:
                     precision    recall  f1-score   support

           compound       0.05      0.33      0.09       694
         expression       0.01      0.01      0.01       722
         inhibition       0.13      0.04      0.06      1745
binding/association       0.02      0.04      0.03       458
    phosphorylation       0.01      0.02      0.02      1024
        no_relation       0.34      0.46      0.39      6143
         activation       0.48      0.10      0.16      7342

           accuracy                           0.22     18128
          macro avg       0.15      0.14      0.11     18128
       weighted avg       0.32      0.22      0.21     18128

Validation Cross-Entropy Loss: 1.9151354829470317
Epoch 5, Train Loss: 1.9517, Val F1 Score: 0.2105


Training: 100%|██████████| 100/100 [01:13<00:00,  1.36it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.36it/s]


Weighted F1 Score: 0.22487051143288644
Classification Report:
                     precision    recall  f1-score   support

           compound       0.05      0.29      0.09       694
         expression       0.00      0.00      0.00       722
         inhibition       0.22      0.03      0.06      1745
binding/association       0.03      0.02      0.03       458
    phosphorylation       0.01      0.02      0.02      1024
        no_relation       0.34      0.54      0.42      6143
         activation       0.46      0.11      0.18      7342

           accuracy                           0.24     18128
          macro avg       0.16      0.15      0.11     18128
       weighted avg       0.33      0.24      0.22     18128

Validation Cross-Entropy Loss: 1.900701814227634
Epoch 6, Train Loss: 1.9310, Val F1 Score: 0.2249


Training: 100%|██████████| 100/100 [01:13<00:00,  1.36it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.48it/s]


Weighted F1 Score: 0.2239954531715866
Classification Report:
                     precision    recall  f1-score   support

           compound       0.04      0.23      0.07       694
         expression       0.00      0.00      0.00       722
         inhibition       0.08      0.01      0.01      1745
binding/association       0.02      0.02      0.02       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.59      0.43      6143
         activation       0.46      0.11      0.18      7342

           accuracy                           0.25     18128
          macro avg       0.14      0.14      0.10     18128
       weighted avg       0.31      0.25      0.22     18128

Validation Cross-Entropy Loss: 1.8867918716536627
Epoch 7, Train Loss: 1.9106, Val F1 Score: 0.2240


Training: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.86it/s]


Weighted F1 Score: 0.23282051892506284
Classification Report:
                     precision    recall  f1-score   support

           compound       0.04      0.19      0.07       694
         expression       0.00      0.00      0.00       722
         inhibition       0.14      0.01      0.01      1745
binding/association       0.01      0.01      0.01       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.62      0.44      6143
         activation       0.47      0.12      0.20      7342

           accuracy                           0.27     18128
          macro avg       0.14      0.14      0.10     18128
       weighted avg       0.32      0.27      0.23     18128

Validation Cross-Entropy Loss: 1.8731168177392747
Epoch 8, Train Loss: 1.8971, Val F1 Score: 0.2328


Training: 100%|██████████| 100/100 [01:13<00:00,  1.36it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.27it/s]


Weighted F1 Score: 0.2378958987114396
Classification Report:
                     precision    recall  f1-score   support

           compound       0.06      0.19      0.09       694
         expression       0.00      0.00      0.00       722
         inhibition       0.16      0.01      0.01      1745
binding/association       0.01      0.01      0.01       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.68      0.45      6143
         activation       0.47      0.12      0.20      7342

           accuracy                           0.29     18128
          macro avg       0.15      0.14      0.11     18128
       weighted avg       0.33      0.29      0.24     18128

Validation Cross-Entropy Loss: 1.8599629004796345
Epoch 9, Train Loss: 1.8768, Val F1 Score: 0.2379


Training: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.57it/s]


Weighted F1 Score: 0.2479807736317506
Classification Report:
                     precision    recall  f1-score   support

           compound       0.06      0.15      0.08       694
         expression       0.00      0.00      0.00       722
         inhibition       0.18      0.01      0.01      1745
binding/association       0.01      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.72      0.46      6143
         activation       0.49      0.14      0.21      7342

           accuracy                           0.31     18128
          macro avg       0.15      0.14      0.11     18128
       weighted avg       0.33      0.31      0.25     18128

Validation Cross-Entropy Loss: 1.8470862706502278
Epoch 10, Train Loss: 1.8606, Val F1 Score: 0.2480


Training: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.55it/s]


Weighted F1 Score: 0.2451516719013837
Classification Report:
                     precision    recall  f1-score   support

           compound       0.06      0.13      0.08       694
         expression       0.00      0.00      0.00       722
         inhibition       0.12      0.00      0.01      1745
binding/association       0.01      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.74      0.47      6143
         activation       0.49      0.13      0.20      7342

           accuracy                           0.31     18128
          macro avg       0.15      0.14      0.11     18128
       weighted avg       0.33      0.31      0.25     18128

Validation Cross-Entropy Loss: 1.8346104423205059
Epoch 11, Train Loss: 1.8435, Val F1 Score: 0.2452


Training: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.56it/s]


Weighted F1 Score: 0.24236224249153626
Classification Report:
                     precision    recall  f1-score   support

           compound       0.03      0.05      0.04       694
         expression       0.00      0.00      0.00       722
         inhibition       0.00      0.00      0.00      1745
binding/association       0.01      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.78      0.48      6143
         activation       0.48      0.12      0.20      7342

           accuracy                           0.32     18128
          macro avg       0.12      0.14      0.10     18128
       weighted avg       0.31      0.32      0.24     18128

Validation Cross-Entropy Loss: 1.8226220077938504
Epoch 12, Train Loss: 1.8293, Val F1 Score: 0.2424


Training: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.60it/s]


Weighted F1 Score: 0.2418830277522208
Classification Report:
                     precision    recall  f1-score   support

           compound       0.04      0.05      0.05       694
         expression       0.00      0.00      0.00       722
         inhibition       0.00      0.00      0.00      1745
binding/association       0.01      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
        no_relation       0.34      0.81      0.48      6143
         activation       0.47      0.12      0.19      7342

           accuracy                           0.32     18128
          macro avg       0.12      0.14      0.10     18128
       weighted avg       0.31      0.32      0.24     18128

Validation Cross-Entropy Loss: 1.8111601803037856
Epoch 13, Train Loss: 1.8156, Val F1 Score: 0.2419


Training:  11%|█         | 11/100 [00:08<01:08,  1.30it/s]


KeyboardInterrupt: 