https://medium.com/@pytorch_geometric/link-prediction-on-heterogeneous-graphs-with-pyg-6d5c29677c70

这里，我们尝试，不要初始特征，让初始特征变成embedding。

In [1]:
import os.path as osp
import torch
from torch_geometric.utils import negative_sampling
from torch_geometric.datasets import Planetoid
import torch_geometric.transforms as T
from torch_geometric.utils import train_test_split_edges
import torch.nn.functional as F
from sklearn.metrics import roc_auc_score

In [2]:
import pandas as pd
import numpy as np

In [3]:
df_ori = pd.read_csv("sample.csv").head(10000)
mapping = {
    ci: idx for idx, ci in enumerate(sorted(list(set(df_ori.company_id.to_list() + df_ori.outcompany_id.to_list()))))
}
for col in df_ori:
    df_ori[col] = df_ori[col].map(mapping)

In [4]:
from torch_geometric.data import Data

data = Data(
    num_nodes = len(mapping),
    edge_index=torch.tensor(
        df_ori.T.to_numpy(), 
        dtype = torch.long
    )
)

In [5]:
data.is_directed()

True

In [6]:
n_nodes = data.num_nodes
hidden_channels = 64

In [7]:
transform = T.RandomLinkSplit(
    num_val=0.1,
    num_test=0.1,
    disjoint_train_ratio=0.3,
    neg_sampling_ratio=2.0,
    add_negative_train_samples=False,    
)
train_data, val_data, test_data = transform(data)

In [8]:
from torch_geometric.loader import LinkNeighborLoader
train_loader = LinkNeighborLoader(
    data=train_data,
    num_neighbors=[-1, -1],
    neg_sampling_ratio=2.0,
    batch_size=128,
    shuffle=True,
)

In [9]:
from torch_geometric.nn import SAGEConv, GCNConv, GATConv
import torch.nn.functional as F
class GNN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = GCNConv(in_channels, 128)
        self.conv2 = GCNConv(128, out_channels)
    def forward(self, x, edge_index) :
        x = F.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

class Classifier(torch.nn.Module):
    def forward(self, x_from, x_to,):
        return (x_from * x_to).sum(dim=-1)

class Model(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.emb = torch.nn.Embedding(n_nodes, in_channels)
        self.gnn = GNN(in_channels, out_channels)
        self.classifier = Classifier()
    def forward(self, data):
        x_out = self.gnn(
            self.emb(data.n_id), 
            data.edge_label_index
        )
        pred = self.classifier(
            x_out[data.edge_label_index[0]], ## 边的起始点。
            x_out[data.edge_label_index[-1]] ## 边的终结点。
        )
        return pred
        
model = Model(in_channels=hidden_channels, out_channels=64)

weights = model.emb.weight.detach().numpy()
pd.DataFrame(weights, columns = [f"col_{i}" for i in range(weights.shape[1])])

Unnamed: 0,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7,col_8,col_9,...,col_54,col_55,col_56,col_57,col_58,col_59,col_60,col_61,col_62,col_63
0,0.673093,0.084448,-0.912523,-0.302799,-0.300071,0.148199,-0.028175,0.053041,-0.229686,1.127988,...,0.389781,-0.856533,-0.449655,-0.391108,-0.030861,-0.618855,-0.510815,-2.121732,-1.414026,-0.864564
1,-0.574482,-0.730405,-1.120228,-0.373633,0.562438,0.208483,1.358662,0.381027,-1.304625,-0.547326,...,0.581155,-1.152089,0.408686,1.157466,-0.394275,-0.680969,-0.432629,-1.179956,0.032631,0.759917
2,0.076137,0.144809,-0.573109,-0.006288,1.116953,-0.340451,0.098864,-0.990979,-0.678256,1.452438,...,-0.943300,-0.206855,-0.152801,-0.351507,-0.295181,2.062279,0.487851,1.243069,0.213492,0.315391
3,0.324575,0.281586,-0.484736,0.201174,1.713745,-0.227687,0.478808,0.529208,-0.407379,1.539424,...,0.758681,-0.591228,1.039135,1.514455,-0.248104,1.121391,-0.796028,-1.197440,0.133543,-1.358330
4,1.282399,-1.414037,-0.428685,-1.069791,-1.809425,0.252699,-0.454228,0.073880,-1.262417,0.529732,...,-0.370654,0.499203,0.049004,0.070215,-0.888624,-1.435283,-0.642115,-0.910152,1.684550,-0.439363
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15805,0.379300,-0.222644,0.508146,0.337009,0.059634,-0.379873,0.820893,-0.447036,1.869698,2.545470,...,0.081788,2.005803,0.533785,1.160504,1.374256,1.726806,1.054099,1.478007,-0.359950,1.342865
15806,0.442348,1.286383,-0.647498,-0.411736,-0.126161,-0.816695,1.205098,0.301594,-1.502892,0.319701,...,-1.506700,0.357390,0.290817,-1.124178,0.556477,-0.579499,-0.711793,-0.069983,0.243077,-0.096684
15807,0.464442,-1.507516,-0.167446,1.531604,-0.052680,1.639775,-0.034988,0.807626,-0.684235,0.160630,...,0.157994,-1.391207,0.472000,-1.011287,-2.118191,0.999438,-0.361703,-0.741744,-0.797778,0.090970
15808,-0.236571,0.081072,1.797068,1.227335,0.070513,0.748204,-0.406728,-0.753957,2.189181,0.600518,...,0.736173,-0.734787,-0.141054,-0.682156,-2.210488,0.407994,-1.195266,-0.274191,-0.078927,0.015276


In [10]:
# !pip install torch-scatter torch-sparse -f https://data.pyg.org/whl/torch-2.1.0+cu111.html
# import torch_geometric
# torch_geometric.typing.WITH_TORCH_SPARSE

In [11]:
import tqdm
import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: '{device}'")
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
for epoch in range(1, 10):
    total_loss = total_examples = 0
    for sampled_data in tqdm.tqdm(train_loader):        
        optimizer.zero_grad()
        sampled_data.to(device)
        pred = model(sampled_data)
        ground_truth = sampled_data.edge_label
        loss = F.binary_cross_entropy_with_logits(pred, ground_truth)
        loss.backward()
        optimizer.step()
        total_loss += float(loss) * pred.numel()
        total_examples += pred.numel()
    
#     break
    
    print(f"Epoch: {epoch:03d}, Loss: {total_loss / total_examples:.4f}")
    

Device: 'cpu'


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 102.11it/s]


