In [1]:
import networkx as nx
import csv

def build_multidigraph_from_csv(csv_file):
    G = nx.MultiDiGraph()

    with open(csv_file, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            # Add nodes (if not already added)
            G.add_node(row['starter_ID'], name=row['starter_ID'])
            G.add_node(row['receiver_ID'], name=row['receiver_ID'])

            # Add directed edges with additional attributes
            # Each edge is unique and can represent a different type of interaction
            G.add_edge(
                row['starter_ID'], 
                row['receiver_ID'], 
                interaction_type=row['subtype_name'],
                relation_type=row['relation_type'],
                pathway_sources=row['pathway_source'],
                credibility=row['credibility']
            )

    return G

# Path to the CSV file
csv_file_path = 'relations_train_final.csv'  # Update this to the path of your relations_train.csv file

# Build the multidigraph
MDG = build_multidigraph_from_csv(csv_file_path)

# Print basic information about the multidigraph
print(f"Number of nodes: {MDG.number_of_nodes()}")
print(f"Number of edges: {MDG.number_of_edges()}")

# The multidigraph `MDG` is now constructed with nodes and multiple types of directed edges from the relations_train.csv
# You can now use `MDG` for further analysis or as input to clustering algorithms.



Number of nodes: 4816
Number of edges: 101782


In [2]:
import pandas as pd

def read_cluster_assignments(csv_file):
    df = pd.read_csv(csv_file)
    gene_to_cluster = {row['Gene']: row['Cluster'] for _, row in df.iterrows()}
    return gene_to_cluster

# Use this function to read your CSV file
gene_to_cluster = read_cluster_assignments('vgae_spectral_gene_cluster_assignments.csv')


In [3]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, GATConv
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx


class GraphTransformer(torch.nn.Module):
    def __init__(self, num_nodes, embedding_dim, num_classes, num_hidden_units=32, num_heads=5):
        super(GraphTransformer, self).__init__()
        self.node_emb = torch.nn.Embedding(num_nodes, embedding_dim)
        
        # First Graph Transformer layer
        self.conv1 = TransformerConv(embedding_dim, num_hidden_units, heads=num_heads, dropout=0.6, edge_dim=None)
        
        # Output layer
        self.conv2 = TransformerConv(num_hidden_units * num_heads, num_classes, heads=1, concat=True, dropout=0.6, edge_dim=None)

    def forward(self, data):
        x = self.node_emb(data.node_index)
        edge_index = data.edge_index
        edge_weight = None  # If you have edge weights, they should be used here

        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        return F.log_softmax(x, dim=1)


# Assuming MDG is your MultiDiGraph
train_data = from_networkx(MDG)
train_data.node_index = torch.arange(train_data.num_nodes)  # Node indexing

num_nodes = train_data.num_nodes
embedding_dim = 32  # Choose an appropriate embedding size
num_classes = 7   # Set the number of classes based on your edge types


In [4]:
# Collect all unique interaction types from MDG
interaction_types = set()
for _, _, edge_data in MDG.edges(data=True):
    interaction_types.add(edge_data['interaction_type'])

# Update the mapping to include 'no interaction' class
interaction_type_to_label = {inter_type: i for i, inter_type in enumerate(interaction_types)}
num_classes = len(interaction_type_to_label)

def setup_edge_labels_with_no_interaction(MDG, interaction_type_to_label, data):
    num_nodes = len(MDG.nodes())
    edge_labels = torch.zeros((num_nodes, num_nodes), dtype=torch.long)

    for u, v, edge_data in MDG.edges(data=True):
        u_index = list(MDG.nodes()).index(u)
        v_index = list(MDG.nodes()).index(v)
        label = interaction_type_to_label[edge_data['interaction_type']]
        edge_labels[u_index, v_index] = label

    data.edge_label = edge_labels  # Directly modify the data object

In [5]:
import torch.optim as optim

# Process training data
train_MDG = build_multidigraph_from_csv('relations_train_final.csv')
train_data = from_networkx(train_MDG)
train_data.node_index = torch.arange(train_data.num_nodes)
setup_edge_labels_with_no_interaction(train_MDG, interaction_type_to_label, train_data)

# Process testidation data
val_MDG = build_multidigraph_from_csv('cleaned_relations_val_final.csv')
val_data = from_networkx(val_MDG)
val_data.node_index = torch.arange(val_data.num_nodes)
setup_edge_labels_with_no_interaction(val_MDG, interaction_type_to_label, val_data)

model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [6]:
test_MDG = build_multidigraph_from_csv('cleaned_relations_test_final.csv')
test_data = from_networkx(test_MDG)
test_data.node_index = torch.arange(test_data.num_nodes)
setup_edge_labels_with_no_interaction(test_MDG, interaction_type_to_label, test_data)

In [7]:
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# Create a dataset of edges and their labels
edges = train_data.edge_index.t().tolist()  # List of [node_u, node_v]
labels = [train_data.edge_label[edge[0], edge[1]].item() for edge in edges]

edge_dataset = list(zip(edges, labels))
edge_loader = DataLoader(edge_dataset, batch_size=1024, shuffle=True)  # Adjust batch_size as needed

# Assuming val_data is structured similarly to train_data
val_edges = val_data.edge_index.t().tolist()
val_labels = [val_data.edge_label[edge[0], edge[1]].item() for edge in val_edges]

val_edge_dataset = list(zip(val_edges, val_labels))
val_edge_loader = DataLoader(val_edge_dataset, batch_size=1024, shuffle=False)  # You can adjust the batch size

test_edges = test_data.edge_index.t().tolist()
test_labels = [test_data.edge_label[edge[0], edge[1]].item() for edge in test_edges]

test_edge_dataset = list(zip(test_edges, test_labels))
test_edge_loader = DataLoader(test_edge_dataset, batch_size=1024, shuffle=False)  # You can adjust the batch size

from sklearn.metrics import f1_score, classification_report

inverse_interaction_type_to_label = {v: k for k, v in interaction_type_to_label.items()}

def validate():
    model.eval()
    total_loss = 0
    total_cross_entropy_loss = 0  # Initialize cross-entropy loss
    all_predictions = []
    all_true_labels = []

    with torch.no_grad():
        for batch in tqdm(val_edge_loader, desc='Validating'):
            edge_tensors, label_batch = batch
            node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
            label_batch = label_batch.to(device)

            output = model(val_data.to(device))
            edge_predictions = output[node_u_list].argmax(dim=1)  # Predicted classes

            # Calculate and accumulate cross-entropy loss
            cross_entropy_loss = criterion(output[node_u_list], label_batch)
            total_cross_entropy_loss += cross_entropy_loss.item()

            # Store predictions and true labels
            preds = edge_predictions.cpu().numpy()
            true_labels = label_batch.cpu().numpy()
            all_predictions.extend(preds)
            all_true_labels.extend(true_labels)

    # Calculate average cross-entropy loss over all batches
    avg_cross_entropy_loss = total_cross_entropy_loss / len(val_edge_loader)

    # Calculate F1 Score and classification report
    weighted_f1 = f1_score(all_true_labels, all_predictions, average='weighted')
    class_report = classification_report(all_true_labels, all_predictions, target_names=[inverse_interaction_type_to_label[i] for i in range(num_classes)], zero_division=0)

    print(f"Weighted F1 Score: {weighted_f1}\nClassification Report:\n{class_report}")
    print(f"Validation Cross-Entropy Loss: {avg_cross_entropy_loss}")

    return weighted_f1


In [8]:
# Instantiate the model
model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Learning rate
criterion = torch.nn.CrossEntropyLoss()  # Loss function
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move your model to the chosen device
model = model.to(device)

def train():
    model.train()
    total_loss = 0

    for batch in tqdm(edge_loader, desc='Training'):
        edge_tensors, label_batch = batch
        
        node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
        label_batch = label_batch.to(device)

        optimizer.zero_grad()
        output = model(train_data.to(device))

        edge_predictions = output[node_u_list]  # Corrected indexing
        
        loss = criterion(edge_predictions, label_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(edge_loader)

Using device: cuda


In [9]:
# Early stopping parameters
patience = 5  # Number of epochs to wait for improvement before stopping
best_val_loss = float('inf')
epochs_no_improve = 0

# Set up the learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

best_f1_score = 0
epochs_no_improve = 0
best_model_state = None

for epoch in range(1000):
    train_loss = train()  # Your training function
    val_f1_score = validate()  # Your validation function returning F1 score

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val F1 Score: {val_f1_score:.4f}')

    # Step the scheduler with the F1 score
    scheduler.step(val_f1_score)

    # Check for improvement
    if val_f1_score > best_f1_score:
        best_f1_score = val_f1_score
        epochs_no_improve = 0
        best_model_state = model.state_dict()  # Update the best model state
    else:
        epochs_no_improve += 1

    if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break

# Save the best model state (outside of the loop)
if best_model_state is not None:
    torch.save(best_model_state, 'GT_32_32_5_-3.pth')
    print("Best model saved.")
else:
    print("No model improvement was observed.")

Training: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.80it/s]


Weighted F1 Score: 0.24081056742507587
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.44      0.38      6143
         expression       0.01      0.00      0.01       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.03      0.21      0.05       694
binding/association       0.02      0.01      0.01       458
         activation       0.35      0.22      0.27      7342

           accuracy                           0.25     18128
          macro avg       0.11      0.13      0.10     18128
       weighted avg       0.26      0.25      0.24     18128

Validation Cross-Entropy Loss: 1.8602794607480366
Epoch 1, Train Loss: 2.0141, Val F1 Score: 0.2408


Training: 100%|██████████| 100/100 [01:09<00:00,  1.44it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.44it/s]


Weighted F1 Score: 0.2394147642278179
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.58      0.43      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.02      0.11      0.04       694
binding/association       0.00      0.00      0.00       458
         activation       0.33      0.18      0.23      7342

           accuracy                           0.27     18128
          macro avg       0.10      0.12      0.10     18128
       weighted avg       0.25      0.27      0.24     18128

Validation Cross-Entropy Loss: 1.7992704576916165
Epoch 2, Train Loss: 1.8647, Val F1 Score: 0.2394


Training: 100%|██████████| 100/100 [01:09<00:00,  1.43it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.70it/s]


Weighted F1 Score: 0.25689393428596985
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.67      0.45      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.02      0.06      0.03       694
binding/association       0.00      0.00      0.00       458
         activation       0.36      0.20      0.25      7342

           accuracy                           0.31     18128
          macro avg       0.10      0.13      0.11     18128
       weighted avg       0.26      0.31      0.26     18128

Validation Cross-Entropy Loss: 1.7694940699471369
Epoch 3, Train Loss: 1.8037, Val F1 Score: 0.2569


Training: 100%|██████████| 100/100 [01:10<00:00,  1.41it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.84it/s]


Weighted F1 Score: 0.25122254656789195
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.70      0.46      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.02      0.05      0.03       694
binding/association       0.00      0.00      0.00       458
         activation       0.35      0.18      0.23      7342

           accuracy                           0.31     18128
          macro avg       0.10      0.13      0.10     18128
       weighted avg       0.26      0.31      0.25     18128

Validation Cross-Entropy Loss: 1.7569267021285162
Epoch 4, Train Loss: 1.7608, Val F1 Score: 0.2512


Training: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.70it/s]


