In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
%%capture
! pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
! pip install torch-sparse -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
! pip install torch-cluster -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
! pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-1.8.0+cu101.html
! pip install git+https://github.com/rusty1s/pytorch_geometric.git

In [None]:
import os.path as osp
import torch
from torch.nn import Linear, LSTM
from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score
from torch_geometric.datasets import JODIEDataset
from torch_geometric.nn import TGNMemory, TransformerConv
from torch_geometric.nn.models.tgn import (LastNeighborLoader, IdentityMessage, LastAggregator)
from torch_geometric.data import InMemoryDataset, TemporalData, download_url
from torch_geometric.nn.inits import zeros
import pandas as pd
import numpy as np
import torch.nn.functional as F
from torch import nn
from torch.nn import Parameter
from attr import attrs, attrib, Factory
import warnings
from sklearn.preprocessing import OneHotEncoder
import os.path as osp
import torch
import pandas
from torch_geometric.data import InMemoryDataset, TemporalData, download_url

class genericDataset(InMemoryDataset):

    def __init__(self, root, name, transform=None, pre_transform=None):
        self.name = name.lower()

        super(genericDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_dir(self):
        return osp.join(self.root, self.name, 'raw')

    @property
    def processed_dir(self):
        return osp.join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self):
        return f'{self.name}.csv'

    @property
    def processed_file_names(self):
        return 'data.pt'

    def download(self):
        pass

    def process(self):
        df = pandas.read_csv(self.raw_paths[0], skiprows=1, header=None, index_col = None)
        print(df.head())
        src = torch.from_numpy(df.iloc[:, 0].values).to(torch.long)
        dst = torch.from_numpy(df.iloc[:, 1].values).to(torch.long)
        dst += int(src.max()) + 1
        t = torch.from_numpy(df.iloc[:, 2].values).to(torch.long)
        msg = torch.from_numpy(df.iloc[:, 4:].values).to(torch.long)
        y = torch.from_numpy(df.iloc[:, 3].values).to(torch.long)
        # msg = torch.from_numpy(df.iloc[:, 4:].values).to(torch.float)

        data = TemporalData(src=src, dst=dst, t=t, msg=msg, y=y)

        if self.pre_transform is not None:
            data = self.pre_transform(data)

        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self):
        return f'{self.name.capitalize()}()'

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

path = osp.join('/content/drive/MyDrive/ceiling-graphs/TGN_ablation/Knowledge Graph', 'Data', 'JODIE')
dataset = genericDataset(path, name='wiki_large')
data = dataset[0].to(device)

In [None]:
# Ensure to only sample actual destination nodes as negatives.
min_dst_idx, max_dst_idx = int(data.dst.min()), int(data.dst.max())

train_data, val_data, test_data = data.train_val_test_split(
    val_ratio=0.15, test_ratio=0.15)

neighbor_loader = LastNeighborLoader(data.num_nodes, size=10, device=device)

class GraphAttentionEmbedding(torch.nn.Module):
    def __init__(self, in_channels, out_channels, msg_dim, time_enc):
        super(GraphAttentionEmbedding, self).__init__()
        self.time_enc = time_enc
        edge_dim = msg_dim + time_enc.out_channels
        self.conv = TransformerConv(in_channels, out_channels // 2, heads=2,
                                    dropout=0.1, edge_dim=edge_dim)

    def forward(self, x, last_update, edge_index, t, msg):
        rel_t = last_update[edge_index[0]] - t
        rel_t_enc = self.time_enc(rel_t.to(x.dtype))
        edge_attr = torch.cat([rel_t_enc, msg], dim=-1)
        return self.conv(x, edge_index, edge_attr)

class LinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(LinkPredictor, self).__init__()
        self.lin_src = Linear(in_channels, in_channels)
        self.lin_dst = Linear(in_channels, in_channels)
        self.lin_final = Linear(in_channels, out_channels)
        self.m = nn.LogSoftmax(dim=1)

    def forward(self, z_src, z_dst):
        h = self.lin_src(z_src) + self.lin_dst(z_dst)
        h = h.relu()
        return self.m(self.lin_final(h))