Epoch: 001, Loss: 8.0700


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 113.10it/s]


Epoch: 002, Loss: 2.4143


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 106.35it/s]


Epoch: 003, Loss: 1.4118


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 119.72it/s]


Epoch: 004, Loss: 1.1143


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 119.66it/s]


Epoch: 005, Loss: 0.9752


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 115.49it/s]


Epoch: 006, Loss: 0.8981


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 114.37it/s]


Epoch: 007, Loss: 0.8461


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 119.29it/s]


Epoch: 008, Loss: 0.8138


100%|██████████████████████████████████████████| 44/44 [00:00<00:00, 115.31it/s]

Epoch: 009, Loss: 0.7880





In [12]:
# !pip install torch-scatter torch-sparse torch-cluster torch-spline-conv torch-geometric

In [13]:
test_loader = LinkNeighborLoader(
    data=val_data,
    num_neighbors=[10, 5],
    neg_sampling_ratio=2.0,
    batch_size=128,
    shuffle=True,
)

In [14]:
from sklearn.metrics import roc_auc_score
preds = []
ground_truths = []
for sampled_data in tqdm.tqdm(test_loader):
    with torch.no_grad():
        sampled_data.to(device)
        preds.append(model(sampled_data))
        ground_truths.append(sampled_data.edge_label)
pred = torch.cat(preds, dim=0).cpu().numpy()
ground_truth = torch.cat(ground_truths, dim=0).cpu().numpy()
auc = roc_auc_score(ground_truth, pred)
print()
print(f"Validation AUC: {auc:.4f}")

100%|██████████████████████████████████████████| 63/63 [00:00<00:00, 286.98it/s]


Validation AUC: 0.5604





In [15]:
model.emb

Embedding(15810, 64)

In [16]:
weights = model.emb.weight.detach().numpy()
pd.DataFrame(weights, columns = [f"col_{i}" for i in range(weights.shape[1])])

Unnamed: 0,col_0,col_1,col_2,col_3,col_4,col_5,col_6,col_7,col_8,col_9,...,col_54,col_55,col_56,col_57,col_58,col_59,col_60,col_61,col_62,col_63
0,0.669065,0.077524,-0.920855,-0.296226,-0.308748,0.141634,-0.039478,0.058258,-0.227015,1.121350,...,0.386672,-0.846345,-0.444276,-0.387542,-0.021972,-0.622189,-0.528684,-2.114470,-1.407983,-0.875418
1,-0.568167,-0.745433,-1.132718,-0.362444,0.577328,0.219588,1.362991,0.384137,-1.301211,-0.551781,...,0.571873,-1.143167,0.395442,1.108930,-0.391104,-0.658486,-0.430655,-1.178692,0.044567,0.753442
2,0.081950,0.135864,-0.585260,0.004492,1.116560,-0.346169,0.094464,-0.982604,-0.709313,1.446680,...,-0.933353,-0.199456,-0.158422,-0.338428,-0.307981,2.049353,0.476412,1.242506,0.227790,0.304634
3,0.300844,0.289991,-0.469499,0.176004,1.711455,-0.227131,0.482339,0.546135,-0.389559,1.548811,...,0.757717,-0.582577,1.016637,1.515538,-0.243395,1.137132,-0.793814,-1.195644,0.138800,-1.340938
4,1.297667,-1.411987,-0.440514,-1.057410,-1.804011,0.261665,-0.455122,0.069703,-1.262537,0.542551,...,-0.366746,0.493233,0.063764,0.059564,-0.875817,-1.428237,-0.638704,-0.912958,1.682740,-0.427236
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15805,0.384092,-0.241503,0.525978,0.367077,0.048377,-0.368888,0.824629,-0.423710,1.851531,2.530833,...,0.067745,1.978428,0.563053,1.178492,1.357648,1.709931,1.043315,1.479352,-0.354581,1.350115
15806,0.453145,1.276647,-0.671884,-0.393321,-0.115286,-0.802233,1.186930,0.316034,-1.516867,0.303651,...,-1.495722,0.344023,0.272564,-1.114990,0.556404,-0.564846,-0.701774,-0.043099,0.268811,-0.103387
15807,0.481912,-1.498320,-0.182975,1.523987,-0.049871,1.635256,-0.031260,0.800568,-0.678790,0.184470,...,0.161116,-1.374936,0.477708,-1.027672,-2.113470,0.986944,-0.354879,-0.741725,-0.784854,0.108101
15808,-0.215172,0.061211,1.794581,1.211783,0.033802,0.735371,-0.399514,-0.724306,2.170383,0.595437,...,0.753187,-0.733203,-0.104428,-0.662712,-2.195806,0.437034,-1.165709,-0.238060,-0.112751,0.029646
