In [None]:
import copy, os

from tqdm import tqdm

import torch
from torch.utils.data import TensorDataset, DataLoader

from torch_geometric.nn import to_hetero
from torch_geometric.transforms import Compose, ToUndirected
from torch_geometric.loader import NeighborLoader, LinkNeighborLoader
from torch_geometric.utils import negative_sampling

# project-specific
import lib
from lib.model import SupervisedNodePredictions, SupervisedEdgePredictions, SupervisedMTL, train_node_readout, make_gae, make_gmae, Readout, test_node_readout, get_x_dict, GraphSAGE, pretrain_gmae, train_edge_readout, test_edge_readout

from lib.dataset import load_data, to_inductive

Load data

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
root_path = 'OGBN-MAG/'
transform = Compose([ToUndirected(merge=False)])
preprocess = 'metapath2vec'
data = lib.dataset.load_data(root_path, transform=transform, preprocess=preprocess)

train_data = data.subgraph({ # used for the unsupervised part
    "paper": data["paper"].train_mask.nonzero(as_tuple=False).view(-1)
})
num_classes = int(data["paper"].y.max()) + 1



Pretrain the unsupervised encoder

In [None]:
encoder = pretrain_gmae(train_data)
torch.save(encoder.state_dict(), "gmae_encoder")

Run the trained encoder on the full data for training the node classification and edge prediction readout layers

In [None]:
def dataset_to_loader(d):
    x_dict = get_x_dict(d)
    x_dict = {k: v.to(device) for k, v in x_dict.items()}
    with torch.no_grad():
        z_dict = encoder(x_dict, d.edge_index_dict)
    z_paper = z_dict["paper"]
    z_paper = z_paper.detach().cpu()
    y_paper = d["paper"].y.cpu()
    torchdataset = torch.utils.data.TensorDataset(z_paper, y_paper)
    return torch.utils.data.DataLoader(
        torchdataset,
        batch_size=16,
        shuffle=True
    )

train_loader = dataset_to_loader(train_data)
val_loader   = dataset_to_loader(data)


Train node readout 

In [None]:
encoder.eval()
readout = Readout(num_classes).to(device)
for epoch in range(5):
    loss = train_node_readout(readout, train_loader)
    acc  = test_node_readout(readout, val_loader)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Val Acc: {acc:.4f}")

torch.save(readout.state_dict(), "gmae_node_classification_readout")


Train edge readout

In [None]:
paper_train_mask = data["paper"].train_mask
paper_test_mask  = data["paper"].val_mask
edge_type = ("paper", "has_topic", "field_of_study")
edge_index = data[edge_type].edge_index   
paper_idx = edge_index[0]               
fos_idx = edge_index[1]              
train_edge_mask = paper_train_mask[paper_idx]
test_edge_mask = paper_test_mask[paper_idx]
train_edge_index = edge_index[:, train_edge_mask]
test_edge_index  = edge_index #[:, test_edge_mask] allowed to see all edges during inference 

x_dict = get_x_dict(data)
with torch.no_grad():
    z_dict = encoder(x_dict, data.edge_index_dict)
z_paper = z_dict["paper"].detach()            
z_fos   = z_dict["field_of_study"].detach()  
readout = Readout(1) # single out channel -- probability

def edge_index_to_loader(edge_index, z_paper, z_fos, batch_size=1024):
    pos_edge_index = edge_index
    num_pos = pos_edge_index.size(1)
    num_paper = z_paper.size(0)
    num_fos   = z_fos.size(0)
    neg_edge_index = negative_sampling(
        pos_edge_index,
        num_nodes=(num_paper, num_fos),
        num_neg_samples=num_pos,
    )
    z_src_pos = z_paper[pos_edge_index[0]]
    z_dst_pos = z_fos[pos_edge_index[1]]  
    z_src_neg = z_paper[neg_edge_index[0]]
    z_dst_neg = z_fos[neg_edge_index[1]]  
    z_src = torch.cat([z_src_pos, z_src_neg], dim=0)
    z_dst = torch.cat([z_dst_pos, z_dst_neg], dim=0)
    y = torch.cat([
        torch.ones(num_pos, dtype=torch.float32),
        torch.zeros(num_pos, dtype=torch.float32),
    ], dim=0)
    dataset = TensorDataset(z_src, z_dst, y)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

print("building edge datasets...")
train_loader = edge_index_to_loader(train_edge_index, z_paper, z_fos)
test_loader  = edge_index_to_loader(test_edge_index,  z_paper, z_fos)

print("training edge predictor...")
for epoch in range(2):
    loss = train_edge_readout(readout, train_loader)
    acc  = test_edge_readout(readout, test_loader)
    print(f"Epoch {epoch:02d} | Loss: {loss:.4f} | Test Acc: {acc:.4f}")


# Train supervised versions of the network for comparison
## Node classification

In [None]:
target_type = "paper"
data_inductive = to_inductive(copy.deepcopy(data), target_type)

train_loader = NeighborLoader(
    data_inductive,
    input_nodes=(target_type, data_inductive[target_type].train_mask),
    num_neighbors=[15, 10],
    batch_size=2048,
    shuffle=True,
)
val_loader = NeighborLoader(
    data,
    input_nodes=(target_type, data[target_type].val_mask),
    num_neighbors=[15, 10],
    batch_size=2048,
)

hidden_dim = 128
num_classes = int(data_inductive[target_type].y.max()) + 1

