In [2]:
%pip install torch torch_geometric

Collecting torch_geometric
  Using cached torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
Collecting aiohttp (from torch_geometric)
  Using cached aiohttp-3.12.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.6 kB)
Collecting aiohappyeyeballs>=2.5.0 (from aiohttp->torch_geometric)
  Using cached aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.1.2 (from aiohttp->torch_geometric)
  Using cached aiosignal-1.3.2-py2.py3-none-any.whl.metadata (3.8 kB)
Collecting frozenlist>=1.1.1 (from aiohttp->torch_geometric)
  Using cached frozenlist-1.7.0-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)
Collecting multidict<7.0,>=4.5 (from aiohttp->torch_geometric)
  Using cached multidict-6.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.3 kB)
Collecting propcache>=0.2.0 (from aiohttp->torch_geometric)
  Using cached propcache-0.3.2-cp311-cp311-

In [6]:
"""CROSS-TASK EVALUATION: TRAIN ON LINK PREDICTION -> TEST ON NODE CLASSIFICATION"""
import torch
import torch.nn.functional as F
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import JumpingKnowledge, GCNConv, MixHopConv
import torch_geometric.transforms as T
import time
import matplotlib.pyplot as plt
from sklearn.metrics import roc_auc_score, accuracy_score
import os

def set_seed(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim, num_layers=2, dropout=0.5):
        super(GCN, self).__init__()
        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(GCNConv(hidden_channels, hidden_channels))
        self.convs.append(GCNConv(hidden_channels, embedding_dim))
    def forward(self, x, edge_index):
        for i, conv in enumerate(self.convs):
            x = conv(x, edge_index)
            if i < len(self.convs) - 1:
                x = F.relu(x)
                x = F.dropout(x, p=self.dropout, training=self.training)
        return x

class JKNetMax(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim, num_layers=6, dropout=0.5):
        super(JKNetMax, self).__init__()
        self.dropout = dropout
        self.convs = torch.nn.ModuleList()
        self.convs.append(GCNConv(in_channels, 16)) 
        for _ in range(num_layers - 1):
            self.convs.append(GCNConv(16, 16))
        self.jump = JumpingKnowledge(mode='max', channels=16, num_layers=num_layers)
        self.final_lin = torch.nn.Linear(16, embedding_dim)
    def forward(self, x, edge_index):
        xs = []
        for conv in self.convs:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xs.append(x)
        x_jump = self.jump(xs)
        return self.final_lin(x_jump)

class MixHopEncoder(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim, dropout=0.5, powers=[0, 1, 2]):
        super(MixHopEncoder, self).__init__()
        self.dropout = dropout
        self.conv1 = MixHopConv(in_channels, hidden_channels, powers=powers)
        self.final_lin = torch.nn.Linear(hidden_channels * len(powers), embedding_dim)
    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.final_lin(x)
        return x

class AlternatingGCNMixHopJKNet(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, embedding_dim, num_layers=4, hop_hidden_channels=16, mode='max', dropout=0.5, mixhop_powers=[0,1,2]):
        super(AlternatingGCNMixHopJKNet, self).__init__()
        self.dropout = dropout
        self.num_layers = num_layers
        self.convs = torch.nn.ModuleList()
        self.lins = torch.nn.ModuleList()
        for i in range(num_layers):
            if i % 2 == 0:
                in_c = in_channels if i == 0 else hidden_channels
                self.convs.append(GCNConv(in_c, hidden_channels))
                self.lins.append(torch.nn.Identity())
            else:
                self.convs.append(MixHopConv(hidden_channels, hop_hidden_channels, powers=mixhop_powers))
                self.lins.append(torch.nn.Linear(hop_hidden_channels * len(mixhop_powers), hidden_channels))
        self.jump = JumpingKnowledge(mode=mode, channels=hidden_channels, num_layers=num_layers)
        if mode == 'cat':
            self.final_layer = torch.nn.Linear(num_layers * hidden_channels, embedding_dim)
        else:
            self.final_layer = torch.nn.Linear(hidden_channels, embedding_dim)
    def forward(self, x, edge_index):
        xs = []
        for i in range(self.num_layers):
            if i % 2 == 0:
                x = self.convs[i](x, edge_index)
                x = F.relu(x)
            else:
                x = self.convs[i](x, edge_index)
                x = self.lins[i](x)
                x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            xs.append(x)
        x_jump = self.jump(xs)
        x_final = self.final_layer(x_jump)
        return x_final


