In [4]:
import torch
import torch.nn.functional as F
from torch import nn
from torch_geometric.nn import SAGEConv
from torch_geometric.datasets import Planetoid
from torch_geometric.utils import negative_sampling
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.transforms import RandomLinkSplit
from torch_geometric.data import Data
from utils.reorganized_preprocessing import get_edges_and_indices
import pandas as pd
dataset = 'icews18'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


In [5]:
def create_x(train_data, val_data, test_data):
    num_node_features = 512
    x = torch.randn(train_data.edge_index.max().item()+3, num_node_features)
    train_data.x = x
    train_data.num_node_features = num_node_features
    val_data.x = x
    val_data.num_node_features = num_node_features
    test_data.x = x
    test_data.num_node_features = num_node_features

    return train_data, val_data, test_data
    

In [6]:
user_artist = pd.read_csv(f'../../../../data/raw/{dataset}/1-indexed/actor_actor.csv', encoding='utf-8', names=['userID','artistID', 'weight'],)
user_friend = pd.read_csv(f'../../../../data/raw/{dataset}/1-indexed/actor_action.csv', encoding='utf-8', names=['userID', 'friendID'])
artist_tag = pd.read_csv(f'../../../../data/raw/{dataset}/1-indexed/actor_sector.csv', encoding='utf-8', names=['artistID', 'tagID'])

num_user = user_artist['userID'].max()
num_artist = user_artist['artistID'].max()
num_tag = artist_tag['tagID'].max()

In [7]:
train_data, val_data, test_data, train_idx, val_idx, test_idx = get_edges_and_indices(user_artist, remove_fraction=1.0)
train_data, val_data, test_data = create_x(train_data, val_data, test_data)



In [8]:
train_data

Data(edge_index=[2, 57814], pos_edge_label=[16224], pos_edge_label_index=[2, 16224], neg_edge_label=[28907], neg_edge_label_index=[2, 28907], x=[6225, 512], num_node_features=512)

In [9]:
class GraphSAGEEncoder(nn.Module):
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.conv1 = SAGEConv(in_channels, hidden_channels)
        self.conv2 = SAGEConv(hidden_channels, hidden_channels)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = self.conv2(x, edge_index)
        return x


class LinkPredictor(nn.Module):
    def __init__(self, hidden_channels, dropout=0.5):
        super().__init__()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_channels, 1)
        )

    def forward(self, x_i, x_j):
        return torch.sigmoid(self.mlp(x_i * x_j)).view(-1)

In [10]:
@torch.no_grad()
def test(model, predictor, data):
    model.eval()
    predictor.eval()

    z = model(data.x, data.edge_index)

    pos_edge = data.pos_edge_label_index
    neg_edge = data.neg_edge_label_index

    edge = torch.cat([pos_edge, neg_edge], dim=1)
    x_i, x_j = z[edge[0]], z[edge[1]]
    labels = torch.cat([
        torch.ones(pos_edge.size(1)),
        torch.zeros(neg_edge.size(1))
    ]).to(device)

    pred = predictor(x_i, x_j).detach().cpu()
    labels = labels.detach().cpu()

    auc = roc_auc_score(labels, pred)
    ap = average_precision_score(labels, pred)
    return auc, ap


def train(model, predictor, train_data, optimizer):
    model.train()
    predictor.train()

    z = model(train_data.x, train_data.edge_index)

    pos_edge = train_data.pos_edge_label_index
    neg_edge = negative_sampling(
        edge_index=train_data.edge_index,
        num_nodes=train_data.num_nodes,
        num_neg_samples=pos_edge.size(1)
    )

    edge = torch.cat([pos_edge, neg_edge], dim=1)
    x_i, x_j = z[edge[0]], z[edge[1]]
    labels = torch.cat([
        torch.ones(pos_edge.size(1)),
        torch.zeros(neg_edge.size(1))
    ]).to(device)

    pred = predictor(x_i, x_j)
    loss = F.binary_cross_entropy(pred, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    return loss.item()




In [11]:
model = GraphSAGEEncoder(train_data.num_node_features, 64).to(device)
predictor = LinkPredictor(64).to(device)
optimizer = torch.optim.Adam(
    list(model.parameters()) + list(predictor.parameters()),
    lr=0.01
)


In [12]:
for epoch in range(1, 51):
    loss = train(model, predictor, train_data.to(device), optimizer)
    if epoch % 10 == 0:
        val_auc, val_ap = test(model, predictor, val_data.to(device))
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val AUC: {val_auc:.4f}, AP: {val_ap:.4f}')

# Final test evaluation
test_auc, test_ap = test(model, predictor, test_data.to(device))
print(f'Test AUC: {test_auc:.4f}, Test AP: {test_ap:.4f}')

Epoch: 010, Loss: 0.4376, Val AUC: 0.9109, AP: 0.9229
Epoch: 020, Loss: 0.2100, Val AUC: 0.9270, AP: 0.9424
Epoch: 030, Loss: 0.1095, Val AUC: 0.9343, AP: 0.9496
Epoch: 040, Loss: 0.0658, Val AUC: 0.9379, AP: 0.9525
Epoch: 050, Loss: 0.0516, Val AUC: 0.9403, AP: 0.9548
Test AUC: 0.9437, Test AP: 0.9578
