In [1]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [2]:
import networkx as nx
import csv

def build_multidigraph_from_csv(csv_file):
    G = nx.MultiDiGraph()

    with open(csv_file, 'r') as file:
        reader = csv.DictReader(file)
        for row in reader:
            # Add nodes (if not already added)
            G.add_node(row['starter_ID'], name=row['starter_ID'])
            G.add_node(row['receiver_ID'], name=row['receiver_ID'])

            # Add directed edges with additional attributes
            # Each edge is unique and can represent a different type of interaction
            G.add_edge(
                row['starter_ID'], 
                row['receiver_ID'], 
                interaction_type=row['subtype_name'],
                relation_type=row['relation_type'],
                pathway_sources=row['pathway_source'],
                credibility=row['credibility']
            )

    return G

# Path to the CSV file
csv_file_path = 'relations_train_final.csv'  # Update this to the path of your relations_train.csv file

# Build the multidigraph
MDG = build_multidigraph_from_csv(csv_file_path)

# Print basic information about the multidigraph
print(f"Number of nodes: {MDG.number_of_nodes()}")
print(f"Number of edges: {MDG.number_of_edges()}")

# The multidigraph `MDG` is now constructed with nodes and multiple types of directed edges from the relations_train.csv
# You can now use `MDG` for further analysis or as input to clustering algorithms.



Number of nodes: 4816
Number of edges: 101782


In [3]:
import pandas as pd

def read_cluster_assignments(csv_file):
    df = pd.read_csv(csv_file)
    gene_to_cluster = {row['Gene']: row['Cluster'] for _, row in df.iterrows()}
    return gene_to_cluster

# Use this function to read your CSV file
gene_to_cluster = read_cluster_assignments('vgae_kmeans_gene_cluster_assignments.csv')


In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import TransformerConv, GATConv
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx


class GraphTransformer(torch.nn.Module):
    def __init__(self, num_nodes, embedding_dim, num_classes, cluster_feature_dim=8, num_hidden_units=32, num_heads=5):
        super(GraphTransformer, self).__init__()
        self.node_emb = torch.nn.Embedding(num_nodes, embedding_dim)
        self.cluster_emb = torch.nn.Embedding(num_clusters, cluster_feature_dim)  # Initialize cluster embeddings

        # Freeze the cluster embeddings
        self.cluster_emb.weight.requires_grad = False

        # First Graph Transformer layer
        self.conv1 = TransformerConv(embedding_dim + cluster_feature_dim, num_hidden_units, heads=num_heads, dropout=0.6, edge_dim=None)
        
        # Output layer
        self.conv2 = TransformerConv(num_hidden_units * num_heads, num_classes, heads=1, concat=True, dropout=0.6, edge_dim=None)

    def forward(self, data, cluster_indices):
        x = self.node_emb(data.node_index)
        cluster_x = self.cluster_emb(cluster_indices)  # Embedding for cluster
        x = torch.cat([x, cluster_x], dim=1)  # Concatenate node and cluster embeddings

        edge_index = data.edge_index
        edge_weight = None  # Update if you have edge weights

        x = F.dropout(x, p=0.6, training=self.training)
        x = F.elu(self.conv1(x, edge_index, edge_weight))
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index, edge_weight)

        return F.log_softmax(x, dim=1)


# Assuming MDG is your MultiDiGraph
train_data = from_networkx(MDG)

num_nodes = train_data.num_nodes
embedding_dim = 32  # Choose an appropriate embedding size
num_classes = 7   # Set the number of classes based on your edge types
num_clusters = 100

In [5]:
# Collect all unique interaction types from MDG
interaction_types = set()
for _, _, edge_data in MDG.edges(data=True):
    interaction_types.add(edge_data['interaction_type'])

# Update the mapping to include 'no interaction' class
interaction_type_to_label = {inter_type: i for i, inter_type in enumerate(interaction_types)}
num_classes = len(interaction_type_to_label)