class LinkPredictor(torch.nn.Module):
    def __init__(self, encoder, embedding_dim):
        super().__init__()
        self.encoder = encoder
        self.decoder = torch.nn.Sequential(
            torch.nn.Linear(2 * embedding_dim, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 1)
        )
    def forward(self, x, edge_index, edge_label_index):
        z = self.encoder(x, edge_index)
        src_emb = z[edge_label_index[0]]
        dst_emb = z[edge_label_index[1]]
        return self.decoder(torch.cat([src_emb, dst_emb], dim=-1)).squeeze()

def run_link_prediction_training(encoder_class_lambda, model_name, seed, embedding_dim):
    print(f"\n--- STAGE 1: Training {model_name} on Link Prediction ---")
    set_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = Planetoid(root='./Cora', name='Cora', transform=T.NormalizeFeatures())
    data = dataset[0]
    transform = T.RandomLinkSplit(num_val=0.05, num_test=0.1, is_undirected=True, add_negative_train_samples=True)
    train_data, val_data, test_data = transform(data.clone())
    
    encoder = encoder_class_lambda(dataset.num_features, embedding_dim).to(device)
    model = LinkPredictor(encoder, embedding_dim).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.0005, weight_decay=5e-4)

    best_val_auc, final_test_auc, patience = 0, 0, 10000
    wait = 0
    
    for epoch in range(1, 10000):
        model.train()
        optimizer.zero_grad()
        out = model(train_data.x.to(device), train_data.edge_index.to(device), train_data.edge_label_index.to(device))
        loss = F.binary_cross_entropy_with_logits(out, train_data.edge_label.to(device))
        loss.backward()
        optimizer.step()
        
        model.eval()
        with torch.no_grad():
            out = model(train_data.x.to(device), train_data.edge_index.to(device), val_data.edge_label_index.to(device)).sigmoid()
            val_auc = roc_auc_score(val_data.edge_label.cpu(), out.cpu())

        if val_auc > best_val_auc:
            best_val_auc = val_auc
            with torch.no_grad():
                out = model(train_data.x.to(device), train_data.edge_index.to(device), test_data.edge_label_index.to(device)).sigmoid()
                final_test_auc = roc_auc_score(test_data.edge_label.cpu(), out.cpu())
            torch.save(model.encoder.state_dict(), f'best_encoder_{model_name}.pt')
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                print(f"Early stopping at epoch {epoch}.")
                break
        
    best_encoder = encoder_class_lambda(dataset.num_features, embedding_dim).to(device)
    if os.path.exists(f'best_encoder_{model_name}.pt'):
        best_encoder.load_state_dict(torch.load(f'best_encoder_{model_name}.pt'))
    return best_encoder, final_test_auc



def evaluate_on_node_classification(encoder, model_name, embedding_dim, seed):
    print(f"\n--- STAGE 2: Evaluating {model_name}'s Embeddings on Node Classification ---")
    set_seed(seed)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = Planetoid(root='./Cora', name='Cora', transform=T.NormalizeFeatures())
    data = dataset[0].to(device)
    
    encoder.eval()
    for param in encoder.parameters():
        param.requires_grad = False
    
    with torch.no_grad():
        z = encoder(data.x, data.edge_index)
        
    classifier = torch.nn.Linear(embedding_dim, dataset.num_classes).to(device)
    optimizer = torch.optim.Adam(classifier.parameters(), lr=0.005, weight_decay=5e-4)
    
    best_val_acc = 0
    final_test_acc = 0
    patience, wait = 200, 0

    for epoch in range(1, 5000):
        classifier.train()
        optimizer.zero_grad()
        out = classifier(z)
        loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
        loss.backward()
        optimizer.step()
        
        classifier.eval()
        with torch.no_grad():
            pred = classifier(z).argmax(dim=-1)
            val_acc = accuracy_score(data.y[data.val_mask].cpu(), pred[data.val_mask].cpu())
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            final_test_acc = accuracy_score(data.y[data.test_mask].cpu(), pred[data.test_mask].cpu())
            wait = 0
        else:
            wait += 1
            if wait >= patience:
                break
                
    return final_test_acc


