In [None]:
# === Import Libraries ===
import numpy as np
import pandas as pd
import torch
# Evaluation metrics from scikit-learn
from sklearn.metrics import roc_auc_score, f1_score, precision_score, accuracy_score, average_precision_score
# PyTorch Geometric modules for graph data handling
from torch_geometric.data import Data
from torch_geometric.nn import SAGEConv
# PyTorch modules for model layers and loss functions
from torch.nn import BCEWithLogitsLoss, Linear, ModuleDict, LeakyReLU
import torch.nn.functional as F
# Preprocessing and utility libraries
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
import time

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# %% === Reproducibility: Set Random Seed for Determinism ===
seed = 42
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True  # Ensures consistent behavior
torch.backends.cudnn.benchmark = False  # Disables dynamic kernel optimization for reproducibility

In [None]:
# %% === Load Node and Edge Data ===
nodes = pd.read_csv("nodes.tsv", sep="\t") # Load node metadata
edges = pd.read_csv("edges.tsv", sep="\t") # Load edge list

In [None]:
# %% === Clean and Preprocess String Fields ===
# Strip whitespace from node and edge identifiers to ensure consistency
edges["source"] = edges["source"].str.strip()
edges["target"] = edges["target"].str.strip()
edges["metaedge"] = edges["metaedge"].str.strip()
nodes["id"] = nodes["id"].str.strip()
# %% === Encode Node and Edge Categories ===
# Encode node types (e.g., gene, drug) into integer labels
le_node_kind = LabelEncoder()
nodes['kind_encoded'] = le_node_kind.fit_transform(nodes['kind'])
# Encode edge types (e.g., CbG, CrC) into integer labels
le_metaedge = LabelEncoder()
edges['metaedge_encoded'] = le_metaedge.fit_transform(edges['metaedge'])
# %% === Filter Nodes to Only Include Those Present in Edges ===
# Create a set of active node IDs that appear as source or target in edges
active_nodes = set(edges["source"]).union(set(edges["target"]))
# Filter nodes DataFrame to include only those involved in at least one edge
nodes = nodes[nodes["id"].isin(active_nodes)].reset_index(drop=True)
# %% === Create Mapping from Node ID to Node Index ===
# This will be used to map string-based node IDs to integer indices 
node_id_map = {node_id: idx for idx, node_id in enumerate(nodes["id"].unique())}
num_nodes = len(node_id_map) # Total number of active nodes

# %% === Construct compatible Edge Index Tensor ===
# Create edge_index with shape [2, num_edges], where each column is [source, target]
edge_index = torch.tensor([[node_id_map[src], node_id_map[tgt]] for src, tgt in zip(edges["source"], edges["target"]) if src in active_nodes and tgt in active_nodes], dtype=torch.long).t()
# %% === Construct Edge Attribute Tensor ===
# Each edge gets an integer attribute corresponding to its metaedge (type)
edge_attr = torch.tensor(edges[edges["source"].isin(active_nodes) & edges["target"].isin(active_nodes)]['metaedge_encoded'].values, dtype=torch.long)

In [None]:
# === Function to Ensure edge_index and edge_attr Are Aligned ===
def ensure_alignment(edge_index, edge_attr):
        # Ensure that each edge has a corresponding edge attribute
    if edge_index.size(1) != edge_attr.size(0):
        raise ValueError("edge_index and edge_attr must have matching dimensions.")
    # Sort edges by source node for consistent ordering
    sorted_idx = edge_index[0].argsort()
    edge_index = edge_index[:, sorted_idx]
    edge_attr = edge_attr[sorted_idx]
    # Final check after sorting
    assert edge_index.size(1) == edge_attr.size(0), "Mismatch after sorting edge_index and edge_attr."
    return edge_index, edge_attr

In [None]:
# %% === Align edge tensors ===
edge_index, edge_attr = ensure_alignment(edge_index, edge_attr)
# Initialize node features with Xavier-uniform random values (128 dims)
node_features = torch.nn.init.xavier_uniform_(torch.rand(num_nodes, 128))
# One-hot encode node types based on their encoded 'kind'
kind_embeddings = torch.eye(len(le_node_kind.classes_))[nodes['kind_encoded']].float()
kind_embeddings = kind_embeddings[:node_features.size(0)]
# Concatenate structural features with node kind embeddings
node_features = torch.cat([node_features, kind_embeddings], dim=1)