def setup_edge_labels_with_no_interaction(MDG, interaction_type_to_label, data):
    num_nodes = len(MDG.nodes())
    edge_labels = torch.zeros((num_nodes, num_nodes), dtype=torch.long)

    for u, v, edge_data in MDG.edges(data=True):
        u_index = list(MDG.nodes()).index(u)
        v_index = list(MDG.nodes()).index(v)
        label = interaction_type_to_label[edge_data['interaction_type']]
        edge_labels[u_index, v_index] = label

    data.edge_label = edge_labels  # Directly modify the data object

In [6]:
def map_nodes_to_clusters(G, gene_to_cluster):
    # Assuming gene_to_cluster is the dictionary from gene to cluster ID
    node_to_cluster = {}
    for node in G.nodes():
        cluster_id = gene_to_cluster.get(node, 99)  # default_cluster_id can be set to a specific value
        node_to_cluster[node] = cluster_id

    return node_to_cluster

node_to_cluster = map_nodes_to_clusters(MDG, gene_to_cluster)

In [7]:
import torch.optim as optim

# Process training data
val_MDG = build_multidigraph_from_csv('cleaned_relations_val_final.csv')

# Create a global node to index mapping
all_nodes = set(MDG.nodes()).union(set(val_MDG.nodes()))
global_node_to_index = {node: idx for idx, node in enumerate(all_nodes)}

# Function to map nodes in a graph to global indices
def map_nodes_to_global_indices(G, global_node_to_index):
    return [global_node_to_index[node] for node in G.nodes()]

val_data = from_networkx(val_MDG)

train_data.node_index = torch.tensor(map_nodes_to_global_indices(MDG, global_node_to_index), dtype=torch.long)

cluster_indices = [node_to_cluster[node] for node in MDG.nodes()]
cluster_indices_tensor = torch.tensor(cluster_indices, dtype=torch.long)

val_data.node_index = torch.tensor(map_nodes_to_global_indices(val_MDG, global_node_to_index), dtype=torch.long)
setup_edge_labels_with_no_interaction(val_MDG, interaction_type_to_label, val_data)
setup_edge_labels_with_no_interaction(MDG, interaction_type_to_label, train_data)
val_node_to_cluster = map_nodes_to_clusters(val_MDG, gene_to_cluster)

# Convert the node-cluster mapping to a tensor for the validation set
val_cluster_indices = [val_node_to_cluster[node] for node in val_MDG.nodes()]
val_cluster_indices_tensor = torch.tensor(val_cluster_indices, dtype=torch.long)


model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

In [8]:
import torch.optim as optim
import torch.nn.functional as F
from torch_geometric.loader import DataLoader
from tqdm import tqdm

# Create a dataset of edges and their labels
edges = train_data.edge_index.t().tolist()  # List of [node_u, node_v]
labels = [train_data.edge_label[edge[0], edge[1]].item() for edge in edges]

edge_dataset = list(zip(edges, labels))
edge_loader = DataLoader(edge_dataset, batch_size=1024, shuffle=True)  # Adjust batch_size as needed

# Assuming val_data is structured similarly to train_data
val_edges = val_data.edge_index.t().tolist()
val_labels = [val_data.edge_label[edge[0], edge[1]].item() for edge in val_edges]

val_edge_dataset = list(zip(val_edges, val_labels))
val_edge_loader = DataLoader(val_edge_dataset, batch_size=1024, shuffle=False)  # You can adjust the batch size

from sklearn.metrics import f1_score, classification_report

inverse_interaction_type_to_label = {v: k for k, v in interaction_type_to_label.items()}

