In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
! pip install torch==2.1.0  torchvision==0.16.0 torchtext==0.16.0 torchaudio==2.1.0 --index-url https://download.pytorch.org/whl/cu121
#! pip install pyg_lib torch_scatter torch_sparse torch_cluster -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html # torch_spline_conv
! pip install torch_geometric
! pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cu121.html
#! pip install torch_sparse -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
#! pip install torch_scatter -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
#! pip install pyg_lib -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html
! pip install sentence-transformers
! pip install torcheval
! pip install matplotlib
! pip install pandas
! pip install tensorboard
! pip install weaviate-client

! pip install -U pip setuptools wheel
! pip install -U spacy
! python -m spacy download en_core_web_sm

In [None]:
from graph_sampler import get_datasets, equal_edgeweight_hgt_sampler, get_minibatch_count, add_reverse_edge_original_attributes_and_label_inplace, get_hgt_linkloader, get_single_minibatch_count, sampler_for_init

train_data, val_data, test_data = get_datasets(get_edge_attr=False, filter_top_k=True, top_k=15)


In [None]:
train_data

In [None]:
from torch_geometric.nn.kge import TransE

import math

import torch
import torch.nn.functional as F
from torch import Tensor

from torch_geometric.nn.kge import KGEModel

# adapted and taken from https://github.com/pyg-team/pytorch_geometric/blob/master/torch_geometric/nn/kge/transe.py

class TransE(KGEModel):
    r"""The TransE model from the `"Translating Embeddings for Modeling
    Multi-Relational Data" <https://proceedings.neurips.cc/paper/2013/file/
    1cecc7a77928ca8133fa24680a88d2f9-Paper.pdf>`_ paper.

    :class:`TransE` models relations as a translation from head to tail
    entities such that

    .. math::
        \mathbf{e}_h + \mathbf{e}_r \approx \mathbf{e}_t,

    resulting in the scoring function:

    .. math::
        d(h, r, t) = - {\| \mathbf{e}_h + \mathbf{e}_r - \mathbf{e}_t \|}_p

    .. note::

        For an example of using the :class:`TransE` model, see
        `examples/kge_fb15k_237.py
        <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/
        kge_fb15k_237.py>`_.

    Args:
        num_nodes (int): The number of nodes/entities in the graph.
        num_relations (int): The number of relations in the graph.
        hidden_channels (int): The hidden embedding size.
        margin (int, optional): The margin of the ranking loss.
            (default: :obj:`1.0`)
        p_norm (int, optional): The order embedding and distance normalization.
            (default: :obj:`1.0`)
        sparse (bool, optional): If set to :obj:`True`, gradients w.r.t. to the
            embedding matrices will be sparse. (default: :obj:`False`)
    """
    def __init__(
        self,
        num_nodes: int,
        num_relations: int,
        hidden_channels: int,
        margin: float = 1.0,
        p_norm: float = 1.0,
        sparse: bool = False,
    ):
        super().__init__(num_nodes, num_relations, hidden_channels, sparse)

        self.p_norm = p_norm
        self.margin = margin

        self.reset_parameters()

    def reset_parameters(self):
        bound = 6. / math.sqrt(self.hidden_channels)
        torch.nn.init.uniform_(self.node_emb.weight, -bound, bound)
        torch.nn.init.uniform_(self.rel_emb.weight, -bound, bound)
        F.normalize(self.rel_emb.weight.data, p=self.p_norm, dim=-1,
                    out=self.rel_emb.weight.data)

    def forward(
        self,
        head_embeddings: Tensor,
        rel_type,
        tail_embeddings: Tensor,
    ) -> Tensor:
        #head = self.node_emb(head_index)
        rel = self.rel_emb(rel_type)  # Amos: only learn the relation embeddings, others are learned with GNN
        #tail = self.node_emb(tail_index)

        head = F.normalize(head_embeddings, p=self.p_norm, dim=-1)
        tail = F.normalize(tail_embeddings, p=self.p_norm, dim=-1)
        # Calculate *negative* TransE norm:
        negative_norm = -((head + rel) - tail).norm(p=self.p_norm, dim=-1)
        return negative_norm

    
    def get_embedding(self,
                      embedding,
                      rel_type,
                        have_head_or_tail
                      ):
        rel = self.rel_emb(rel_type)
        embedding = F.normalize(embedding, p=self.p_norm, dim=-1)
        if have_head_or_tail == 'head':
            return embedding + rel
        else:
            return embedding - rel
    
    
    def loss(
        self,
        head_embeddings: Tensor,
        rel_type: Tensor,
        tail_embeddings: Tensor,
        labels: Tensor, # labels 0 or 1
    ) -> Tensor:
        pos_mask = labels == 1
        neg_mask = labels == 0
        
        pos_score = self(head_embeddings[pos_mask], rel_type, tail_embeddings[pos_mask])
        neg_score = self(head_embeddings[neg_mask], rel_type, tail_embeddings[neg_mask])
        loss = F.margin_ranking_loss(
            pos_score,
            neg_score,
            target=torch.ones_like(pos_score), # 1 for similarity, -1 for dissimilarity
            margin=self.margin,
        )
        return loss