In [None]:
# Multi-Task GraphSAGE Model
class MultiTaskGraphSAGE(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, task_outputs, num_edge_types):
        super(MultiTaskGraphSAGE, self).__init__()
        # Two-layer GraphSAGE encoder with mean aggregation
        self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='mean', bias=True)
        self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggr='mean', bias=True)
        # Learnable embeddings for edge types
        self.edge_type_embeddings = torch.nn.Embedding(num_edge_types, hidden_dim)
        # Task-specific prediction heads (each task has a separate MLP)
        self.task_heads = ModuleDict({task: torch.nn.Sequential(
            Linear(hidden_dim * 3, hidden_dim),
            LeakyReLU(negative_slope=0.2),
            Linear(hidden_dim, output_dim),
        ) for task, output_dim in task_outputs.items()
        })
        self.dropout = torch.nn.Dropout(0.3)
        self.leaky_relu = LeakyReLU(negative_slope=0.2)

    def forward(self, x, edge_index, edge_attr):
        # Two GraphSAGE convolutional layers with LeakyReLU and dropout
        x = self.leaky_relu(self.conv1(x, edge_index))
        x = self.dropout(self.leaky_relu(self.conv2(x, edge_index)))
        # Incorporate global edge-type bias into node embeddings
        edge_bias = self.edge_type_embeddings(edge_attr).mean(dim=0)
        x = x + edge_bias
        return x

    def predict(self, embeddings, edge_pairs, edge_attr, task):
        # Predict scores for given edges and a specific task
        assert edge_pairs.size(1) == edge_attr.size(0)
        src_embeddings = embeddings[edge_pairs[0]]
        tgt_embeddings = embeddings[edge_pairs[1]]
        edge_type_emb = self.edge_type_embeddings(edge_attr)
        # Concatenate source, target, and edge type embeddings
        concat_embeddings = torch.cat([src_embeddings, tgt_embeddings, edge_type_emb], dim=1)
        # Pass through task-specific prediction head
        output = self.task_heads[task](concat_embeddings)
        return torch.sigmoid(output)

In [None]:
# Teacher Model (inherits MultiTaskGraphSAGE)
class TeacherModel(MultiTaskGraphSAGE):
    def __init__(self, input_dim, hidden_dim, task_outputs, num_edge_types):
        super(TeacherModel, self).__init__(input_dim, hidden_dim, task_outputs, num_edge_types)

In [None]:
# Student Model (lightweight version of Teacher Model)
class StudentModel(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, task_outputs, num_edge_types):
        super(StudentModel, self).__init__()
        # Single GraphSAGE layer for lighter computation
        self.conv1 = SAGEConv(input_dim, hidden_dim, aggr='mean', bias=True)
        # Edge type embeddings
        self.edge_type_embeddings = torch.nn.Embedding(num_edge_types, hidden_dim)
        # Task-specific prediction heads
        self.task_heads = ModuleDict({
            task: torch.nn.Sequential(
                Linear(hidden_dim * 3, hidden_dim),
                LeakyReLU(negative_slope=0.2),
                Linear(hidden_dim, output_dim),
            ) for task, output_dim in task_outputs.items()
        })

    def forward(self, x, edge_index, edge_attr):
        # Apply one GraphSAGE convolution layer
        x = self.conv1(x, edge_index)
        # Apply mean edge-type bias to node embeddings
        edge_bias = self.edge_type_embeddings(edge_attr).mean(dim=0)
        x = x + edge_bias
        return x

    def predict(self, embeddings, edge_pairs, edge_attr, task):
        # Predict edge probabilities for a given task
        src_embeddings = embeddings[edge_pairs[0]]
        tgt_embeddings = embeddings[edge_pairs[1]]
        edge_type_emb = self.edge_type_embeddings(edge_attr)
        concat_embeddings = torch.cat([src_embeddings, tgt_embeddings, edge_type_emb], dim=1)
        output = self.task_heads[task](concat_embeddings)
        return torch.sigmoid(output)

In [None]:
# Weighted Binary Cross Entropy Loss Wrapper
class WeightedBCEWithLogitsLoss(torch.nn.Module):
    def __init__(self, pos_weight=None):
        super(WeightedBCEWithLogitsLoss, self).__init__()
        self.criterion = BCEWithLogitsLoss(pos_weight=pos_weight)

    def forward(self, inputs, targets):
        return self.criterion(inputs, targets)