def validate(val_data, val_cluster_indices_tensor, device, model, criterion):
    model.eval()
    total_cross_entropy_loss = 0  # Initialize cross-entropy loss
    all_predictions = []
    all_true_labels = []

    with torch.no_grad():
        for batch in tqdm(val_edge_loader, desc='Validating'):
            edge_tensors, label_batch = batch
            node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
            label_batch = label_batch.to(device)

            # Pass both val_data and val_cluster_indices_tensor to the model
            output = model(val_data.to(device), val_cluster_indices_tensor.to(device))
            edge_predictions = output[node_u_list].argmax(dim=1)  # Predicted classes

            # Calculate and accumulate cross-entropy loss
            cross_entropy_loss = criterion(output[node_u_list], label_batch)
            total_cross_entropy_loss += cross_entropy_loss.item()

            # Store predictions and true labels
            preds = edge_predictions.cpu().numpy()
            true_labels = label_batch.cpu().numpy()
            all_predictions.extend(preds)
            all_true_labels.extend(true_labels)

    # Calculate average cross-entropy loss over all batches
    avg_cross_entropy_loss = total_cross_entropy_loss / len(val_edge_loader)

    # Calculate F1 Score and classification report
    weighted_f1 = f1_score(all_true_labels, all_predictions, average='weighted')
    class_report = classification_report(all_true_labels, all_predictions, target_names=[inverse_interaction_type_to_label[i] for i in range(num_classes)], zero_division=0)

    print(f"Weighted F1 Score: {weighted_f1}\nClassification Report:\n{class_report}")
    print(f"Validation Cross-Entropy Loss: {avg_cross_entropy_loss}")

    return weighted_f1

In [9]:
# Instantiate the model
model = GraphTransformer(num_nodes, embedding_dim, num_classes)
optimizer = optim.Adam(model.parameters(), lr=0.0001)  # Learning rate
criterion = torch.nn.CrossEntropyLoss()  # Loss function
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Move your model to the chosen device
model = model.to(device)

def train(train_data, cluster_indices_tensor, model, optimizer, criterion, edge_loader, device):
    model.train()
    total_loss = 0

    for batch in tqdm(edge_loader, desc='Training'):
        edge_tensors, label_batch = batch
        
        node_u_list, node_v_list = edge_tensors[0].to(device), edge_tensors[1].to(device)
        label_batch = label_batch.to(device)

        optimizer.zero_grad()

        # Pass both train_data and cluster_indices_tensor to the model
        output = model(train_data.to(device), cluster_indices_tensor.to(device))

        edge_predictions = output[node_u_list]  # Corrected indexing
        
        loss = criterion(edge_predictions, label_batch)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(edge_loader)

Using device: cuda


In [10]:
# Early stopping and scheduler parameters
patience = 5
best_val_loss = float('inf')
epochs_no_improve = 0

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5)

best_f1_score = 0
best_model_state = None

for epoch in range(1000):
    # Call train function with necessary parameters
    train_loss = train(train_data, cluster_indices_tensor, model, optimizer, criterion, edge_loader, device)

    # Call validate function with necessary parameters
    val_f1_score = validate(val_data, val_cluster_indices_tensor, device, model, criterion)

    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Val F1 Score: {val_f1_score:.4f}')

    # Step the scheduler based on validation F1 score
    scheduler.step(val_f1_score)

    # Check for improvement in validation F1 score
    if val_f1_score > best_f1_score:
        best_f1_score = val_f1_score
        epochs_no_improve = 0
        best_model_state = model.state_dict()  # Save the best model state
    else:
        epochs_no_improve += 1

    # Early stopping
    if epochs_no_improve >= patience:
        print("Early stopping triggered")
        break

# Save the best model state
if best_model_state is not None:
    torch.save(best_model_state, 'new_GT_Freeze_Kmeans_32_32_5_best.pth')
    print("Best model saved.")
else:
    print("No model improvement was observed.")

Training: 100%|██████████| 100/100 [01:47<00:00,  1.07s/it]
Validating: 100%|██████████| 18/18 [00:02<00:00,  6.63it/s]


Weighted F1 Score: 0.2852809388645003
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.17      0.23      6143
           compound       0.05      0.16      0.07       694
         expression       0.03      0.01      0.01       722
binding/association       0.02      0.07      0.03       458
    phosphorylation       0.04      0.03      0.03      1024
         activation       0.43      0.60      0.50      7342

           accuracy                           0.31     18128
          macro avg       0.13      0.15      0.13     18128
       weighted avg       0.30      0.31      0.29     18128

Validation Cross-Entropy Loss: 1.82952794763777
Epoch 1, Train Loss: 2.0879, Val F1 Score: 0.2853


Training: 100%|██████████| 100/100 [02:18<00:00,  1.39s/it]
Validating: 100%|██████████| 18/18 [00:04<00:00,  3.92it/s]


