In [1]:
import random
import argparse
import numpy as np
import networkx as nx

import pickle

import torch
import torch.nn.functional as F
import torch_sparse
from torch import Tensor
from torch.nn import Linear
from torch.utils.data import DataLoader
from torch_geometric.utils import negative_sampling, to_networkx
from typing import Union, Tuple
from torch_geometric.typing import OptPairTensor, Adj, OptTensor, Size
from torch_sparse import SparseTensor, matmul
from torch_geometric.nn.conv import MessagePassing

from ogb.linkproppred import PygLinkPropPredDataset, Evaluator

In [2]:
# Load the ogbl-ddi dataset
# dataset_name = 'ogbl-ddi'
dataset = PygLinkPropPredDataset(name='ogbl-ddi')
# dataset = PygLinkPropPredDataset(name='ogbl-ddi',
#                                      transform=T.ToSparseTensor())
print(f'The ogbl-ddi dataset has {len(dataset)} graph(s).')

The ogbl-ddi dataset has 1 graph(s).


In [3]:
with open('data/node_features_PCA.pickle', 'rb') as handle:
    node_features = pickle.load(handle)
with open('data/train_pos.pickle', 'rb') as handle:
    train_pos = pickle.load(handle)
with open('data/valid_pos.pickle', 'rb') as handle:
    valid_pos = pickle.load(handle)
with open('data/valid_neg.pickle', 'rb') as handle:
    valid_neg = pickle.load(handle)
with open('data/test_pos.pickle', 'rb') as handle:
    test_pos = pickle.load(handle)
with open('data/test_neg.pickle', 'rb') as handle:
    test_neg = pickle.load(handle)
with open('data/nodes.pickle', 'rb') as handle:
    nodes = pickle.load(handle)
with open('data/nodes_removed.pickle', 'rb') as handle:
    nodes_removed = pickle.load(handle)

In [4]:
print(f'Number of training positive edges: {train_pos.shape[0]}')
print(f'Number of validation positive edges: {valid_pos.shape[0]}')
print(f'Number of validation negative edges: {valid_neg.shape[0]}')
print(f'Number of test positive edges: {test_pos.shape[0]}')
print(f'Number of test negative edges: {test_neg.shape[0]}')

Number of training positive edges: 887234
Number of validation positive edges: 111117
Number of validation negative edges: 74182
Number of test positive edges: 114869
Number of test negative edges: 70962


In [5]:
nodes = torch.tensor(nodes)

In [6]:
data = dataset[0].subgraph(nodes)

In [7]:
def get_spd_matrix(G, S, max_spd=5):
    spd_matrix = np.zeros((G.number_of_nodes(), len(S)), dtype=np.float32)
    for i, node_S in enumerate(S):
        for node, length in nx.shortest_path_length(G, source=node_S).items():
            spd_matrix[node, i] = min(length, max_spd)
    return spd_matrix

In [8]:
class Logger(object):
    def __init__(self, runs, info=None):
        self.info = info
        self.results = [[] for _ in range(runs)]

    def add_result(self, run, result):
        assert len(result) == 2
        assert run >= 0 and run < len(self.results)
        self.results[run].append(result)

    def print_statistics(self, run=None):
        if run is not None:
            result = 100 * torch.tensor(self.results[run])
            argmax = result[:, 0].argmax().item()
            print(f'Run {run + 1:02d}:')
            print(f'Highest Valid: {result[:, 0].max():.2f}')
            print(f'   Final Test: {result[argmax, 1]:.2f}')
        else:
            result = 100 * torch.tensor(self.results)
            best_results = []
            for r in result:
                valid = r[:, 0].max().item()
                test = r[r[:, 0].argmax(), 1].item()
                best_results.append((valid, test))
            best_result = torch.tensor(best_results)
            print(f'All runs:')
            r = best_result[:, 0]
            print(f'Highest Valid: {r.mean():.4f} ± {r.std():.4f}')
            r = best_result[:, 1]
            print(f'   Final Test: {r.mean():.4f} ± {r.std():.4f}')