In [None]:
from torch_geometric.nn import HGTConv, Linear
import torch 

class HGT(torch.nn.Module):
    def __init__(self, hidden_channels, out_channels, num_heads, num_layers, node_types, data_metadata):
        super().__init__()

        self.lin_dict = torch.nn.ModuleDict()
        for node_type in node_types:
            self.lin_dict[node_type] = Linear(-1, hidden_channels)

        self.convs = torch.nn.ModuleList()
        for _ in range(num_layers):
            conv = HGTConv(hidden_channels, hidden_channels, data_metadata,
                           num_heads, group='sum')
            self.convs.append(conv)

        self.lin = Linear(hidden_channels, out_channels)

    def forward(self, x_dict, edge_index_dict):
        x_dict = {
            node_type: self.lin_dict[node_type](x).relu_()
            for node_type, x in x_dict.items()
        }

        for conv in self.convs:
            x_dict = conv(x_dict, edge_index_dict)

        return x_dict
    
# if __name__ == '__main__':
    # model = HGT(hidden_channels=64, out_channels=4, num_heads=2, num_layers=1, node_types=data.node_types, data_metadata=data.metadata())

In [None]:
import torch
# from models.TransE import TransE
# from models.DistMult import DistMult
# from models.HGT import HGT
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class Model(torch.nn.Module):
    def __init__(self, gnn : torch.nn.Module, head :  torch.nn.Module, node_types, edge_types, ggn_output_dim, pnorm=1):
        super().__init__()
        # edge_type onehot lookup table with keys
        # node_type onehot lookup table with keys
        self.node_type_embedding = torch.nn.Embedding(len(node_types), ggn_output_dim) # hidden channels should be the output dim of gnn

        self.edge_types = edge_types
        for edge_type in edge_types:
            if edge_type[1].startswith('rev_'):
                self.edge_types.remove(edge_type)

        # create edge to int mapping
        self.edgeindex_lookup = {edge_type:torch.tensor(i)  for i, edge_type in enumerate(edge_types)}

        # hidden channels should be the output dim of gnn
        if head=='TransE':
            self.head = TransE(len(node_types), len(edge_types) , ggn_output_dim, p_norm= pnorm, margin=0.5)  # KGE head with loss function
        else:
            raise NotImplementedError

        self.gnn = gnn



    def forward(self, hetero_data1, target_edge_type, edge_label_index, edge_label, hetero_data2=None, get_head_fn='loss'):

        if hetero_data2 is not None:
            assert target_edge_type[0] != target_edge_type[2], 'when passing two data objects, the edge type has to contain two different node types'
            head_embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = self.gnn(hetero_data2.x_dict, hetero_data2.edge_index_dict)[target_edge_type[2]][edge_label_index[1,:]]
        else:
            assert target_edge_type[0] == target_edge_type[2], 'when passing one data object, the edge type has to contain the same node types'


            embeddings = self.gnn(hetero_data1.x_dict, hetero_data1.edge_index_dict)
            head_embeddings = embeddings[target_edge_type[0]][edge_label_index[0,:]]
            tail_embeddings = embeddings[target_edge_type[2]][edge_label_index[1,:]]
        

        edgeindex = self.edgeindex_lookup[target_edge_type]
        if get_head_fn=='loss':
            loss = self.head.loss(head_embeddings, edgeindex.to(device), tail_embeddings, edge_label)
            return loss
        elif get_head_fn=='forward':
            return self.head.forward(head_embeddings, edgeindex.to(device), tail_embeddings)


metadata = train_data.metadata()
# add selfloops
for node_type in train_data.node_types:
    metadata[1].append((node_type, 'self_loop', node_type))

out_channels = 256
hidden_channels = 256
num_heads = 8
num_layers = 3
pnorm = 2
head = 'TransE'
gnn = HGT(hidden_channels=out_channels, out_channels=out_channels, num_heads=num_heads, num_layers=num_layers, node_types=train_data.node_types, data_metadata=metadata)

model = Model(gnn, head=head, node_types=metadata[0], edge_types=metadata[1], ggn_output_dim=out_channels, pnorm=pnorm)
#torch_geometric.compile(model, dynamic=True)
model.to(device)


In [None]:
from tqdm.auto import tqdm
from datetime import datetime
import os 
batch_size = 32
num_node_types = len(train_data.node_types)
print('num_node_types', num_node_types)
one_hop_neighbors = (20 * batch_size)//num_node_types # per relationship type
two_hop_neighbors = (20 * 8 * batch_size)//num_node_types # per relationship type
three_hop_neighbors = (20 * 8 * 3 * batch_size)//num_node_types # per relationship type
num_neighbors = [one_hop_neighbors, two_hop_neighbors, three_hop_neighbors] # three_hop_neighbors
# num_neighbors [36, 363, 1454]

print('num_neighbors', num_neighbors)
print('avg_num_neighbors', [num_neighbors[0]/batch_size,num_neighbors[1]/batch_size,  num_neighbors[2]/batch_size if len(num_neighbors)==3 else 0 ])