memory_dim = time_dim = embedding_dim = 100

memory = TGNMemory(
    data.num_nodes,
    data.msg.size(-1),
    memory_dim,
    time_dim,
    message_module=IdentityMessage(data.msg.size(-1), memory_dim, time_dim),
    aggregator_module=LastAggregator(),
).to(device)

gnn = GraphAttentionEmbedding(
    in_channels=memory_dim,
    out_channels=embedding_dim,
    msg_dim=data.msg.size(-1),
    time_enc=memory.time_enc,
).to(device)

link_pred = LinkPredictor(in_channels=embedding_dim, out_channels = 25).to(device)  # ********************Change here

optimizer = torch.optim.Adam(
    set(memory.parameters()) | set(gnn.parameters())
    | set(link_pred.parameters()), lr=0.0001)

criterion = torch.nn.NLLLoss()

# Helper vector to map global node indices to local ones.
assoc = torch.empty(data.num_nodes, dtype=torch.long, device=device)

In [None]:
def metrics(logits, y):

    _, perm = logits.sort(dim=1, descending=True)
    mask = (y.view(-1, 1) == perm)

    nnz = mask.nonzero(as_tuple=False)
    mr = (nnz[:, -1] + 1).to(torch.float).mean().item()
    mrr = (1 / (nnz[:, -1] + 1).to(torch.float)).mean().item()
    hits1 = mask[:, :1].sum().item() / y.size(0)
    hits3 = mask[:, :3].sum().item() / y.size(0)
    # hits10 = mask[:, :10].sum().item() / y.size(0)

    return torch.tensor([mrr, mr, hits1, hits3])


def train():
    memory.train()
    gnn.train()
    link_pred.train()

    memory.reset_state()  # Start with a fresh memory.
    neighbor_loader.reset_state()  # Start with an empty graph.

    total_loss = 0
    acc = []
    for batch in train_data.seq_batches(batch_size=200):
        optimizer.zero_grad()

        src, dst, t, msg, labels = batch.src, batch.dst, batch.t, batch.msg, batch.y

        n_id = torch.cat([src, dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)

        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])

        log_prob_rel = link_pred(z[assoc[src]], z[assoc[dst]])

        loss = criterion(log_prob_rel, labels)

        memory.update_state(src, dst, t, msg)
        neighbor_loader.insert(src, dst)

        loss.backward()
        optimizer.step()
        memory.detach()

        total_loss += float(loss) * batch.num_events
        y_pred = torch.argmax(log_prob_rel, dim=1)
        acc.append(accuracy_score(labels.cpu(), y_pred.cpu()))

    return total_loss / train_data.num_events, float(torch.tensor(acc).mean())


@torch.no_grad()
def test(inference_data):
    memory.eval()
    gnn.eval()
    link_pred.eval()

    acc = []
    result = torch.tensor([0, 0, 0, 0], dtype=torch.float)
    for batch in inference_data.seq_batches(batch_size=200):

        src, dst, t, msg, label = batch.src, batch.dst, batch.t, batch.msg, batch.y
    
        n_id = torch.cat([src, dst]).unique()
        n_id, edge_index, e_id = neighbor_loader(n_id)
        assoc[n_id] = torch.arange(n_id.size(0), device=device)

        z, last_update = memory(n_id)
        z = gnn(z, last_update, edge_index, data.t[e_id], data.msg[e_id])

        log_prob_rel = link_pred(z[assoc[src]], z[assoc[dst]])

        result += metrics(log_prob_rel, label) * src.size(0)
        
        memory.update_state(src, dst, t, msg)
        neighbor_loader.insert(src, dst)

        y_pred = torch.argmax(log_prob_rel, dim=1)
                
        acc.append(accuracy_score(label.cpu(), y_pred.cpu()))

    result = result / (inference_data.num_events)
    return float(torch.tensor(acc).mean()), result.tolist()