class SAGEConv(MessagePassing):
    r"""The GraphSAGE operator from the `"Inductive Representation Learning on
    Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot
        \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j

    Args:
        in_channels (int or tuple): Size of each input sample. A tuple
            corresponds to the sizes of source and target dimensionalities.
        out_channels (int): Size of each output sample.
        normalize (bool, optional): If set to :obj:`True`, output features
            will be :math:`\ell_2`-normalized, *i.e.*,
            :math:`\frac{\mathbf{x}^{\prime}_i}
            {\| \mathbf{x}^{\prime}_i \|_2}`.
            (default: :obj:`False`)
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """
    def __init__(self, in_channels: Union[int, Tuple[int, int]],
                 out_channels: int, normalize: bool = False,
                 root_weight: bool = True,
                 bias: bool = True, **kwargs):  # yapf: disable
        kwargs.setdefault('aggr', 'mean')
        super(SAGEConv, self).__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalize = normalize
        self.root_weight = root_weight

        if isinstance(in_channels, int):
            in_channels = (in_channels, in_channels)

        self.lin_l = Linear(in_channels[0], out_channels, bias=bias)
        if self.root_weight:
            self.lin_r = Linear(in_channels[1], out_channels, bias=False)

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        if self.root_weight:
            self.lin_r.reset_parameters()


    def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
                edge_attr: OptTensor = None, size: Size = None) -> Tensor:
        """"""
        if isinstance(x, Tensor):
            x: OptPairTensor = (x, x)

        # Node and edge feature dimensionalites need to match.
        if isinstance(edge_index, Tensor):
            assert edge_attr is not None
            assert x[0].size(-1) == edge_attr.size(-1)
        elif isinstance(edge_index, SparseTensor):
            assert x[0].size(-1) == edge_index.size(-1)

        # propagate_type: (x: OptPairTensor, edge_attr: OptTensor)
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=size)
        out = self.lin_l(out)

        x_r = x[1]
        if self.root_weight and x_r is not None:
            out += self.lin_r(x_r)

        if self.normalize:
            out = F.normalize(out, p=2., dim=-1)

        return out


    def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor:
        return F.relu(x_j + edge_attr)


    def __repr__(self):
        return '{}({}, {})'.format(self.__class__.__name__, self.in_channels,
                                   self.out_channels)