train_sampler = equal_edgeweight_hgt_sampler(train_data, batch_size, True, 'triplet', 1, num_neighbors, num_workers=0, prefetch_factor=None, pin_memory=True)
val_sampler = equal_edgeweight_hgt_sampler(val_data, batch_size, True, 'triplet', 1, num_neighbors, num_workers=0, prefetch_factor=None, pin_memory=True)


learning_rate = 2e-4
# torch get optimizer by string name
optimizer = 'Adam'
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) #2e-15


# create a tensorboard writer
from torch.utils.tensorboard import SummaryWriter
neighbors = '_'.join([str(n) for n in num_neighbors])

from pathlib import Path
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
ROOT_FOLDER = 'cache/'
summary_folder = Path(ROOT_FOLDER+f'runs/all_hgt_{timestamp}_mg05_norm{pnorm}_lr{learning_rate}_bs{batch_size}_neigh_{neighbors}_h_{head}_hid_{hidden_channels}_out_{out_channels}_numh_{num_heads}_numl_{num_layers}'.replace('.', '_'))
print('summary_folder', summary_folder)
writer = SummaryWriter(summary_folder)
print('writer',summary_folder)
# make dir 

model.train()
start_epoch = 1
total_minibatches = get_minibatch_count(train_data, batch_size)
for epoch in range(start_epoch, start_epoch+1):
    for i, (same_nodetype, target_edge_type, minibatch) in tqdm(enumerate(train_sampler), total=total_minibatches):

        try:
            if i%3==0:
                optimizer.zero_grad()
            # batching is different depending on if node types in edge are same or different
            print(target_edge_type)
            if same_nodetype:

                minibatch, edge_label_index, edge_label, input_edge_ids, global_node_ids = minibatch
                #print(minibatch['jobs'].x.device, edge_label_index.device, edge_label.device)
                loss = model(minibatch.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device))
                #loss, pos, neg = model(minibatch, target_edge_type, edge_label_index, edge_label)
            else:
                try:
                    minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id, global_src_ids, global_dst_ids = minibatch
                except ValueError as err:
                    print('value error', err)
                    continue # for skill qual edges sometimes for some reason only 5 instead of 7 elements returned
                #print(minibatchpart1['jobs'].device, minibatchpart2['jobs'].device, edge_label_index.device, edge_label.device)
                loss = model(minibatchpart1.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device), minibatchpart2.to(device))


            loss.backward()
            if i%3==2:
                optimizer.step()

            total_samples_seen = i * batch_size
            writer.add_scalar('Loss/train', loss.item(), total_samples_seen)

            if i == total_minibatches-1:
                print(f'{i} loss: {loss.item():.4f}')
                writer.add_scalar('Epoch Loss/train', loss.item(), total_samples_seen)

            # print loss and minibatch in the same line
            print(f'{i} loss: {loss.item():.4f}', end='\r')

            if i % 300 == 0 or i == total_minibatches-1:
                model.eval()
                with torch.no_grad():
                    val_loss = 0
                    for _ in range(3):
                        try:
                            same_nodetype, target_edge_type, minibatch = next(val_sampler)
                        except StopIteration:
                            val_sampler = iter(val_sampler)
                            same_nodetype, target_edge_type, minibatch = next(val_sampler)

                        if same_nodetype:
                            minibatch, edge_label_index, edge_label, input_edge_ids, global_node_ids = minibatch
                            #print(minibatch['jobs'].x.device, edge_label_index.device, edge_label.device)
                            val_loss += model(minibatch.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device))
                            #loss, pos, neg = model(minibatch, target_edge_type, edge_label_index, edge_label)
                        else:
                            try:
                                minibatchpart1, minibatchpart2, edge_label_index, edge_label, input_edge_id, global_src_ids, global_dst_ids = minibatch
                            except ValueError:
                                continue

                            #print(minibatchpart1['jobs'].device, minibatchpart2['jobs'].device, edge_label_index.device, edge_label.device)
                            val_loss += model(minibatchpart1.to(device), target_edge_type, edge_label_index.to(device), edge_label.to(device), minibatchpart2.to(device))

                val_loss /= 3
                if i == 0:
                    writer.add_scalar('Epoch Loss/val', val_loss, total_samples_seen)
                    writer.add_scalar('Loss/val', val_loss, total_samples_seen)
                elif i == total_minibatches-1:
                    writer.add_scalar('Epoch Loss/val', val_loss, total_samples_seen)
                else:
                    writer.add_scalar('Loss/val', val_loss, total_samples_seen)


                print(f'val_loss: {val_loss:.4f}', end='\r')
                model.train()

            writer.flush()

            if i % 10000 == 0 or i == total_minibatches-1:
                folder = 'models'
                if not os.path.exists(folder):
                    os.makedirs(folder)

                
                run_folder = Path(summary_folder)
                if not os.path.exists(run_folder):
                    os.makedirs(run_folder)

                print('saving model to', run_folder)
                # save model and optimizer
                is_epoch = f'Ep{epoch}_' if i == total_minibatches-1 else ''
                torch.save({
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    }, run_folder/f'{is_epoch}model_samplesseen{total_samples_seen}.pt')

        except IndexError:
            print('indexerror')
            pass

writer.close()