In [None]:
# Knowledge Distillation Loss using KL Divergence
def distillation_loss(y_student, y_teacher, temperature_initial=2.0, annealing_rate=0.9):
    # Adjust temperature for softening probabilities
    temperature = temperature_initial * annealing_rate
    # Compute soft targets from teacher outputs
    soft_targets = F.softmax(y_teacher / temperature, dim=-1)
    # Compute log probabilities for student outputs
    student_log_probs = F.log_softmax(y_student / temperature, dim=-1)
    # KL divergence loss scaled by squared temperature (standard in KD)
    loss = F.kl_div(student_log_probs, soft_targets, reduction='batchmean') * (temperature ** 2)
    return loss

In [None]:
def sample_negative_edges(edge_index, num_neg_samples, num_nodes):
    """
    Samples a set of negative edges (non-existent links) by randomly choosing 
    node pairs that are not already connected in the edge_index.
    
    Args:
        edge_index: Edge indices tensor.
        num_neg_samples: Number of negative samples to generate.
        num_nodes: Total number of nodes in the graph.

    Returns:
        Tensor: Negative edge indices.
    """
    existing = set((i.item(), j.item()) for i, j in zip(*edge_index))
    neg_edges = set()
    while len(neg_edges) < num_neg_samples:
        src = torch.randint(0, num_nodes, (num_neg_samples,))
        dst = torch.randint(0, num_nodes, (num_neg_samples,))
        candidates = set((s.item(), d.item()) for s, d in zip(src, dst) if s != d)
        new_samples = candidates - existing - neg_edges
        neg_edges.update(list(new_samples)[:num_neg_samples - len(neg_edges)])
    return torch.tensor(list(neg_edges), dtype=torch.long).T

In [None]:
class WeightedLossCombiner(torch.nn.Module):
    """
    Combines task loss and distillation loss using a weighted sum.
    
    Args:
        alpha (float): Weight for task loss.
        beta (float): Weight for distillation loss.
    """
    def __init__(self, alpha=0.3, beta=0.7):
        super(WeightedLossCombiner, self).__init__()
        self.alpha = alpha
        self.beta = beta
        
    def forward(self, loss, dist_loss):
        return self.alpha * loss + self.beta * dist_loss

In [None]:
def batch_sampling(edge_pairs, edge_attr, batch_size, device):
    """
    Randomly samples a mini-batch of edge pairs and their corresponding attributes.

    Args:
        edge_pairs: A tensor for representing source and target node indices for each edge.
        edge_attr: A tensor containing attributes for each edge.
        batch_size: The number of edges to sample in the mini-batch.

    Returns:
        Tuple[Tensor, Tensor]: A tuple containing:
            - edge_pairs_sampled: Sampled edge indices of shape.
            - edge_attr_sampled: Sampled edge attributes of shape.
    """
    num_edges = edge_pairs.size(1)
    indices = torch.randperm(num_edges)[:batch_size]
    return edge_pairs[:, indices].to(device), edge_attr[indices].to(device)

In [None]:
def split_data(edge_index, edge_attr, val_ratio=0.2, test_ratio=0.2, random_state=42):
    """
    Splits the input edge list and associated attributes into training, validation, and test sets.

    Parameters:
        edge_index: A tensor representing the source and target nodes of N edges.
        edge_attr: A tensor containing attributes associated with each edge.
        val_ratio: The proportion of edges to use for validation.
        test_ratio: The proportion of edges to use for testing.
        random_state: Seed for reproducible random splitting (default: 42).

    Returns:
        dict: A dictionary containing 'train', 'val', and 'test' splits. 
              Each split is a dictionary with:
                  - 'edge_index': A tensor of edge indices for the split.
                  - 'edge_attr': A tensor of corresponding edge attributes for the split.

    Notes:
        - The training set is sampled first, and the remaining edges are then split into validation and test sets.
        - The splitting is stratified only by edge indices, not by label or edge type.
    """
    num_edges = edge_index.size(1)
    edge_ids = torch.arange(num_edges).tolist()
    train_ids, remaining_ids = train_test_split(edge_ids, test_size=val_ratio + test_ratio, random_state=random_state)
    val_size = val_ratio / (val_ratio + test_ratio)
    val_ids, test_ids = train_test_split(remaining_ids, test_size=1 - val_size, random_state=random_state)
    split = {
        'train': {
            'edge_index': edge_index[:, train_ids],
            'edge_attr': edge_attr[train_ids],},
        'val': {
            'edge_index': edge_index[:, val_ids],
            'edge_attr': edge_attr[val_ids],},
        'test': {
            'edge_index': edge_index[:, test_ids],
            'edge_attr': edge_attr[test_ids],}
    }

    return split