In [None]:
for epoch in range(1, 100):
    loss, train_acc = train()
    print(f'  Epoch: {epoch:02d}, Loss: {loss:.4f}, Acc: {train_acc:.4f}')
    val_acc, res = test(val_data)
    print('Epoch: {:02d}, MRR: {:.4f}, MR: {:.4f}, Hits@1: {:.4f}, Hits@3: {:.4f}'.format(epoch, *res))
    test_acc, res = test(test_data)
    print('Epoch: {:02d}, MRR: {:.4f}, MR: {:.4f}, Hits@1: {:.4f}, Hits@3: {:.4f}'.format(epoch, *res))
    print()

  Epoch: 01, Loss: 2.6416, Acc: 0.2515
Epoch: 01, MRR: 0.7651, MR: 2.1428, Hits@1: 0.6426, Hits@3: 0.8830
Epoch: 01, MRR: 0.6828, MR: 2.8621, Hits@1: 0.5546, Hits@3: 0.7542

  Epoch: 02, Loss: 2.1021, Acc: 0.3298
Epoch: 02, MRR: 0.7732, MR: 2.0541, Hits@1: 0.6501, Hits@3: 0.8841
Epoch: 02, MRR: 0.6824, MR: 2.8317, Hits@1: 0.5572, Hits@3: 0.7554

  Epoch: 03, Loss: 1.9861, Acc: 0.3447
Epoch: 03, MRR: 0.7695, MR: 2.0461, Hits@1: 0.6423, Hits@3: 0.8852
Epoch: 03, MRR: 0.6840, MR: 2.7664, Hits@1: 0.5545, Hits@3: 0.7586

  Epoch: 04, Loss: 1.9228, Acc: 0.3439
Epoch: 04, MRR: 0.7696, MR: 2.0532, Hits@1: 0.6435, Hits@3: 0.8873
Epoch: 04, MRR: 0.6870, MR: 2.7432, Hits@1: 0.5555, Hits@3: 0.7836

  Epoch: 05, Loss: 1.8615, Acc: 0.3546
Epoch: 05, MRR: 0.7720, MR: 2.0386, Hits@1: 0.6469, Hits@3: 0.8890
Epoch: 05, MRR: 0.6945, MR: 2.6925, Hits@1: 0.5596, Hits@3: 0.8018

  Epoch: 06, Loss: 1.7969, Acc: 0.3873
Epoch: 06, MRR: 0.7833, MR: 1.9937, Hits@1: 0.6647, Hits@3: 0.8953
Epoch: 06, MRR: 0.7077, 

In [None]:
m = nn.LogSoftmax(dim=1)
input = torch.randn(3, 5, requires_grad=True)
target = torch.tensor([1, 0, 4])
logs = m(input)

print(target)
print(logs)
      
# tensor([1, 0, 4])
# tensor([[-1.6392, -2.3468, -1.0096, -1.4705, -2.1542], 4
#         [-2.8297, -1.4175, -2.2059, -1.4539, -1.0361], 3
#         [-3.0506, -1.9580, -0.7626, -1.2412, -2.8820]] 3

tensor([1, 0, 4])
tensor([[-0.6964, -2.1560, -4.0682, -1.8711, -1.5381],
        [-2.4510, -2.4992, -0.9120, -0.9522, -3.1229],
        [-1.3784, -2.1717, -2.7236, -0.8271, -2.0318]],
       grad_fn=<LogSoftmaxBackward>)


In [None]:
_, perm = logs.sort(dim=1, descending=True)
mask = (target.view(-1, 1) == perm)
nnz = mask.nonzero(as_tuple=False)

print(perm)
print(target.view(-1, 1) == perm)
print(mask.nonzero(as_tuple=False))
print((nnz[:, -1] + 1).to(torch.float).mean().item())

# mask = (y.view(-1, 1) == perm)

# nnz = mask.nonzero(as_tuple=False)
# mrr = (1 / (nnz[:, -1] + 1).to(torch.float)).mean().item()

tensor([[0, 4, 3, 1, 2],
        [2, 3, 0, 1, 4],
        [3, 0, 4, 1, 2]])
tensor([[False, False, False,  True, False],
        [False, False,  True, False, False],
        [False, False,  True, False, False]])
tensor([[0, 3],
        [1, 2],
        [2, 2]])
3.3333332538604736