model = GraphSAGE(in_channels=hidden_dim, out_channels=hidden_dim, num_classes=num_classes)
model = to_hetero(model, data_inductive.metadata(), aggr='sum')
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
pipeline = SupervisedNodePredictions(
    model=model,
    device=device,
    optimizer=optimizer,
    target_type=target_type
)

best_acc = 0
save_path = 'supervised/models/node/'
for epoch in range(10):
    loss = pipeline.train(train_loader)
    acc = pipeline.test(val_loader)

    print(f"Epoch {epoch:02d} | "
          f"Loss: {loss:.4f} | "
          f"Val Acc: {acc:.4f}")

    torch.save({
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "epoch": epoch
    }, save_path+f"checkpoint_{epoch}.pt")

    if acc > best_acc:
        best_acc = acc
        torch.save(model.state_dict(), save_path+"best.pt")

## Link predictions

In [None]:
target_edge_type = ('paper', 'has_topic', 'field_of_study')

train_loader = LinkNeighborLoader(
    data_inductive,
    num_neighbors=[15, 10],
    edge_label_index=(target_edge_type, data_inductive[target_edge_type].edge_index),
    neg_sampling_ratio=1.0,
    batch_size=2048,
    shuffle=True,
)
edge_index_all = data[target_edge_type].edge_index
src_papers = edge_index_all[0]

# Use papers val_mask to select validation edges
val_edge_mask = data['paper'].val_mask[src_papers]
val_edge_index = edge_index_all[:, val_edge_mask]

val_loader = LinkNeighborLoader(
    data,
    num_neighbors=[15, 10],
    edge_label_index=(target_edge_type, val_edge_index),
    neg_sampling_ratio=1.0,
    batch_size=2048,
    shuffle=False,
)

hidden_dim = 128

model = GraphSAGE(in_channels=hidden_dim, out_channels=hidden_dim)
model = to_hetero(model, data_inductive.metadata(), aggr='sum')
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.003)
pipeline = SupervisedEdgePredictions(
    model=model,
    device=device,
    optimizer=optimizer,
    target_edge_type=target_edge_type
)

best_auc = 0
save_path = "supervised/models/edge/"
os.makedirs(save_path, exist_ok=True)

max_steps = 10_000
eval_every = 500
save_every = 500

step = 0
epoch = 0

pbar = tqdm(total=max_steps)
while step < max_steps:
    print(f"=== Epoch {epoch} ===")
    for batch in train_loader:
        step += 1

        loss = pipeline.train_on_batch(batch)

        if step % eval_every == 0:
            metrics = pipeline.test(val_loader)
            auc = metrics["AUC"]

            pbar.write(
                f"AUC: {auc:.4f}"
            )

            if auc > best_auc:
                best_auc = auc
                torch.save(model.state_dict(), save_path + "best.pt")

        if step % save_every == 0:
            torch.save({
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "step": step,
                "epoch": epoch,
            }, save_path + f"checkpoint_{step}.pt")

        if step >= max_steps:
            break

        pbar.update(1)
    epoch += 1

## Multi-Task Learning

In [None]:
target_node_type = "paper"

train_loader = LinkNeighborLoader(
    data_inductive,
    num_neighbors=[15, 10],
    edge_label_index=(target_edge_type,  data_inductive[target_edge_type].edge_index),
    neg_sampling_ratio=1.0,
    batch_size=2048,
    shuffle=True
)

edge_index_all = data[target_edge_type].edge_index
src_papers = edge_index_all[0]

# Use papers val_mask to select validation edges
val_edge_mask = data['paper'].val_mask[src_papers]
val_edge_index = edge_index_all[:, val_edge_mask]

val_loader = LinkNeighborLoader(
    data,
    num_neighbors=[15, 10],
    edge_label_index=(target_edge_type, val_edge_index),
    neg_sampling_ratio=1.0,
    batch_size=2048,
    shuffle=False
)

model = GraphSAGE(in_channels=hidden_dim, out_channels=hidden_dim, num_classes=num_classes)
model = to_hetero(model, data_inductive.metadata(), aggr='sum')
model = model.to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=0.003)

pipeline = SupervisedMTL(
    model=model,
    device=device,
    optimizer=optimizer,
    target_node_type=target_node_type,
    target_edge_type=target_edge_type
)

best_loss = float("inf")
save_path = 'supervised/models/mtl/'
os.makedirs(save_path, exist_ok=True)

max_steps = 10_000
eval_every = 500
save_every = 500

step = 0
epoch = 0

pbar = tqdm(total=max_steps)
while step < max_steps:
    print(f"=== Epoch {epoch} ===")
    for batch in train_loader:
        step += 1

        pipeline.train_on_batch(batch)

        if step % eval_every == 0:
            metrics = pipeline.test(val_loader)
            loss = metrics["loss_total"]

            print(
                f"AUC: {metrics['AUC']:.4f} | "
                f"Accuracy: {metrics['Accuracy']:.4f} | "
                f"Edge loss: {metrics['loss_edge']:.4f} | "
                f"Node loss: {metrics['loss_node']:.4f} | "
                f"Total Loss: {loss:.4f}"
            )

            if loss < best_loss:
                best_loss = loss
                torch.save(model.state_dict(), save_path + "best.pt")

        if step % save_every == 0:
            torch.save({
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "step": step,
                "epoch": epoch,
            }, save_path + f"checkpoint_{step}.pt")

        if step >= max_steps:
            break

        pbar.update(1)
    epoch += 1