In [None]:
val_ratio = 0.2
test_ratio = 0.2

split = split_data(edge_index, edge_attr, val_ratio=val_ratio, test_ratio=test_ratio)

train_edge_index = split['train']['edge_index']
train_edge_attr = split['train']['edge_attr']
val_edge_index = split['val']['edge_index']
val_edge_attr = split['val']['edge_attr']
test_edge_index = split['test']['edge_index']
test_edge_attr = split['test']['edge_attr']

In [None]:
def train_teacher(model, node_features, train_edge_index, train_edge_attr, task_edges, optimizer, criterion, num_nodes, device, batch_size, iterations=10):
    """
    Trains the teacher model for multi-task link prediction over a knowledge graph.

    This function performs supervised training using positive and sampled negative edges for each task. 
    For each batch, it computes the model outputs, constructs labels, and calculates the classification loss.
    The loss is aggregated across tasks and iterations, and model parameters are updated via backpropagation.

    Args:
        model: The teacher model used to generate embeddings and task-specific predictions.
        node_features: Input node feature matrix.
        train_edge_index: Edge index tensor for training.
        train_edge_attr: Edge attributes tensor.
        task_edges: Dictionary mapping each task name to a dictionary containing task-specific edge sets, including the 'train' key.
        optimizer: Optimizer for updating the model's parameters.
        criterion: Binary classification loss function.
        batch_size: Number of edge samples in each training batch.
        iterations: Number of full training iterations (default: 10).

    Returns:
        float: The average total loss across all iterations.
    """
    model.train()
    for _ in range(iterations):
        optimizer.zero_grad()
        embeddings = model(node_features, train_edge_index, train_edge_attr)
        total_loss = 0
        for task, edges in task_edges.items():
            positive_pairs = edges['train'].to(device)
            edge_attr = train_edge_attr[positive_pairs[0]].to(device)

            task_loss = 0
            for start in range(0, positive_pairs.size(1), batch_size):
                batch_pos_pairs = positive_pairs[:, start:start + batch_size]
                batch_edge_attr = edge_attr[start:start + batch_size]
                batch_neg_pairs = sample_negative_edges(batch_pos_pairs, batch_pos_pairs.size(1), num_nodes).to(device)
                pos_out = model.predict(embeddings, batch_pos_pairs, batch_edge_attr, task)
                neg_out = model.predict(embeddings, batch_neg_pairs, batch_edge_attr, task)
                all_out = torch.cat([pos_out, neg_out])
                all_labels = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
                criterion.pos_weight = torch.tensor([neg_out.size(0) / pos_out.size(0)]).to(device)
                loss = criterion(all_out, all_labels)
                task_loss += loss
            task_loss.backward(retain_graph=True)
            total_loss += task_loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    return total_loss / iterations

