In [None]:
import torch
import torch.nn as nn
import torch_geometric.nn as tgnn
from typing import Tuple
from datasets import Dataset
import torch.utils.data as data
from tqdm import tqdm

In [None]:
lr = 0.005
weight_decay = 0.001
batch_size = 2048
epochs = 40

## 0. load data
### 0.1 raw data

In [None]:
dataset = Dataset.load_from_disk('/Users/xiaoen/Documents/科研/论文/链接预测/TOD-Code/data/traindataset')

In [None]:
train_dataset = dataset['data'][0]
test_dataset = dataset['data'][1]

### 0.2 data features

In [None]:
train_x = train_dataset['x_feature']
train_edge_index = train_dataset['edge_index']
train_label_edge_index = train_dataset['label_edge']
train_edge_feature = train_dataset['label_edge_feature']
train_label = train_dataset['label']

In [None]:
test_x = test_dataset['x_feature']
test_edge_index = test_dataset['edge_index']
test_label_edge_index = test_dataset['label_edge']
test_edge_feature = test_dataset['label_edge_feature']
test_label = test_dataset['label']

## 1. defined
### 1.1 GAT model

In [None]:
class GAT(nn.Module):
    def __init__(
            self,
            dim_in: int,
            hidden_size: int,
            output_size: int,
            dropout: float = 0.1
    ):
        super().__init__()
        self.dim_in = dim_in
        self.dropout = dropout

        self.conv1 = tgnn.GATv2Conv(
            in_channels=dim_in,
            out_channels=hidden_size,
            heads=2,
            dropout=self.dropout,
            residual=True
        )

        self.conv2 = tgnn.GATv2Conv(
            in_channels=hidden_size * 2,
            out_channels=output_size,
            heads=1,
            dropout=self.dropout,
            residual=True
        )

        self.dim_out = output_size

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor) -> torch.Tensor:
        h = self.conv1(x, edge_index)
        h = nn.functional.dropout(h, p=self.dropout, training=self.training)

        h = self.conv2(h, edge_index)
        h = nn.functional.dropout(h, p=self.dropout, training=self.training)

        return h

### 1.2 Link Prediction model

In [None]:
class LinkPredModel(nn.Module):
    def __init__(
            self,
            dim_in: int,
            gnn_hidden_size: int,
            gnn_output_size: int,
            hidden_dims: Tuple[int, ...] = (16, 4),
            dropout: float = 0.1,
    ):
        super().__init__()

        self.gnn = GAT(
            dim_in=dim_in,
            hidden_size=gnn_hidden_size,
            output_size=gnn_output_size,
            dropout=dropout
        )

        self.input_dim = self.gnn.dim_out + 16
        self.dropout = dropout

        # MLP 层
        self.in_layer = nn.Linear(self.input_dim, hidden_dims[0])
        self.hidden_layer = nn.Linear(hidden_dims[0], hidden_dims[1])
        self.out_layer = nn.Linear(hidden_dims[1], 1)

        # 激活函数和批归一化
        self.lrelu = nn.LeakyReLU(0.01)
        self.bn0 = nn.BatchNorm1d(self.input_dim)
        self.bn1 = nn.BatchNorm1d(hidden_dims[0])
        self.bn2 = nn.BatchNorm1d(hidden_dims[1])

    def forward(self, x: torch.Tensor, edge_index: torch.Tensor, label_edge_index: torch.Tensor, label_edge_feature: torch.Tensor) -> torch.Tensor:
        h = self.gnn(x, edge_index)

        h_src = h[label_edge_index[0, :]]
        h_dst = h[label_edge_index[1, :]]

        src_dst_mult = h_src * h_dst

        all_features = torch.cat([src_dst_mult, label_edge_feature], dim=1)

        _out = self.bn0(all_features)

        _out = self.in_layer(_out)
        _out = self.bn1(_out)
        _out = self.lrelu(_out)
        _out = nn.functional.dropout(_out, p=self.dropout, training=self.training)

        _out = self.hidden_layer(_out)
        _out = self.bn2(_out)
        _out = self.lrelu(_out)
        _out = nn.functional.dropout(_out, p=self.dropout, training=self.training)

        _out = self.out_layer(_out)

        return _out

## 2. preparation component
### 2.0 GPU

In [None]:
device = torch.device("mps")

### 2.1 model

In [None]:
LP = LinkPredModel(
    dim_in=100,
    gnn_hidden_size=16,
    gnn_output_size=10,
    dropout=0.1
).to(device)

### 2.2 loss

In [None]:
loss_fn = nn.BCEWithLogitsLoss()

### 2.3 loader

In [None]:
class GraphDataset(data.Dataset):
    def __init__(self, _label_edge_index, _label_edge_feature, _label):
        self.label_edge_index = torch.tensor(_label_edge_index, dtype=torch.int64).T
        self.label_edge_feature = torch.tensor(_label_edge_feature, dtype=torch.float)
        self.label = torch.tensor(_label, dtype=torch.float)

    def __len__(self):
        return len(self.label_edge_index[0])
    
    def __getitem__(self, idx):
        return self.label_edge_index[:, idx], self.label_edge_feature[idx], self.label[idx]

In [None]:
def get_loader(_x, _edge_index, _label_edge_index, _label_edge_feature, _label,_batch_size=batch_size):
    _x = torch.tensor(_x, dtype=torch.float)
    _edge_index = torch.tensor(_edge_index, dtype=torch.int64).T
    _dataset = GraphDataset(_label_edge_index, _label_edge_feature, _label)
    graph_loader = data.DataLoader(_dataset, batch_size=_batch_size, shuffle=True)
    return _x, _edge_index, graph_loader

## 3. train

In [None]:
LP.train()

In [None]:
d_x, d_edge_index, loader = get_loader(train_x, train_edge_index, train_label_edge_index, train_edge_feature, train_label)

In [None]:
for epoch in range(epochs):
    train_bar = tqdm(loader)
    running_loss = []
    if epoch != 0:
        if epoch % 10 == 0:
            lr = lr * 0.5
            weight_decay = weight_decay * 0.5
    optimizer = torch.optim.Adam(LP.parameters(), lr=lr, weight_decay=weight_decay)
    for d_label_edge_index, d_label_edge_feature, d_label in train_bar:
        optimizer.zero_grad()
        out = LP(d_x.to(device), d_edge_index.to(device), d_label_edge_index.t().to(device), d_label_edge_feature.to(device))
        loss = loss_fn(out.squeeze(), d_label.float().to(device))
        running_loss.append(loss.item())
        loss.backward()
        optimizer.step()
        train_bar.set_description(f'Epoch {epoch}, loss: {sum(running_loss)/len(running_loss):.4f}')