Weighted F1 Score: 0.3195036404296214
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.31      0.33      6143
           compound       0.04      0.08      0.06       694
         expression       0.00      0.00      0.00       722
binding/association       0.03      0.04      0.03       458
    phosphorylation       0.05      0.01      0.02      1024
         activation       0.44      0.60      0.51      7342

           accuracy                           0.35     18128
          macro avg       0.13      0.15      0.13     18128
       weighted avg       0.30      0.35      0.32     18128

Validation Cross-Entropy Loss: 1.7523737682236566
Epoch 2, Train Loss: 1.8893, Val F1 Score: 0.3195


Training: 100%|██████████| 100/100 [02:29<00:00,  1.49s/it]
Validating: 100%|██████████| 18/18 [00:04<00:00,  3.91it/s]


Weighted F1 Score: 0.3362972586028845
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.43      0.38      6143
           compound       0.02      0.03      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.06      0.04      0.05       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.46      0.57      0.51      7342

           accuracy                           0.38     18128
          macro avg       0.13      0.15      0.14     18128
       weighted avg       0.30      0.38      0.34     18128

Validation Cross-Entropy Loss: 1.7197898493872747
Epoch 3, Train Loss: 1.8059, Val F1 Score: 0.3363


Training: 100%|██████████| 100/100 [02:28<00:00,  1.49s/it]
Validating: 100%|██████████| 18/18 [00:04<00:00,  3.84it/s]


Weighted F1 Score: 0.34733644045542766
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.35      0.46      0.39      6143
           compound       0.05      0.03      0.04       694
         expression       0.00      0.00      0.00       722
binding/association       0.10      0.03      0.05       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.59      0.52      7342

           accuracy                           0.40     18128
          macro avg       0.14      0.16      0.14     18128
       weighted avg       0.31      0.40      0.35     18128

Validation Cross-Entropy Loss: 1.7017865180969238
Epoch 4, Train Loss: 1.7582, Val F1 Score: 0.3473


Training: 100%|██████████| 100/100 [01:32<00:00,  1.08it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.55it/s]


Weighted F1 Score: 0.35090284845573017
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.46      0.39      6143
           compound       0.07      0.03      0.04       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.61      0.53      7342

           accuracy                           0.41     18128
          macro avg       0.13      0.16      0.14     18128
       weighted avg       0.31      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6894561383459303
Epoch 5, Train Loss: 1.7338, Val F1 Score: 0.3509


Training: 100%|██████████| 100/100 [01:13<00:00,  1.35it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  6.83it/s]


Weighted F1 Score: 0.3491482820507393
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.46      0.39      6143
           compound       0.05      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.61      0.53      7342

           accuracy                           0.40     18128
          macro avg       0.12      0.16      0.14     18128
       weighted avg       0.31      0.40      0.35     18128

Validation Cross-Entropy Loss: 1.680410901705424
Epoch 6, Train Loss: 1.7079, Val F1 Score: 0.3491


Training: 100%|██████████| 100/100 [01:13<00:00,  1.36it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.00it/s]


Weighted F1 Score: 0.34974702368828403
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.47      0.39      6143
           compound       0.07      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.61      0.53      7342

           accuracy                           0.41     18128
          macro avg       0.13      0.16      0.14     18128
       weighted avg       0.31      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6727469960848491
Epoch 7, Train Loss: 1.6929, Val F1 Score: 0.3497


Training: 100%|██████████| 100/100 [01:12<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.23it/s]


Weighted F1 Score: 0.35098125551107806
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.46      0.39      6143
           compound       0.07      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.63      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.13      0.16      0.14     18128
       weighted avg       0.31      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6629633969730802
Epoch 8, Train Loss: 1.6758, Val F1 Score: 0.3510


Training: 100%|██████████| 100/100 [01:14<00:00,  1.34it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.17it/s]


Weighted F1 Score: 0.3522964724586283
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.45      0.39      6143
           compound       0.09      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.64      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.13      0.16      0.14     18128
       weighted avg       0.31      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.655489398373498