In [None]:
def train_student(model, teacher_model, node_features, train_edge_index, train_edge_attr, task_edges, optimizer, criterion, distillation_criterion, num_nodes, device, loss_combiner, batch_size, iterations=10):
    """
    This function performs multi-task training over a knowledge graph using a combination of 
    supervised learning (classification loss) and distillation loss from the teacher model.

    Args:
        model: The student model to be trained.
        teacher_model: The teacher model used for distillation.
        node_features: Input node feature matrix.
        train_edge_index: Edge index tensor for training.
        train_edge_attr: Edge attributes tensor.
        task_edges: Dictionary mapping each task name to a dictionary containing task-specific 'train' edge index tensors.
        optimizer: Optimizer for the student model.
        criterion: Binary classification loss function.
        distillation_criterion: Loss function measuring divergence between student and teacher predictions.
        loss_combiner: Function to combine task and distillation losses.
        batch_size: Number of edge samples per training batch.
        iterations: Number of training iterations (default: 10).

    Returns:
        float: The average loss across all tasks and training iterations.
    """
    model.train()
    teacher_model.eval()
    for _ in range(iterations):
        optimizer.zero_grad()
        embeddings = model(node_features, train_edge_index, train_edge_attr)
        
        total_loss = 0
        with torch.no_grad():
            teacher_embeddings = teacher_model(node_features, train_edge_index, train_edge_attr)

        for task, edges in task_edges.items():
            positive_pairs = edges['train'].to(device)
            edge_attr = train_edge_attr[positive_pairs[0]].to(device)
            task_loss = 0
            for start in range(0, positive_pairs.size(1), batch_size):
                batch_pos_pairs = positive_pairs[:, start:start + batch_size]
                batch_edge_attr = edge_attr[start:start + batch_size]
                batch_neg_pairs = sample_negative_edges(batch_pos_pairs, batch_pos_pairs.size(1), num_nodes).to(device)
                pos_out = model.predict(embeddings, batch_pos_pairs, batch_edge_attr, task)
                neg_out = model.predict(embeddings, batch_neg_pairs, batch_edge_attr, task)
                
                all_out = torch.cat([pos_out, neg_out])
                all_labels = torch.cat([torch.ones_like(pos_out), torch.zeros_like(neg_out)])
                criterion.pos_weight = torch.tensor([neg_out.size(0) / pos_out.size(0)]).to(device)
                loss_task = criterion(all_out, all_labels)
                with torch.no_grad():
                    teacher_pos_out = teacher_model.predict(teacher_embeddings, batch_pos_pairs, batch_edge_attr, task)
                    teacher_neg_out = teacher_model.predict(teacher_embeddings, batch_neg_pairs, batch_edge_attr, task)
                    teacher_out = torch.cat([teacher_pos_out, teacher_neg_out])

                loss_distill = distillation_criterion(all_out, teacher_out)
                combined_loss = loss_combiner(loss_task, loss_distill)
                task_loss += combined_loss
            task_loss.backward(retain_graph=True)
            total_loss += task_loss.item()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

    return total_loss / iterations  

In [None]:
def validate_model(model, node_features, val_edge_index, val_edge_attr, task_edges, num_nodes, iterations=10):
    """
    Evaluates a trained model on validation edges for each task using multiple performance metrics.

    Args:
        model: Trained model to evaluate.
        node_features: Node feature matrix.
        val_edge_index: Edge index tensor for the validation graph.
        val_edge_attr: Edge attribute tensor for the validation edges.
        task_edges: Dictionary mapping each task name to a dictionary containing task-specific edge splits (including 'val').
        iterations: Number of evaluation repetitions for robust metrics (default: 10).

    Returns:
        Dict: Dictionary mapping each task to a dictionary of aggregated metrics:
            - auc_mean / auc_std
            - aupr_mean / aupr_std
            - accuracy_mean / accuracy_std
            - f1_mean / f1_std
            - precision_mean / precision_std
    """
    model.eval()
    all_metrics = {task: {'auc': [], 'aupr': [], 'accuracy': [], 'f1': [], 'precision': []} for task in task_edges}

    for _ in range(iterations):
        with torch.no_grad():
            embeddings = model(node_features, val_edge_index, val_edge_attr)
            for task, edges in task_edges.items():
                edge_split = edges['val'].to(device)
                edge_attr = val_edge_attr[edge_split[0]].to(device)
                pos_out = model.predict(embeddings, edge_split, edge_attr, task)
                neg_edges = sample_negative_edges(edge_split, edge_split.size(1), num_nodes)
                neg_attr = edge_attr[neg_edges[0]]
                neg_out = model.predict(embeddings, neg_edges, neg_attr, task)
                preds = np.concatenate([pos_out, neg_out])
                labels = np.concatenate([np.ones_like(pos_out), np.zeros_like(neg_out)])
                binary_preds = (preds >= 0.5).astype(int)
                auc = roc_auc_score(labels, preds) if len(np.unique(labels)) > 1 else 0.0
                aupr = average_precision_score(labels, preds) if len(np.unique(labels)) > 1 else 0.0
                accuracy = accuracy_score(labels, binary_preds)
                f1 = f1_score(labels, binary_preds)
                precision = precision_score(labels, binary_preds, zero_division=1)
                all_metrics[task]['auc'].append(auc)
                all_metrics[task]['aupr'].append(aupr)
                all_metrics[task]['accuracy'].append(accuracy)
                all_metrics[task]['f1'].append(f1)
                all_metrics[task]['precision'].append(precision)
    final_metrics = {}
    for task, metrics in all_metrics.items():
        final_metrics[task] = {
            'auc_mean': np.mean(metrics['auc']),
            'auc_std': np.std(metrics['auc']),
            'aupr_mean': np.mean(metrics['aupr']),
            'aupr_std': np.std(metrics['aupr']),
            'accuracy_mean': np.mean(metrics['accuracy']),
            'accuracy_std': np.std(metrics['accuracy']),
            'f1_mean': np.mean(metrics['f1']),
            'f1_std': np.std(metrics['f1']),
            'precision_mean': np.mean(metrics['precision']),
            'precision_std': np.std(metrics['precision']),
        }

    return final_metrics

