In [65]:
import argparse
import torch
import torch.nn as nn
from sklearn.metrics import roc_auc_score, average_precision_score
import torch_geometric.transforms as T
from torch_geometric.utils import negative_sampling
from torch_geometric.nn import RGCNConv
import pandas as pd
import numpy as np
import tqdm
import torch
import torch.nn.functional as F
import tqdm
import gc
import warnings
warnings.filterwarnings('ignore')

gc.collect()

1807

In [55]:
hetero_data = torch.load('../data/model/hetero_graph_data.pt')
hetero_data

HeteroData(
  customer={
    x=[161086, 12],
    index=[161086],
  },
  product={
    x=[2708, 15],
    index=[2708],
  },
  (customer, order, product)={ edge_index=[2, 11892915] },
  (product, class, product)={ edge_index=[2, 1796620] },
  (product, rev_class, product)={ edge_index=[2, 1796620] },
  (product, rev_order, customer)={ edge_index=[2, 11892915] }
)

In [56]:
train_data, val_data, test_data = T.RandomLinkSplit(
        num_val=0.2,
        num_test=0.2,
        is_undirected=True,
        add_negative_train_samples=False,
        disjoint_train_ratio=0,
        edge_types=[('customer', 'order', 'product'),
                    ('product', 'class', 'product')],
        rev_edge_types=[('product', 'rev_order', 'customer'), 
                        ('product', 'rev_class', 'product')]
    )(hetero_data.to_homogeneous())

print(train_data)
print(val_data)
print(test_data)

Data(edge_index=[2, 16427442], x=[163794, 15], index=[163794], node_type=[163794], edge_type=[16427442], edge_label=[8213721], edge_label_index=[2, 8213721])
Data(edge_index=[2, 16427442], x=[163794, 15], index=[163794], node_type=[163794], edge_type=[16427442], edge_label=[5475814], edge_label_index=[2, 5475814])
Data(edge_index=[2, 21903256], x=[163794, 15], index=[163794], node_type=[163794], edge_type=[21903256], edge_label=[5475814], edge_label_index=[2, 5475814])


In [57]:
def negative_sample():
    # 从训练集中采样与正边相同数量的负边
    neg_edge_index = negative_sampling(
        edge_index=train_data.edge_index, num_nodes=train_data.num_nodes,
        num_neg_samples=train_data.edge_label_index.size(1), method='sparse')
    # print(neg_edge_index.size(1))   # 3642条负边，即每次采样与训练集中正边数量一致的负边
    edge_label_index = torch.cat(
        [train_data.edge_label_index, neg_edge_index],
        dim=-1,
    )
    edge_label = torch.cat([
        train_data.edge_label,
        train_data.edge_label.new_zeros(neg_edge_index.size(1))
    ], dim=0)

    return edge_label, edge_label_index

In [58]:
# 参数
device = 'cuda' if torch.cuda.is_available() else 'cpu'
node_types = hetero_data.node_types
num_relations = len(hetero_data.edge_types)
init_sizes = []
for i in node_types:
    init_sizes.append(hetero_data['customer'].x.shape[-1])
in_feats = 
hidden_feats = 
out_channels = 

In [None]:
class RGCN_LP(nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super(RGCN_LP, self).__init__()
        self.conv1 = RGCNConv(in_channels, hidden_channels,
                              num_relations=num_relations, num_bases=30)
        self.conv2 = RGCNConv(hidden_channels, out_channels,
                              num_relations=num_relations, num_bases=30)
        self.lins = torch.nn.ModuleList()
        for i in range(len(node_types)):
            lin = nn.Linear(init_sizes[i], in_channels)
            self.lins.append(lin)

        self.fc = nn.Sequential(
            nn.Linear(2 * out_channels, 1),
            nn.Sigmoid()
        )

    def trans_dimensions(self, data):
        res = []
        for node_type, lin in zip(node_types, self.lins):
            res.append(lin(data[node_type].x))
        return torch.cat(res, dim=0)

    def encode(self, data):
        x = self.trans_dimensions(data)  # ! 需要检查效果
        edge_index, edge_type = data.edge_index, data.edge_type
        x = self.conv1(x, edge_index, edge_type)
        x = F.relu(x)
        # x = F.dropout(x, p=0.5, training=self.training)
        x = self.conv2(x, edge_index, edge_type)
        return x

    def decode(self, z, edge_label_index):
        # print(z.shape)
        src = z[edge_label_index[0]]
        dst = z[edge_label_index[1]]
        x = torch.cat([src, dst], dim=-1)
        x = self.fc(x)
        return x

    def forward(self, data, edge_label_index):
        z = self.encode(data)
        z = self.decode(z, edge_label_index)
        return z

In [None]:
def get_metrics(out):
    auc = roc_auc_score(test_data.edge_label.cpu().numpy(), out[:, 1].cpu().numpy())
    ap = average_precision_score(test_data.edge_label.cpu().numpy(), out[:, 1].cpu().numpy())
    return auc, ap

In [None]:
def train():
    model = RGCN_LP(in_feats, hidden_feats, 128).to(device)
    optimizer = torch.optim.Adam(params=model.parameters(), lr=0.01)
    criterion = torch.nn.BCELoss().to(device)
    min_epochs = 10
    min_val_loss = np.Inf
    final_test_auc = 0
    final_test_ap = 0
    model.train()
    for epoch in tqdm(range(100)):
        optimizer.zero_grad()
        edge_label, edge_label_index = negative_sample(train_data)
        out = model(train_data, edge_label_index).view(-1)
        loss = criterion(out, edge_label)
        loss.backward()
        optimizer.step()
        # validation
        val_loss, test_auc, test_ap = test(model, val_data, test_data)
        if epoch + 1 > min_epochs and val_loss < min_val_loss:
            min_val_loss = val_loss
            final_test_auc = test_auc
            final_test_ap = test_ap

        print('epoch {:03d} train_loss {:.8f} val_loss {:.4f} test_auc {:.4f} test_ap {:.4f}'
              .format(epoch, loss.item(), val_loss, test_auc, test_ap))

    return final_test_auc, final_test_ap

@torch.no_grad()
def test(model, val_data, test_data):
    model.eval()
    # cal val loss
    criterion = torch.nn.BCELoss().to(device)
    out = model(val_data, val_data.edge_label_index).view(-1)
    val_loss = criterion(out, val_data.edge_label)
    # cal metrics
    out = model(test_data, test_data.edge_label_index).view(-1)
    model.train()

    auc, ap = get_metrics(out, test_data.edge_label)

    return val_loss, auc, ap