Weighted F1 Score: 0.24801680817634586
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.76      0.47      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.03      0.05      0.03       694
binding/association       0.00      0.00      0.00       458
         activation       0.36      0.15      0.22      7342

           accuracy                           0.32     18128
          macro avg       0.10      0.14      0.10     18128
       weighted avg       0.26      0.32      0.25     18128

Validation Cross-Entropy Loss: 1.7473703490363226
Epoch 5, Train Loss: 1.7320, Val F1 Score: 0.2480


Training: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.14it/s]


Weighted F1 Score: 0.24603719627419549
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.79      0.47      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.02      0.03      0.02       694
binding/association       0.00      0.00      0.00       458
         activation       0.37      0.14      0.21      7342

           accuracy                           0.33     18128
          macro avg       0.10      0.14      0.10     18128
       weighted avg       0.27      0.33      0.25     18128

Validation Cross-Entropy Loss: 1.7394248247146606
Epoch 6, Train Loss: 1.7109, Val F1 Score: 0.2460


Training: 100%|██████████| 100/100 [01:10<00:00,  1.43it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.94it/s]


Weighted F1 Score: 0.24642270263564023
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.80      0.48      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.01      0.01      0.01       694
binding/association       0.00      0.00      0.00       458
         activation       0.37      0.14      0.21      7342

           accuracy                           0.33     18128
          macro avg       0.10      0.14      0.10     18128
       weighted avg       0.27      0.33      0.25     18128

Validation Cross-Entropy Loss: 1.7327967882156372
Epoch 7, Train Loss: 1.6962, Val F1 Score: 0.2464


Training: 100%|██████████| 100/100 [01:10<00:00,  1.42it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.21it/s]

Weighted F1 Score: 0.2496369349263342
Classification Report:
                     precision    recall  f1-score   support

        no_relation       0.34      0.83      0.48      6143
         expression       0.00      0.00      0.00       722
    phosphorylation       0.00      0.00      0.00      1024
         inhibition       0.00      0.00      0.00      1745
           compound       0.01      0.01      0.01       694
binding/association       0.00      0.00      0.00       458
         activation       0.40      0.14      0.21      7342

           accuracy                           0.34     18128
          macro avg       0.11      0.14      0.10     18128
       weighted avg       0.28      0.34      0.25     18128

Validation Cross-Entropy Loss: 1.7293875416119893
Epoch 8, Train Loss: 1.6803, Val F1 Score: 0.2496
Early stopping triggered
Best model saved.