class GraphSAGE(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers, dropout):
        super(GraphSAGE,self).__init__()
        self.convs = torch.nn.ModuleList()
        self.convs.append(SAGEConv(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.convs.append(SAGEConv(hidden_channels, hidden_channels))
        self.convs.append(SAGEConv(hidden_channels, out_channels))
        self.dropout = dropout

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()

    def forward(self, x, adj_t, edge_attr, emb_ea):
        edge_attr = torch.mm(edge_attr, emb_ea)
        for conv in self.convs[:-1]:
            x = conv(x, adj_t, edge_attr)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.convs[-1](x, adj_t, edge_attr)
        return x


class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels, num_layers,
                 dropout):
        super(LinkPredictor, self).__init__()

        self.lins = torch.nn.ModuleList()
        self.lins.append(torch.nn.Linear(in_channels, hidden_channels))
        for _ in range(num_layers - 2):
            self.lins.append(torch.nn.Linear(hidden_channels, hidden_channels))
        self.lins.append(torch.nn.Linear(hidden_channels, out_channels))

        self.dropout = dropout

    def reset_parameters(self):
        for lin in self.lins:
            lin.reset_parameters()

    def forward(self, x_i, x_j):
        x = x_i * x_j
        for lin in self.lins[:-1]:
            x = lin(x)
            x = F.relu(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.lins[-1](x)
        return torch.sigmoid(x)


In [9]:
def train(model, predictor, edge_attr, x, emb_ea, adj_t, train_edge, optimizer, batch_size):
    edge_index = adj_t

    model.train()
    predictor.train()

    pos_train_edge = train_edge.to(x.device)

    total_loss = total_examples = 0
    for perm in DataLoader(range(pos_train_edge.size(0)), batch_size, shuffle=True):
        optimizer.zero_grad()

        h = model(x, adj_t, edge_attr, emb_ea)

        edge = pos_train_edge[perm].t()

        pos_out = predictor(h[edge[0]], h[edge[1]])
        pos_loss = -torch.log(pos_out + 1e-15).mean()

        edge = negative_sampling(edge_index, num_nodes=x.size(0),
                                 num_neg_samples=perm.size(0), method='dense')

        neg_out = predictor(h[edge[0]], h[edge[1]])
        neg_loss = -torch.log(1 - neg_out + 1e-15).mean()

        loss = pos_loss + neg_loss
        loss.backward()

        torch.nn.utils.clip_grad_norm_(x, 1.0)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)

        optimizer.step()

        num_examples = pos_out.size(0)
        total_loss += loss.item() * num_examples
        total_examples += num_examples

    return total_loss / total_examples


In [10]:
all_pos_valid_pred = []
all_neg_valid_pred = []
all_pos_test_pred = []
all_neg_test_pred = []
@torch.no_grad()
def test(model, predictor, edge_attr, x, emb_ea, adj_t, pos_valid, neg_valid, pos_test, neg_test, evaluator, batch_size):
    model.eval()
    predictor.eval()

    h = model(x, adj_t, edge_attr, emb_ea)

    pos_valid_edge = pos_valid.to(x.device)
    neg_valid_edge = neg_valid.to(x.device)
    pos_test_edge = pos_test.to(x.device)
    neg_test_edge = neg_test.to(x.device)

    pos_valid_preds = []
    for perm in DataLoader(range(pos_valid_edge.size(0)), batch_size):
        edge = pos_valid_edge[perm].t()
        pos_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_valid_pred = torch.cat(pos_valid_preds, dim=0)

    neg_valid_preds = []
    for perm in DataLoader(range(neg_valid_edge.size(0)), batch_size):
        edge = neg_valid_edge[perm].t()
        neg_valid_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_valid_pred = torch.cat(neg_valid_preds, dim=0)

    pos_test_preds = []
    for perm in DataLoader(range(pos_test_edge.size(0)), batch_size):
        edge = pos_test_edge[perm].t()
        pos_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    pos_test_pred = torch.cat(pos_test_preds, dim=0)

    neg_test_preds = []
    for perm in DataLoader(range(neg_test_edge.size(0)), batch_size):
        edge = neg_test_edge[perm].t()
        neg_test_preds += [predictor(h[edge[0]], h[edge[1]]).squeeze().cpu()]
    neg_test_pred = torch.cat(neg_test_preds, dim=0)

    results = {}
    for K in [20, 50, 100]:
        evaluator.K = K
        valid_hits = evaluator.eval({
            'y_pred_pos': pos_valid_pred,
            'y_pred_neg': neg_valid_pred,
        })[f'hits@{K}']
        test_hits = evaluator.eval({
            'y_pred_pos': pos_test_pred,
            'y_pred_neg': neg_test_pred,
        })[f'hits@{K}']
        
        if float(valid_hits) > 0.3 and float(test_hits) > 0.3:
            all_pos_valid_pred.append(pos_valid_pred)
            all_neg_valid_pred.append(neg_valid_pred)
            all_pos_test_pred.append(pos_test_pred)
            all_neg_test_pred.append(neg_test_pred)
        


        results[f'Hits@{K}'] = (valid_hits, test_hits)

    return results

In [11]:
node_features.size()

torch.Size([3539, 96])

In [13]:
device = torch.device('cpu')
num_layers = 3
num_samples = 5
node_emb = 96
hidden_channels = 96
dropout = 0.2
batch_size = 64 * 1024
lr = 0.003
epochs = 400
log_steps = 1
eval_steps = 1
runs = 1

In [14]:
edge_index = data.edge_index.to(device)

In [15]:
model = GraphSAGE(node_emb, hidden_channels, hidden_channels,
                      num_layers, dropout).to(device)
emb = torch.nn.Embedding.from_pretrained(node_features, freeze=False).to(device)
emb_ea = torch.nn.Embedding(num_samples, node_emb).to(device)
predictor = LinkPredictor(hidden_channels, hidden_channels, 1,
                              num_layers+1, dropout).to(device)

In [24]:
# model = GraphSAGE(node_emb, hidden_channels, hidden_channels, 
#                       num_layers, dropout).to(device)
# emb = torch.nn.Embedding(data.num_nodes, node_emb).to(device)
# emb_ea = torch.nn.Embedding(num_samples, node_emb).to(device)
# predictor = LinkPredictor(hidden_channels, hidden_channels, 1,
#                               num_layers+1, dropout).to(device)

In [16]:
emb.weight

Parameter containing:
tensor([[-7.4891e-02,  1.4025e-02,  1.2281e+00,  ..., -3.5495e-04,
         -8.6586e-04,  5.4261e-04],
        [ 3.2602e-01, -4.1421e-01,  9.6334e-01,  ...,  8.1558e-04,
         -5.5911e-04,  8.3106e-05],
        [ 3.2602e-01, -4.1421e-01,  9.6334e-01,  ...,  8.1558e-04,
         -5.5911e-04,  8.3106e-05],
        ...,
        [ 2.9628e-01,  4.8572e-01,  9.0430e-01,  ...,  3.3874e-05,
          1.8812e-03,  4.0076e-03],
        [ 1.8145e+00, -9.2635e-02,  2.6241e-01,  ...,  3.2484e-04,
         -7.3765e-04,  7.6157e-04],
        [ 4.0041e-01, -5.6188e-01,  1.1202e+00,  ...,  2.0947e-03,
         -9.5269e-04,  1.2110e-03]], requires_grad=True)

In [17]:
print('Number of parameters:',
          sum(p.numel() for p in list(model.parameters()) +
          list(predictor.parameters()) + list(emb.parameters()) + list(emb_ea.parameters())))

Number of parameters: 423841


In [18]:
np.random.seed(0)
nx_graph = to_networkx(data, to_undirected=True)
node_mask = []
for _ in range(num_samples):
    node_mask.append(np.random.choice(500, size=200, replace=False))
node_mask = np.array(node_mask)
node_subset = np.random.choice(nx_graph.number_of_nodes(), size=500, replace=False)
spd = get_spd_matrix(G=nx_graph, S=node_subset, max_spd=5)
spd = torch.Tensor(spd).to(device)

In [19]:
edge_attr = spd[edge_index, :].mean(0)[:, node_mask].mean(2)

In [20]:
a_max = torch.max(edge_attr, dim=0, keepdim=True)[0]
a_min = torch.min(edge_attr, dim=0, keepdim=True)[0]
edge_attr = (edge_attr - a_min) / (a_max - a_min + 1e-6)

In [21]:
import argparse

In [22]:
parser = argparse.ArgumentParser(description='Link_Pred_DDI')
parser.add_argument('--device', type=int, default=0)
parser.add_argument('--num_layers', type=int, default=3)
parser.add_argument('--num_samples', type=int, default=5)
parser.add_argument('--node_emb', type=int, default=96)
parser.add_argument('--hidden_channels', type=int, default=96)
parser.add_argument('--dropout', type=float, default=0.2)
parser.add_argument('--batch_size', type=int, default=64 * 1024)
parser.add_argument('--lr', type=float, default=0.003)
parser.add_argument('--epochs', type=int, default=200)
parser.add_argument('--log_steps', type=int, default=1)
parser.add_argument('--eval_steps', type=int, default=1)
parser.add_argument('--runs', type=int, default=1)
parser.add_argument("-f", "--file", required=False)
args = parser.parse_args()
print(args)

Namespace(device=0, num_layers=3, num_samples=5, node_emb=96, hidden_channels=96, dropout=0.2, batch_size=65536, lr=0.003, epochs=200, log_steps=1, eval_steps=1, runs=1, file='C:\\Users\\AGM2\\AppData\\Roaming\\jupyter\\runtime\\kernel-e5cec5c6-faa2-4b69-ad5b-af2760c0f904.json')


In [24]:
evaluator = Evaluator(name='ogbl-ddi')
loggers = {
    'Hits@20': Logger(runs, args),
    'Hits@50': Logger(args.runs, args),
    'Hits@100': Logger(args.runs, args),
}

In [25]:
for run in range(runs):
    random.seed(run)
    torch.manual_seed(run)
#     torch.nn.init.xavier_uniform_(emb.weight)
#     torch.nn.init.xavier_uniform_(emb_ea.weight)
    model.reset_parameters()
    predictor.reset_parameters()
    optimizer = torch.optim.Adam(
        list(model.parameters()) + list(emb.parameters()) +
        list(emb_ea.parameters()) + list(predictor.parameters()), lr=lr)

    for epoch in range(1, 1 + epochs):
        loss = train(model, predictor, edge_attr, emb.weight, emb_ea.weight, edge_index, train_pos,
                      optimizer, batch_size)

        if epoch % eval_steps == 0:
            results = test(model, predictor, edge_attr, emb.weight, emb_ea.weight, edge_index, valid_pos, valid_neg, test_pos, test_neg,
                            evaluator, batch_size)
            for key, result in results.items():
                loggers[key].add_result(run, result)

            if epoch % log_steps == 0:
                for key, result in results.items():
                    valid_hits, test_hits = result
                    print(key)
                    print(f'Run: {run + 1:02d}, '
                          f'Epoch: {epoch:02d}, '
                          f'Loss: {loss:.4f}, '
                          f'Valid: {100 * valid_hits:.2f}%, '
                          f'Test: {100 * test_hits:.2f}%')
                print('---')

    for key in loggers.keys():
        print(key)
        loggers[key].print_statistics(run)

for key in loggers.keys():
    print(key)
    loggers[key].print_statistics()

Hits@20
Run: 01, Epoch: 01, Loss: 1.3696, Valid: 2.07%, Test: 1.99%
Hits@50
Run: 01, Epoch: 01, Loss: 1.3696, Valid: 3.36%, Test: 2.87%
Hits@100
Run: 01, Epoch: 01, Loss: 1.3696, Valid: 5.10%, Test: 4.68%
---
Hits@20
Run: 01, Epoch: 02, Loss: 1.4234, Valid: 2.31%, Test: 2.00%
Hits@50
Run: 01, Epoch: 02, Loss: 1.4234, Valid: 3.86%, Test: 3.55%
Hits@100
Run: 01, Epoch: 02, Loss: 1.4234, Valid: 5.73%, Test: 5.31%
---
Hits@20
Run: 01, Epoch: 03, Loss: 1.0864, Valid: 0.23%, Test: 0.39%
Hits@50
Run: 01, Epoch: 03, Loss: 1.0864, Valid: 0.76%, Test: 0.92%
Hits@100
Run: 01, Epoch: 03, Loss: 1.0864, Valid: 1.87%, Test: 2.23%
---
Hits@20
Run: 01, Epoch: 04, Loss: 0.8934, Valid: 1.67%, Test: 3.25%
Hits@50
Run: 01, Epoch: 04, Loss: 0.8934, Valid: 3.07%, Test: 5.12%
Hits@100
Run: 01, Epoch: 04, Loss: 0.8934, Valid: 5.64%, Test: 7.63%
---
Hits@20
Run: 01, Epoch: 05, Loss: 0.8402, Valid: 2.86%, Test: 4.92%
Hits@50
Run: 01, Epoch: 05, Loss: 0.8402, Valid: 4.60%, Test: 7.86%
Hits@100
Run: 01, Epoch: 05,

Hits@20
Run: 01, Epoch: 40, Loss: 0.4053, Valid: 21.91%, Test: 18.59%
Hits@50
Run: 01, Epoch: 40, Loss: 0.4053, Valid: 32.74%, Test: 28.73%
Hits@100
Run: 01, Epoch: 40, Loss: 0.4053, Valid: 44.63%, Test: 37.48%
---
Hits@20
Run: 01, Epoch: 41, Loss: 0.4016, Valid: 28.08%, Test: 21.30%
Hits@50
Run: 01, Epoch: 41, Loss: 0.4016, Valid: 37.82%, Test: 31.37%
Hits@100
Run: 01, Epoch: 41, Loss: 0.4016, Valid: 47.78%, Test: 42.11%
---
Hits@20
Run: 01, Epoch: 42, Loss: 0.3996, Valid: 25.21%, Test: 20.98%
Hits@50
Run: 01, Epoch: 42, Loss: 0.3996, Valid: 36.12%, Test: 31.09%
Hits@100
Run: 01, Epoch: 42, Loss: 0.3996, Valid: 47.34%, Test: 38.67%
---
Hits@20
Run: 01, Epoch: 43, Loss: 0.3939, Valid: 23.81%, Test: 17.41%
Hits@50
Run: 01, Epoch: 43, Loss: 0.3939, Valid: 33.37%, Test: 30.18%
Hits@100
Run: 01, Epoch: 43, Loss: 0.3939, Valid: 44.52%, Test: 40.60%
---
Hits@20
Run: 01, Epoch: 44, Loss: 0.3909, Valid: 23.18%, Test: 18.27%
Hits@50
Run: 01, Epoch: 44, Loss: 0.3909, Valid: 31.14%, Test: 28.90%


Hits@20
Run: 01, Epoch: 79, Loss: 0.3215, Valid: 32.20%, Test: 12.96%
Hits@50
Run: 01, Epoch: 79, Loss: 0.3215, Valid: 42.93%, Test: 18.77%
Hits@100
Run: 01, Epoch: 79, Loss: 0.3215, Valid: 52.19%, Test: 37.00%
---
Hits@20
Run: 01, Epoch: 80, Loss: 0.3212, Valid: 33.33%, Test: 10.50%
Hits@50
Run: 01, Epoch: 80, Loss: 0.3212, Valid: 43.63%, Test: 17.32%
Hits@100
Run: 01, Epoch: 80, Loss: 0.3212, Valid: 51.16%, Test: 30.25%
---
Hits@20
Run: 01, Epoch: 81, Loss: 0.3215, Valid: 32.80%, Test: 13.86%
Hits@50
Run: 01, Epoch: 81, Loss: 0.3215, Valid: 42.84%, Test: 21.03%
Hits@100
Run: 01, Epoch: 81, Loss: 0.3215, Valid: 50.33%, Test: 35.71%
---
Hits@20
Run: 01, Epoch: 82, Loss: 0.3212, Valid: 32.41%, Test: 12.49%
Hits@50
Run: 01, Epoch: 82, Loss: 0.3212, Valid: 42.20%, Test: 20.60%
Hits@100
Run: 01, Epoch: 82, Loss: 0.3212, Valid: 49.43%, Test: 31.59%
---
Hits@20
Run: 01, Epoch: 83, Loss: 0.3193, Valid: 32.97%, Test: 12.72%
Hits@50
Run: 01, Epoch: 83, Loss: 0.3193, Valid: 44.68%, Test: 18.46%


Hits@20
Run: 01, Epoch: 117, Loss: 0.2863, Valid: 42.11%, Test: 13.63%
Hits@50
Run: 01, Epoch: 117, Loss: 0.2863, Valid: 48.68%, Test: 21.49%
Hits@100
Run: 01, Epoch: 117, Loss: 0.2863, Valid: 54.33%, Test: 32.06%
---
Hits@20
Run: 01, Epoch: 118, Loss: 0.2874, Valid: 41.11%, Test: 9.02%
Hits@50
Run: 01, Epoch: 118, Loss: 0.2874, Valid: 49.47%, Test: 18.96%
Hits@100
Run: 01, Epoch: 118, Loss: 0.2874, Valid: 55.61%, Test: 36.47%
---
Hits@20
Run: 01, Epoch: 119, Loss: 0.2854, Valid: 40.97%, Test: 12.98%
Hits@50
Run: 01, Epoch: 119, Loss: 0.2854, Valid: 49.69%, Test: 22.05%
Hits@100
Run: 01, Epoch: 119, Loss: 0.2854, Valid: 56.57%, Test: 38.66%
---
Hits@20
Run: 01, Epoch: 120, Loss: 0.2866, Valid: 44.63%, Test: 11.64%
Hits@50
Run: 01, Epoch: 120, Loss: 0.2866, Valid: 52.06%, Test: 22.10%
Hits@100
Run: 01, Epoch: 120, Loss: 0.2866, Valid: 57.78%, Test: 37.01%
---
Hits@20
Run: 01, Epoch: 121, Loss: 0.2850, Valid: 41.62%, Test: 10.18%
Hits@50
Run: 01, Epoch: 121, Loss: 0.2850, Valid: 48.44%, 

Hits@20
Run: 01, Epoch: 155, Loss: 0.2644, Valid: 46.88%, Test: 7.74%
Hits@50
Run: 01, Epoch: 155, Loss: 0.2644, Valid: 56.20%, Test: 15.12%
Hits@100
Run: 01, Epoch: 155, Loss: 0.2644, Valid: 60.24%, Test: 29.20%
---
Hits@20
Run: 01, Epoch: 156, Loss: 0.2653, Valid: 45.41%, Test: 9.41%
Hits@50
Run: 01, Epoch: 156, Loss: 0.2653, Valid: 52.12%, Test: 15.25%
Hits@100
Run: 01, Epoch: 156, Loss: 0.2653, Valid: 58.30%, Test: 24.74%
---
Hits@20
Run: 01, Epoch: 157, Loss: 0.2628, Valid: 47.38%, Test: 7.45%
Hits@50
Run: 01, Epoch: 157, Loss: 0.2628, Valid: 55.35%, Test: 15.93%
Hits@100
Run: 01, Epoch: 157, Loss: 0.2628, Valid: 60.65%, Test: 28.24%
---
Hits@20
Run: 01, Epoch: 158, Loss: 0.2617, Valid: 46.03%, Test: 9.31%
Hits@50
Run: 01, Epoch: 158, Loss: 0.2617, Valid: 55.71%, Test: 18.17%
Hits@100
Run: 01, Epoch: 158, Loss: 0.2617, Valid: 59.96%, Test: 28.70%
---
Hits@20
Run: 01, Epoch: 159, Loss: 0.2630, Valid: 47.77%, Test: 8.07%
Hits@50
Run: 01, Epoch: 159, Loss: 0.2630, Valid: 54.56%, Test

Hits@20
Run: 01, Epoch: 193, Loss: 0.2498, Valid: 49.77%, Test: 5.25%
Hits@50
Run: 01, Epoch: 193, Loss: 0.2498, Valid: 59.06%, Test: 14.41%
Hits@100
Run: 01, Epoch: 193, Loss: 0.2498, Valid: 63.36%, Test: 39.50%
---
Hits@20
Run: 01, Epoch: 194, Loss: 0.2486, Valid: 54.82%, Test: 5.72%
Hits@50
Run: 01, Epoch: 194, Loss: 0.2486, Valid: 60.63%, Test: 20.77%
Hits@100
Run: 01, Epoch: 194, Loss: 0.2486, Valid: 63.83%, Test: 63.50%
---
Hits@20
Run: 01, Epoch: 195, Loss: 0.2478, Valid: 50.10%, Test: 5.15%
Hits@50
Run: 01, Epoch: 195, Loss: 0.2478, Valid: 59.43%, Test: 13.09%
Hits@100
Run: 01, Epoch: 195, Loss: 0.2478, Valid: 62.94%, Test: 31.95%
---
Hits@20
Run: 01, Epoch: 196, Loss: 0.2502, Valid: 51.87%, Test: 4.62%
Hits@50
Run: 01, Epoch: 196, Loss: 0.2502, Valid: 59.12%, Test: 18.01%
Hits@100
Run: 01, Epoch: 196, Loss: 0.2502, Valid: 63.45%, Test: 56.06%
---
Hits@20
Run: 01, Epoch: 197, Loss: 0.2483, Valid: 49.41%, Test: 4.19%
Hits@50
Run: 01, Epoch: 197, Loss: 0.2483, Valid: 58.95%, Test

Hits@20
Run: 01, Epoch: 231, Loss: 0.2405, Valid: 54.54%, Test: 7.68%
Hits@50
Run: 01, Epoch: 231, Loss: 0.2405, Valid: 61.28%, Test: 36.13%
Hits@100
Run: 01, Epoch: 231, Loss: 0.2405, Valid: 65.04%, Test: 75.51%
---
Hits@20
Run: 01, Epoch: 232, Loss: 0.2394, Valid: 52.79%, Test: 6.20%
Hits@50
Run: 01, Epoch: 232, Loss: 0.2394, Valid: 61.16%, Test: 29.74%
Hits@100
Run: 01, Epoch: 232, Loss: 0.2394, Valid: 65.31%, Test: 72.34%
---
Hits@20
Run: 01, Epoch: 233, Loss: 0.2383, Valid: 54.79%, Test: 5.05%
Hits@50
Run: 01, Epoch: 233, Loss: 0.2383, Valid: 61.54%, Test: 15.94%
Hits@100
Run: 01, Epoch: 233, Loss: 0.2383, Valid: 64.77%, Test: 66.12%
---
Hits@20
Run: 01, Epoch: 234, Loss: 0.2376, Valid: 56.60%, Test: 6.64%
Hits@50
Run: 01, Epoch: 234, Loss: 0.2376, Valid: 61.39%, Test: 36.41%
Hits@100
Run: 01, Epoch: 234, Loss: 0.2376, Valid: 65.07%, Test: 75.51%
---
Hits@20
Run: 01, Epoch: 235, Loss: 0.2374, Valid: 55.49%, Test: 5.96%
Hits@50
Run: 01, Epoch: 235, Loss: 0.2374, Valid: 61.12%, Test

Hits@20
Run: 01, Epoch: 269, Loss: 0.2308, Valid: 56.92%, Test: 8.82%
Hits@50
Run: 01, Epoch: 269, Loss: 0.2308, Valid: 62.50%, Test: 45.52%
Hits@100
Run: 01, Epoch: 269, Loss: 0.2308, Valid: 65.74%, Test: 78.02%
---
Hits@20
Run: 01, Epoch: 270, Loss: 0.2304, Valid: 57.89%, Test: 7.73%
Hits@50
Run: 01, Epoch: 270, Loss: 0.2304, Valid: 62.81%, Test: 33.28%
Hits@100
Run: 01, Epoch: 270, Loss: 0.2304, Valid: 66.09%, Test: 77.14%
---
Hits@20
Run: 01, Epoch: 271, Loss: 0.2332, Valid: 58.51%, Test: 5.15%
Hits@50
Run: 01, Epoch: 271, Loss: 0.2332, Valid: 62.76%, Test: 51.12%
Hits@100
Run: 01, Epoch: 271, Loss: 0.2332, Valid: 65.83%, Test: 76.68%
---
Hits@20
Run: 01, Epoch: 272, Loss: 0.2317, Valid: 57.26%, Test: 9.29%
Hits@50
Run: 01, Epoch: 272, Loss: 0.2317, Valid: 62.20%, Test: 46.28%
Hits@100
Run: 01, Epoch: 272, Loss: 0.2317, Valid: 66.41%, Test: 76.33%
---
Hits@20
Run: 01, Epoch: 273, Loss: 0.2303, Valid: 57.07%, Test: 9.97%
Hits@50
Run: 01, Epoch: 273, Loss: 0.2303, Valid: 62.85%, Test

Hits@20
Run: 01, Epoch: 307, Loss: 0.2242, Valid: 57.40%, Test: 8.01%
Hits@50
Run: 01, Epoch: 307, Loss: 0.2242, Valid: 64.35%, Test: 38.14%
Hits@100
Run: 01, Epoch: 307, Loss: 0.2242, Valid: 67.41%, Test: 77.93%
---
Hits@20
Run: 01, Epoch: 308, Loss: 0.2259, Valid: 54.49%, Test: 6.47%
Hits@50
Run: 01, Epoch: 308, Loss: 0.2259, Valid: 63.97%, Test: 34.05%
Hits@100
Run: 01, Epoch: 308, Loss: 0.2259, Valid: 67.17%, Test: 76.39%
---
Hits@20
Run: 01, Epoch: 309, Loss: 0.2242, Valid: 58.78%, Test: 14.93%
Hits@50
Run: 01, Epoch: 309, Loss: 0.2242, Valid: 64.91%, Test: 64.65%
Hits@100
Run: 01, Epoch: 309, Loss: 0.2242, Valid: 67.80%, Test: 80.97%
---
Hits@20
Run: 01, Epoch: 310, Loss: 0.2251, Valid: 54.06%, Test: 9.28%
Hits@50
Run: 01, Epoch: 310, Loss: 0.2251, Valid: 62.89%, Test: 41.24%
Hits@100
Run: 01, Epoch: 310, Loss: 0.2251, Valid: 66.29%, Test: 72.08%
---
Hits@20
Run: 01, Epoch: 311, Loss: 0.2249, Valid: 56.60%, Test: 8.09%
Hits@50
Run: 01, Epoch: 311, Loss: 0.2249, Valid: 63.53%, Tes

Hits@20
Run: 01, Epoch: 345, Loss: 0.2179, Valid: 60.81%, Test: 13.64%
Hits@50
Run: 01, Epoch: 345, Loss: 0.2179, Valid: 65.40%, Test: 59.85%
Hits@100
Run: 01, Epoch: 345, Loss: 0.2179, Valid: 68.02%, Test: 83.52%
---
Hits@20
Run: 01, Epoch: 346, Loss: 0.2203, Valid: 59.37%, Test: 16.69%
Hits@50
Run: 01, Epoch: 346, Loss: 0.2203, Valid: 64.71%, Test: 65.48%
Hits@100
Run: 01, Epoch: 346, Loss: 0.2203, Valid: 67.06%, Test: 82.83%
---
Hits@20
Run: 01, Epoch: 347, Loss: 0.2205, Valid: 59.44%, Test: 26.96%
Hits@50
Run: 01, Epoch: 347, Loss: 0.2205, Valid: 64.80%, Test: 67.36%
Hits@100
Run: 01, Epoch: 347, Loss: 0.2205, Valid: 68.10%, Test: 83.58%
---
Hits@20
Run: 01, Epoch: 348, Loss: 0.2207, Valid: 60.33%, Test: 5.72%
Hits@50
Run: 01, Epoch: 348, Loss: 0.2207, Valid: 65.51%, Test: 47.46%
Hits@100
Run: 01, Epoch: 348, Loss: 0.2207, Valid: 67.84%, Test: 83.62%
---
Hits@20
Run: 01, Epoch: 349, Loss: 0.2188, Valid: 60.76%, Test: 17.00%
Hits@50
Run: 01, Epoch: 349, Loss: 0.2188, Valid: 65.46%, 

Hits@20
Run: 01, Epoch: 383, Loss: 0.2148, Valid: 61.40%, Test: 21.85%
Hits@50
Run: 01, Epoch: 383, Loss: 0.2148, Valid: 66.11%, Test: 73.68%
Hits@100
Run: 01, Epoch: 383, Loss: 0.2148, Valid: 68.66%, Test: 86.29%
---
Hits@20
Run: 01, Epoch: 384, Loss: 0.2142, Valid: 59.61%, Test: 15.16%
Hits@50
Run: 01, Epoch: 384, Loss: 0.2142, Valid: 66.28%, Test: 58.73%
Hits@100
Run: 01, Epoch: 384, Loss: 0.2142, Valid: 68.71%, Test: 85.61%
---
Hits@20
Run: 01, Epoch: 385, Loss: 0.2144, Valid: 61.44%, Test: 12.66%
Hits@50
Run: 01, Epoch: 385, Loss: 0.2144, Valid: 66.08%, Test: 60.13%
Hits@100
Run: 01, Epoch: 385, Loss: 0.2144, Valid: 68.45%, Test: 84.82%
---
Hits@20
Run: 01, Epoch: 386, Loss: 0.2149, Valid: 62.55%, Test: 17.73%
Hits@50
Run: 01, Epoch: 386, Loss: 0.2149, Valid: 66.69%, Test: 73.67%
Hits@100
Run: 01, Epoch: 386, Loss: 0.2149, Valid: 68.78%, Test: 85.96%
---
Hits@20
Run: 01, Epoch: 387, Loss: 0.2155, Valid: 59.93%, Test: 20.35%
Hits@50
Run: 01, Epoch: 387, Loss: 0.2155, Valid: 65.55%,

In [26]:
with open('data/6/all_pos_valid_pred.pickle', 'wb') as handle:
    pickle.dump(all_pos_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/6/all_neg_valid_pred.pickle', 'wb') as handle:
    pickle.dump(all_neg_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/6/all_pos_test_pred.pickle', 'wb') as handle:
    pickle.dump(all_pos_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/6/all_neg_test_pred.pickle', 'wb') as handle:
    pickle.dump(all_neg_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [24]:
pos_valid_pred, neg_valid_pred, pos_test_pred, neg_test_pred

tensor([0.5568, 0.6987, 0.7311,  ..., 0.7074, 0.6788, 0.7220])

In [28]:
with open('data/4/pos_valid_pred.pickle', 'wb') as handle:
    pickle.dump(pos_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/4/neg_valid_pred.pickle', 'wb') as handle:
    pickle.dump(neg_valid_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/4/pos_test_pred.pickle', 'wb') as handle:
    pickle.dump(pos_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)
with open('data/4/neg_test_pred.pickle', 'wb') as handle:
    pickle.dump(neg_test_pred, handle, protocol=pickle.HIGHEST_PROTOCOL)