Epoch 9, Train Loss: 1.6580, Val F1 Score: 0.3523


Training: 100%|██████████| 100/100 [01:14<00:00,  1.35it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.33it/s]


Weighted F1 Score: 0.3521901120691977
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.46      0.39      6143
           compound       0.09      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.00      0.00      0.00      1024
         activation       0.47      0.62      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.13      0.16      0.14     18128
       weighted avg       0.31      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.650462223423852
Epoch 10, Train Loss: 1.6450, Val F1 Score: 0.3522


Training: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.68it/s]


Weighted F1 Score: 0.35680455021488106
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.47      0.39      6143
           compound       0.10      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.31      0.05      0.09      1024
         activation       0.48      0.61      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.18      0.16      0.15     18128
       weighted avg       0.33      0.41      0.36     18128

Validation Cross-Entropy Loss: 1.641876094871097
Epoch 11, Train Loss: 1.6338, Val F1 Score: 0.3568


Training: 100%|██████████| 100/100 [01:12<00:00,  1.38it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  8.11it/s]


Weighted F1 Score: 0.3523364918950407
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.00      0.00      0.00      1745
        no_relation       0.34      0.41      0.38      6143
           compound       0.06      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.29      0.05      0.09      1024
         activation       0.46      0.65      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.16      0.16      0.15     18128
       weighted avg       0.32      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6338738467958238
Epoch 12, Train Loss: 1.6231, Val F1 Score: 0.3523


Training: 100%|██████████| 100/100 [01:13<00:00,  1.36it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.87it/s]


Weighted F1 Score: 0.3557224697191228
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.09      0.00      0.00      1745
        no_relation       0.34      0.42      0.38      6143
           compound       0.16      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.31      0.11      0.16      1024
         activation       0.46      0.64      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.20      0.17      0.16     18128
       weighted avg       0.34      0.41      0.36     18128

Validation Cross-Entropy Loss: 1.6282564004262288
Epoch 13, Train Loss: 1.6133, Val F1 Score: 0.3557


Training: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.84it/s]


Weighted F1 Score: 0.35240440123800376
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.09      0.00      0.00      1745
        no_relation       0.34      0.45      0.39      6143
           compound       0.14      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.31      0.11      0.16      1024
         activation       0.46      0.60      0.52      7342

           accuracy                           0.40     18128
          macro avg       0.19      0.17      0.16     18128
       weighted avg       0.33      0.40      0.35     18128

Validation Cross-Entropy Loss: 1.6234848499298096
Epoch 14, Train Loss: 1.6042, Val F1 Score: 0.3524


Training: 100%|██████████| 100/100 [01:13<00:00,  1.35it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.87it/s]


Weighted F1 Score: 0.3527075435294619
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.09      0.00      0.00      1745
        no_relation       0.34      0.41      0.37      6143
           compound       0.14      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.31      0.11      0.16      1024
         activation       0.45      0.64      0.53      7342

           accuracy                           0.41     18128
          macro avg       0.19      0.17      0.16     18128
       weighted avg       0.33      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6152955624792311
Epoch 15, Train Loss: 1.5960, Val F1 Score: 0.3527


Training: 100%|██████████| 100/100 [01:13<00:00,  1.37it/s]
Validating: 100%|██████████| 18/18 [00:02<00:00,  7.93it/s]

Weighted F1 Score: 0.3515829949280002
Classification Report:
                     precision    recall  f1-score   support

         inhibition       0.09      0.00      0.00      1745
        no_relation       0.35      0.38      0.36      6143
           compound       0.14      0.02      0.03       694
         expression       0.00      0.00      0.00       722
binding/association       0.00      0.00      0.00       458
    phosphorylation       0.30      0.11      0.16      1024
         activation       0.45      0.67      0.54      7342

           accuracy                           0.41     18128
          macro avg       0.19      0.17      0.16     18128
       weighted avg       0.33      0.41      0.35     18128

Validation Cross-Entropy Loss: 1.6086680955357022
Epoch 16, Train Loss: 1.5871, Val F1 Score: 0.3516
Early stopping triggered
Best model saved.