In [None]:
def test_model(model, node_features, test_edge_index, test_edge_attr, task_edges, num_nodes, iterations=10):
    model.eval()
    all_metrics = {task: {'auc': [], 'aupr': [], 'accuracy': [], 'f1': [], 'precision': []} for task in task_edges}
    inference_times = []

    for _ in range(iterations):
        start_time = time.time() 
        
        with torch.no_grad():
            embeddings = model(node_features, test_edge_index, test_edge_attr)
            for task, edges in task_edges.items():
                edge_split = edges['test'].to(device)
                edge_attr = test_edge_attr[edge_split[0]].to(device)
                pos_out = model.predict(embeddings, edge_split, edge_attr, task)
                neg_edges = sample_negative_edges(edge_split, edge_split.size(1), num_nodes)
                neg_attr = edge_attr[neg_edges[0]]
                neg_out = model.predict(embeddings, neg_edges, neg_attr, task)
                preds = np.concatenate([pos_out, neg_out])
                labels = np.concatenate([np.ones_like(pos_out), np.zeros_like(neg_out)])
                binary_preds = (preds >= 0.5).astype(int)
                auc = roc_auc_score(labels, preds) if len(np.unique(labels)) > 1 else 0.0
                aupr = average_precision_score(labels, preds) if len(np.unique(labels)) > 1 else 0.0
                accuracy = accuracy_score(labels, binary_preds)
                f1 = f1_score(labels, binary_preds)
                precision = precision_score(labels, binary_preds, zero_division=1)
                all_metrics[task]['auc'].append(auc)
                all_metrics[task]['aupr'].append(aupr)
                all_metrics[task]['accuracy'].append(accuracy)
                all_metrics[task]['f1'].append(f1)
                all_metrics[task]['precision'].append(precision)
        end_time = time.time()  
        inference_times.append(end_time - start_time)  
    avg_inference_time = np.mean(inference_times)
    std_inference_time = np.std(inference_times)
    final_metrics = {}
    for task, metrics in all_metrics.items():
        final_metrics[task] = {
            'auc_mean': np.mean(metrics['auc']),
            'auc_std': np.std(metrics['auc']),
            'aupr_mean': np.mean(metrics['aupr']),
            'aupr_std': np.std(metrics['aupr']),
            'accuracy_mean': np.mean(metrics['accuracy']),
            'accuracy_std': np.std(metrics['accuracy']),
            'f1_mean': np.mean(metrics['f1']),
            'f1_std': np.std(metrics['f1']),
            'precision_mean': np.mean(metrics['precision']),
            'precision_std': np.std(metrics['precision']),
        }

    final_metrics['inference_time'] = {
        'avg_inference_time': avg_inference_time,
        'std_inference_time': std_inference_time,}

    return final_metrics

In [None]:
tasks = ["CaD", "CrC", "CbG", "DaG"]

task_edges = {}

for task in tasks:
    task_edges[task] = {
        'train': train_edge_index.clone(),
        'val': val_edge_index.clone(),
        'test': test_edge_index.clone(),
    }

hidden_dim_t = 128
hidden_dim_s = 64
task_outputs = {task: 1 for task in tasks}
num_edge_types = len(le_metaedge.classes_)

teacher_model = TeacherModel(input_dim=node_features.shape[1], hidden_dim=hidden_dim_t, task_outputs=task_outputs, num_edge_types=num_edge_types).to(device)

student_model = StudentModel(input_dim=node_features.shape[1], hidden_dim=hidden_dim_s, task_outputs=task_outputs, num_edge_types=num_edge_types).to(device)

criterion = WeightedBCEWithLogitsLoss()

optimizer_teacher = torch.optim.Adam(teacher_model.parameters(), lr=0.001, weight_decay=1e-5)
optimizer_student = torch.optim.Adam(student_model.parameters(), lr=0.001, weight_decay=1e-5)