if __name__ == '__main__':
    seed = 123
    
    model_configs = [
        {
            'name': 'GCN',
            'encoder_lambda': lambda in_c, emb_dim: GCN(in_c, hidden_channels=32, embedding_dim=emb_dim, num_layers=2, dropout=0.5),
            'embedding_dim': 32
        },
        {
            'name': 'JKNetMax',
            'encoder_lambda': lambda in_c, emb_dim: JKNetMax(in_c, hidden_channels=32, embedding_dim=emb_dim, num_layers=6, dropout=0.5),
            'embedding_dim': 32 
        },
        {
            'name': 'MixHop',
            'encoder_lambda': lambda in_c, emb_dim: MixHopEncoder(in_c, hidden_channels=32, embedding_dim=emb_dim, dropout=0.5),
            'embedding_dim': 96 
        },
        {
            'name': 'AltGCNMixHopJKN',
            'encoder_lambda': lambda in_c, emb_dim: AlternatingGCNMixHopJKNet(in_c, hidden_channels=32, embedding_dim=emb_dim, num_layers=3, hop_hidden_channels=32, mode='max', dropout=0.4),
            'embedding_dim': 32
        }
    ]
    
    results = {}

    for config in model_configs:
        print(f"\n{'='*20} Starting Pipeline for {config['name']} {'='*20}")
        try:
            trained_encoder, test_auc = run_link_prediction_training(
                config['encoder_lambda'], 
                config['name'], 
                seed, 
                config['embedding_dim']
            )
            
            test_acc = evaluate_on_node_classification(
                trained_encoder, 
                config['name'], 
                config['embedding_dim'], 
                seed
            )
            
            results[config['name']] = {'Link Prediction AUC': test_auc, 'Node Classification Accuracy': test_acc}
        except (RuntimeError, TypeError) as e:
            print(f"\nERROR during pipeline for {config['name']}: {e}")
            print("This run was skipped due to a parameter mismatch in the __main__ configuration.")
            results[config['name']] = {'Link Prediction AUC': 0.0, 'Node Classification Accuracy': 0.0}


    print("\n\n--- FINAL CROSS-TASK EVALUATION SUMMARY ---")
    for model_name, metrics in results.items():
        print(f"\nModel: {model_name}")
        print(f"\tEvaluated on Node Classification, Test Accuracy: {metrics['Node Classification Accuracy']:.4f}")



--- STAGE 1: Training GCN on Link Prediction ---

--- STAGE 2: Evaluating GCN's Embeddings on Node Classification ---


--- STAGE 1: Training JKNetMax on Link Prediction ---

--- STAGE 2: Evaluating JKNetMax's Embeddings on Node Classification ---


--- STAGE 1: Training MixHop on Link Prediction ---

--- STAGE 2: Evaluating MixHop's Embeddings on Node Classification ---


--- STAGE 1: Training AltGCNMixHopJKN on Link Prediction ---

--- STAGE 2: Evaluating AltGCNMixHopJKN's Embeddings on Node Classification ---


--- FINAL CROSS-TASK EVALUATION SUMMARY ---

Model: GCN
	Evaluated on Node Classification, Test Accuracy: 0.2380

Model: JKNetMax
	Evaluated on Node Classification, Test Accuracy: 0.3020

Model: MixHop
	Evaluated on Node Classification, Test Accuracy: 0.2860

Model: AltGCNMixHopJKN
	Evaluated on Node Classification, Test Accuracy: 0.2340