epochs = 100
batch_size = 128
loss_combiner = WeightedLossCombiner().to(device)

In [None]:
print("Training Teacher Model...")
for epoch in range(epochs):
    loss = train_teacher(teacher_model, node_features, train_edge_index, train_edge_attr, task_edges, optimizer_teacher, criterion, num_nodes, device, batch_size, iterations=1)
    if (epoch + 1) % 2 == 0:
        val_metrics = validate_model(teacher_model, node_features, val_edge_index, val_edge_attr, task_edges, 'val', num_nodes, iterations=10)
        for task, metric in val_metrics.items():
            print(f"Validation {task}: AUC: {metric['auc_mean']:.4f} ± {metric['auc_std']:.4f}, "
                  f"AUPR: {metric['aupr_mean']:.4f} ± {metric['aupr_std']:.4f}, "
                  f"Accuracy: {metric['accuracy_mean']:.4f} ± {metric['accuracy_std']:.4f}, "
                  f"F1: {metric['f1_mean']:.4f} ± {metric['f1_std']:.4f}, "
                  f"Precision: {metric['precision_mean']:.4f} ± {metric['precision_std']:.4f}")

In [None]:
test_metrics = test_model(teacher_model, node_features, test_edge_index, test_edge_attr, task_edges, num_nodes, iterations=10)

print("\nTeacher Model Test Results:")
for task, metric in test_metrics.items():
        print(f"Test Result {task}: AUC: {metric['auc_mean']:.4f} ± {metric['auc_std']:.4f}, "
              f"AUPR: {metric['aupr_mean']:.4f} ± {metric['aupr_std']:.4f}, "
              f"Accuracy: {metric['accuracy_mean']:.4f} ± {metric['accuracy_std']:.4f}, "
              f"F1: {metric['f1_mean']:.4f} ± {metric['f1_std']:.4f}, "
              f"Precision: {metric['precision_mean']:.4f} ± {metric['precision_std']:.4f}")

inference_time = test_metrics['inference_time']
avg_inference_time = inference_time['avg_inference_time']
std_inference_time = inference_time['std_inference_time']

print(f"\nAverage Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f}")

In [None]:
print("\nTraining Student Model WITH Knowledge Distillation...")
epochs = 50
for epoch in range(epochs):
    loss = train_student(student_model, teacher_model, node_features, train_edge_index, train_edge_attr, task_edges, optimizer_student, criterion, distillation_loss, num_nodes, device, loss_combiner, batch_size, iterations=10)
    if (epoch + 1) % 2 == 0:
        val_metrics = validate_model(student_model, node_features, val_edge_index, val_edge_attr, task_edges, num_nodes, iterations=10)
        for task, metric in val_metrics.items():
            print(f"Validation {task}: AUC: {metric['auc_mean']:.4f} ± {metric['auc_std']:.4f}, "
                  f"AUPR: {metric['aupr_mean']:.4f} ± {metric['aupr_std']:.4f}, "
                  f"Accuracy: {metric['accuracy_mean']:.4f} ± {metric['accuracy_std']:.4f}, "
                  f"F1: {metric['f1_mean']:.4f} ± {metric['f1_std']:.4f}, "
                  f"Precision: {metric['precision_mean']:.4f} ± {metric['precision_std']:.4f}")

In [None]:
test_metrics = test_model(student_model, node_features, test_edge_index, test_edge_attr, task_edges, num_nodes, iterations=10)
print("\nStudent Model Test Results:")
for task, metric in test_metrics.items():
        print(f"Test Result {task}: AUC: {metric['auc_mean']:.4f} ± {metric['auc_std']:.4f}, "
              f"AUPR: {metric['aupr_mean']:.4f} ± {metric['aupr_std']:.4f}, "
              f"Accuracy: {metric['accuracy_mean']:.4f} ± {metric['accuracy_std']:.4f}, "
              f"F1: {metric['f1_mean']:.4f} ± {metric['f1_std']:.4f}, "
              f"Precision: {metric['precision_mean']:.4f} ± {metric['precision_std']:.4f}")

inference_time = test_metrics['inference_time']
avg_inference_time = inference_time['avg_inference_time']
std_inference_time = inference_time['std_inference_time']

print(f"\nAverage Inference Time: {avg_inference_time:.4f} ± {std_inference_time:.4f}")