# RGCN

In [3]:
#!/usr/bin/env python
# coding: utf-8

# # RGCN

# In[1]:


import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from transformers import BertTokenizer, BertModel



# from utils import uniform
class RGCN(torch.nn.Module):
    def __init__(self, num_entities, num_relations, id2entity, id2relation, num_bases, dropout, device, file_path, emb_type="scratch", hid_dim=768, long_text=False):
        super(RGCN, self).__init__()

        self.emb_type = emb_type
        
        self.id2entity = id2entity
        self.id2relation = id2relation
        self.device = device
        self.hid_dim = hid_dim
        
        if emb_type == "scratch":
            self.entity_embedding = nn.Embedding(num_entities, self.hid_dim)
            
        else:
#             self.tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
#             self.bert = BertModel.from_pretrained('bert-base-cased', output_hidden_states=True).to(device)
            self.tokenizer = BertTokenizer.from_pretrained('prajjwal1/bert-tiny')
            self.bert = BertModel.from_pretrained('prajjwal1/bert-tiny', output_hidden_states=True).to(device)
            self.entity2text_dict = pd.read_csv(file_path + "/entity2textlong.txt", sep="\t", header=None, names=["id", "desc"], dtype=str).set_index("id").to_dict()['desc'] \
                                    if long_text \
                                    else pd.read_csv(file_path + "/entity2text.txt", sep="\t", header=None, names=["id", "name"], dtype=str).set_index("id").to_dict()['name']
            #entity2textlong = pd.read_csv(file_path + "/entity2textlong.txt", sep="\t", header=None, names=["id", "desc"]).set_index("id")
            #entity2text = pd.read_csv(file_path + "/entity2text.txt", sep="\t", header=None, names=["id", "name"]).set_index("id")
            #self.entity2textlong_dict = entity2textlong.to_dict()['desc']
            #self.entity2text_dict = entity2text.to_dict()['name']
            
            print("Successfully Loaded BERT")
            
        self.relation_embedding = nn.Parameter(torch.Tensor(num_relations, self.hid_dim)).to(device)

        nn.init.xavier_uniform_(self.relation_embedding, gain=nn.init.calculate_gain('relu'))

        self.conv1 = RGCNConv(
            self.hid_dim, self.hid_dim, num_relations * 2, num_bases=num_bases)
        self.conv2 = RGCNConv(
            self.hid_dim, self.hid_dim, num_relations * 2, num_bases=num_bases)

        self.dropout_ratio = dropout

    def forward(self, entity, edge_index, edge_type, edge_norm):
        if self.emb_type == "scratch":
            x = self.entity_embedding(entity)
        else:
            texts = [self.entity2text_dict[self.id2entity[idx.item()]] for idx in entity]
            
            print("Obtained texts")
            name_last_layers, name_encodings = self.get_model_outputs(texts)
            print("Obtained encodings")
            x = self.output2embedding(name_last_layers, name_encodings, only_attention_mask=True)
            print("Successfully obtained embedding from BERT")
            
        x = F.relu(self.conv1(x, edge_index, edge_type, edge_norm))
        x = F.dropout(x, p = self.dropout_ratio, training = self.training)
        # x = self.conv2(x, edge_index, edge_type, edge_norm)
        # print("train5")
        return x

    def distmult(self, embedding, triplets):
#         s = embedding[triplets[:,0]]
#         r = self.relation_embedding[triplets[:,1]]
#         o = embedding[triplets[:,2]]
#         score = torch.sum(s * r * o, dim=1)
        
#         return score

        s = embedding[triplets[:,0]]
        relations = triplets[:,1]
        texts = [self.entity2text_dict[self.id2relation[relation.item()]] for relation in relations]
        name_last_layers, name_encodings = self.get_model_outputs(texts)
        r = self.output2embedding(name_last_layers, name_encodings, only_attention_mask=True)
        o = embedding[triplets[:,2]]
        score = torch.sum(s * r * o, dim=1)
        return score

    def score_loss(self, embedding, triplets, target):
        score = self.distmult(embedding, triplets)

        return F.binary_cross_entropy_with_logits(score, target)

    def reg_loss(self, embedding):
        return torch.mean(embedding.pow(2)) + torch.mean(self.relation_embedding.pow(2))
    
    def get_model_outputs(self,texts):
      # Use base model

      # WARN: Long sentencs are truncated to the first 512 tokens.
      # https://stackoverflow.com/questions/58636587/how-to-use-bert-for-long-text-classification/63413589#63413589
      encodings = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
      
      for item in encodings:
            encodings[item] = encodings[item].to(self.device)
     
      self.bert = self.bert.to(self.device)
        
      if args.test:
          with torch.no_grad():
            last_layers = self.bert(**encodings)      
      else:
          last_layers = self.bert(**encodings)

      
      
    
      # for item in encodings:
      #       encodings[item] = encodings[item].cuda()
      # last_layers = last_layers.cuda()
      return last_layers, encodings

    def output2embedding(self,last_layers, encodings, only_attention_mask=False):
      # Get an embedding for each sentence by averaging the embeddings of all tokens in the sentence.
      # Note:
      #   When passed to BERT, each sentence is padded to ensure all the sentences have the same number of tokens,
      #   so, the output of BERT includes embeddings for padding elements, which might become noise.
      #   https://huggingface.co/docs/transformers/pad_truncation
      # 
      #   To exclude embeddings of padding elements from the calculation of average, set only_attention_mask = True.
      if only_attention_mask:
        return (last_layers[0] * encodings["attention_mask"][:, :, None].expand(last_layers[0].shape)).sum(1).div(encodings["attention_mask"].sum(1, keepdim=True))
      else:
        return last_layers[0].mean(1)

class RGCNConv(MessagePassing):
    r"""The relational graph convolutional operator from the `"Modeling
    Relational Data with Graph Convolutional Networks"
    <https://arxiv.org/abs/1703.06103>`_ paper

    .. math::
        \mathbf{x}^{\prime}_i = \mathbf{\Theta}_{\textrm{root}} \cdot
        \mathbf{x}_i + \sum_{r \in \mathcal{R}} \sum_{j \in \mathcal{N}_r(i)}
        \frac{1}{|\mathcal{N}_r(i)|} \mathbf{\Theta}_r \cdot \mathbf{x}_j,

    where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types.
    Edge type needs to be a one-dimensional :obj:`torch.long` tensor which
    stores a relation identifier
    :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` for each edge.

    Args:
        in_channels (int): Size of each input sample.
        out_channels (int): Size of each output sample.
        num_relations (int): Number of relations.
        num_bases (int): Number of bases used for basis-decomposition.
        root_weight (bool, optional): If set to :obj:`False`, the layer will
            not add transformed root node features to the output.
            (default: :obj:`True`)
        bias (bool, optional): If set to :obj:`False`, the layer will not learn
            an additive bias. (default: :obj:`True`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.MessagePassing`.
    """

    def __init__(self, in_channels, out_channels, num_relations, num_bases,
                 root_weight=True, bias=True, **kwargs):
        super(RGCNConv, self).__init__(aggr='mean', **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.num_relations = num_relations
        self.num_bases = num_bases

        self.basis = nn.Parameter(torch.Tensor(num_bases, in_channels, out_channels))
        self.att = nn.Parameter(torch.Tensor(num_relations, num_bases))

        if root_weight:
            self.root = nn.Parameter(torch.Tensor(in_channels, out_channels))
        else:
            self.register_parameter('root', None)

        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self.reset_parameters()

    def reset_parameters(self):
        size = self.num_bases * self.in_channels
        uniform(size, self.basis)
        uniform(size, self.att)
        uniform(size, self.root)
        uniform(size, self.bias)


    def forward(self, x, edge_index, edge_type, edge_norm=None, size=None):
        """"""
        return self.propagate(edge_index, size=size, x=x, edge_type=edge_type,
                              edge_norm=edge_norm)


    def message(self, x_j, edge_index_j, edge_type, edge_norm):
        w = torch.matmul(self.att, self.basis.view(self.num_bases, -1))

        # If no node features are given, we implement a simple embedding
        # loopkup based on the target node index and its edge type.
        if x_j is None:
            w = w.view(-1, self.out_channels)
            index = edge_type * self.in_channels + edge_index_j
            out = torch.index_select(w, 0, index)
        else:
            w = w.view(self.num_relations, self.in_channels, self.out_channels)
            w = torch.index_select(w, 0, edge_type)
            out = torch.bmm(x_j.unsqueeze(1), w).squeeze(-2)

        return out if edge_norm is None else out * edge_norm.view(-1, 1)

    def update(self, aggr_out, x):
        if self.root is not None:
            if x is None:
                out = aggr_out + self.root
            else:
                out = aggr_out + torch.matmul(x, self.root)

        if self.bias is not None:
            out = out + self.bias
        return out

    def __repr__(self):
        return '{}({}, {}, num_relations={})'.format(
            self.__class__.__name__, self.in_channels, self.out_channels,
            self.num_relations)


# # Utils

# In[2]:


import os
import math
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
from torch_scatter import scatter_add
from torch_geometric.data import Data

def uniform(size, tensor):
    bound = 1.0 / math.sqrt(size)
    if tensor is not None:
        tensor.data.uniform_(-bound, bound)

def load_data(file_path, max_relations):
    '''
        argument:
            file_path: ./data/FB15k-237
        
        return:
            entity2id, relation2id, train_triplets, valid_triplets, test_triplets
    '''
    print("load data from {}".format(file_path))

    with open(os.path.join(file_path, 'entities.dict')) as f:
        entity2id = dict()
        id2entity = {}

        for line in f:
            # For sampling
            #if np.random.rand(1) > 0.8:
            #    continue
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(file_path, 'relations.dict')) as f:
        relation2id = dict()
        id2relation = {}

        for i, line in enumerate(f):
            # For relation sampling
            if max_relations > -1 and i >= max_relations:
                print(f"Discard relations after {i-1}")
                break
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = entity

    train_triplets = read_triplets(os.path.join(file_path, 'train.txt'), entity2id, relation2id)
    valid_triplets = read_triplets(os.path.join(file_path, 'valid.txt'), entity2id, relation2id)
    test_triplets = read_triplets(os.path.join(file_path, 'test.txt'), entity2id, relation2id)

    print('num_entity: {}'.format(len(entity2id)))
    print('num_relation: {}'.format(len(relation2id)))
    print('num_train_triples: {}'.format(len(train_triplets)))
    print('num_valid_triples: {}'.format(len(valid_triplets)))
    print('num_test_triples: {}'.format(len(test_triplets)))

    return entity2id, id2entity, relation2id, id2relation, train_triplets, valid_triplets, test_triplets

def read_triplets(file_path, entity2id, relation2id):
    triplets = []

    with open(file_path) as f:
        for line in f:
            head, relation, tail = line.strip().split('\t')
            # For sampling
            if head not in entity2id or tail not in entity2id or relation not in relation2id:
                continue
            triplets.append((entity2id[head], relation2id[relation], entity2id[tail]))

    return np.array(triplets)

def sample_edge_uniform(n_triples, sample_size):
    """Sample edges uniformly from all the edges."""
    all_edges = np.arange(n_triples)
    return np.random.choice(all_edges, sample_size, replace=False)

def negative_sampling(pos_samples, num_entity, negative_rate):
    size_of_batch = len(pos_samples)
    num_to_generate = size_of_batch * negative_rate
    neg_samples = np.tile(pos_samples, (negative_rate, 1))
    labels = np.zeros(size_of_batch * (negative_rate + 1), dtype=np.float32)
    labels[: size_of_batch] = 1
    values = np.random.choice(num_entity, size=num_to_generate)
    choices = np.random.uniform(size=num_to_generate)
    subj = choices > 0.5
    obj = choices <= 0.5
    neg_samples[subj, 0] = values[subj]
    neg_samples[obj, 2] = values[obj]

    return np.concatenate((pos_samples, neg_samples)), labels

def edge_normalization(edge_type, edge_index, num_entity, num_relation):
    '''
        Edge normalization trick
        - one_hot: (num_edge, num_relation)
        - deg: (num_node, num_relation)
        - index: (num_edge)
        - deg[edge_index[0]]: (num_edge, num_relation)
        - edge_norm: (num_edge)
    '''
    one_hot = F.one_hot(edge_type, num_classes = 2 * num_relation).to(torch.float)
    deg = scatter_add(one_hot, edge_index[0], dim = 0, dim_size = num_entity)
    index = edge_type + torch.arange(len(edge_index[0])) * (2 * num_relation)
    edge_norm = 1 / deg[edge_index[0]].view(-1)[index]

    return edge_norm

def generate_sampled_graph_and_labels(triplets, sample_size, split_size, num_entity, num_rels, negative_rate):
    """
        Get training graph and signals
        First perform edge neighborhood sampling on graph, then perform negative
        sampling to generate negative samples
    """

    edges = sample_edge_uniform(len(triplets), sample_size)

    # Select sampled edges
    edges = triplets[edges]
    src, rel, dst = edges.transpose()
    uniq_entity, edges = np.unique((src, dst), return_inverse=True)
    src, dst = np.reshape(edges, (2, -1))
    relabeled_edges = np.stack((src, rel, dst)).transpose()

    # Negative sampling
    samples, labels = negative_sampling(relabeled_edges, len(uniq_entity), negative_rate)

    # further split graph, only half of the edges will be used as graph
    # structure, while the rest half is used as unseen positive samples
    split_size = int(sample_size * split_size)
    graph_split_ids = np.random.choice(np.arange(sample_size),
                                       size=split_size, replace=False)

    src = torch.tensor(src[graph_split_ids], dtype = torch.long).contiguous()
    dst = torch.tensor(dst[graph_split_ids], dtype = torch.long).contiguous()
    rel = torch.tensor(rel[graph_split_ids], dtype = torch.long).contiguous()

    # Create bi-directional graph
    src, dst = torch.cat((src, dst)), torch.cat((dst, src))
    rel = torch.cat((rel, rel + num_rels))

    edge_index = torch.stack((src, dst))
    edge_type = rel

    data = Data(edge_index = edge_index)
    data.entity = torch.from_numpy(uniq_entity)
    data.edge_type = edge_type
    data.edge_norm = edge_normalization(edge_type, edge_index, len(uniq_entity), num_rels)
    data.samples = torch.from_numpy(samples)
    data.labels = torch.from_numpy(labels)

    return data

def build_test_graph(num_nodes, num_rels, triplets):
    src, rel, dst = triplets.transpose()

    src = torch.from_numpy(src)
    rel = torch.from_numpy(rel)
    dst = torch.from_numpy(dst)

    src, dst = torch.cat((src, dst)), torch.cat((dst, src))
    rel = torch.cat((rel, rel + num_rels))
    
    
    edge_index = torch.stack((src, dst))
    
    
    edge_type = rel

    data = Data(edge_index = edge_index)
    data.entity = torch.from_numpy(np.arange(num_nodes))
    data.edge_type = edge_type
    data.edge_norm = edge_normalization(edge_type, edge_index, num_nodes, num_rels)

    return data

def sort_and_rank(score, target):
    _, indices = torch.sort(score, dim=1, descending=True)
    indices = torch.nonzero(indices == target.view(-1, 1))
    indices = indices[:, 1].view(-1)
    return indices

# return MRR (filtered), and Hits @ (1, 3, 10)
def calc_mrr(embedding, w, test_triplets, all_triplets, hits=[]):
    with torch.no_grad():
        
        num_entity = len(embedding)

        ranks_s = []
        ranks_o = []

        head_relation_triplets = all_triplets[:, :2]
        tail_relation_triplets = torch.stack((all_triplets[:, 2], all_triplets[:, 1])).transpose(0, 1)

        for test_triplet in tqdm(test_triplets):

            # Perturb object
            subject = test_triplet[0]
            relation = test_triplet[1]
            object_ = test_triplet[2]

            subject_relation = test_triplet[:2]
            delete_index = torch.sum(head_relation_triplets == subject_relation, dim = 1)
            delete_index = torch.nonzero(delete_index == 2).squeeze()

            delete_entity_index = all_triplets[delete_index, 2].view(-1).numpy()
            perturb_entity_index = np.array(list(set(np.arange(num_entity)) - set(delete_entity_index)))
            perturb_entity_index = torch.from_numpy(perturb_entity_index)
            perturb_entity_index = torch.cat((perturb_entity_index, object_.view(-1)))
            
            emb_ar = embedding[subject] * w[relation]
            emb_ar = emb_ar.view(-1, 1, 1)

            emb_c = embedding[perturb_entity_index]
            emb_c = emb_c.transpose(0, 1).unsqueeze(1)
            
            out_prod = torch.bmm(emb_ar, emb_c)
            score = torch.sum(out_prod, dim = 0)
            score = torch.sigmoid(score)
            
            target = torch.tensor(len(perturb_entity_index) - 1)
            ranks_s.append(sort_and_rank(score, target))

            # Perturb subject
            object_ = test_triplet[2]
            relation = test_triplet[1]
            subject = test_triplet[0]

            object_relation = torch.tensor([object_, relation])
            delete_index = torch.sum(tail_relation_triplets == object_relation, dim = 1)
            delete_index = torch.nonzero(delete_index == 2).squeeze()

            delete_entity_index = all_triplets[delete_index, 0].view(-1).numpy()
            perturb_entity_index = np.array(list(set(np.arange(num_entity)) - set(delete_entity_index)))
            perturb_entity_index = torch.from_numpy(perturb_entity_index)
            perturb_entity_index = torch.cat((perturb_entity_index, subject.view(-1)))

            emb_ar = embedding[object_] * w[relation]
            emb_ar = emb_ar.view(-1, 1, 1)

            emb_c = embedding[perturb_entity_index]
            emb_c = emb_c.transpose(0, 1).unsqueeze(1)

            out_prod = torch.bmm(emb_ar, emb_c)
            score = torch.sum(out_prod, dim = 0)
            score = torch.sigmoid(score)

            target = torch.tensor(len(perturb_entity_index) - 1)
            ranks_o.append(sort_and_rank(score, target))

        ranks_s = torch.cat(ranks_s)
        ranks_o = torch.cat(ranks_o)

        ranks = torch.cat([ranks_s, ranks_o])
        ranks += 1 # change to 1-indexed

        mrr = torch.mean(1.0 / ranks.float())
        print("MRR (filtered): {:.6f}".format(mrr.item()))

        for hit in hits:
            avg_count = torch.mean((ranks <= hit).float())
            print("Hits (filtered) @ {}: {:.6f}".format(hit, avg_count.item()))
            
    return mrr.item()


# # Main

# In[6]:


import argparse
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm, trange

# from utils import load_data, generate_sampled_graph_and_labels, build_test_graph, calc_mrr
# from models import RGCN

def train(train_triplets, model, use_cuda, batch_size, split_size, negative_sample, reg_ratio, num_entities, num_relations):

    train_data = generate_sampled_graph_and_labels(train_triplets, batch_size, split_size, num_entities, num_relations, negative_sample)

    if use_cuda:
        device = torch.device('cuda')
        train_data.to(device)

    entity_embedding = model(train_data.entity, train_data.edge_index, train_data.edge_type, train_data.edge_norm)
    loss = model.score_loss(entity_embedding, train_data.samples, train_data.labels) + reg_ratio * model.reg_loss(entity_embedding)

    return loss

def valid(valid_triplets, model, test_graph, all_triplets):

    
    entity_embedding = model(test_graph.entity, test_graph.edge_index, test_graph.edge_type, test_graph.edge_norm)
    mrr = calc_mrr(entity_embedding, model.relation_embedding, valid_triplets, all_triplets, hits=[1, 3, 10])
    return mrr

def test(test_triplets, model, test_graph, all_triplets):

    entity_embedding = model(test_graph.entity, test_graph.edge_index, test_graph.edge_type, test_graph.edge_norm)
    mrr = calc_mrr(entity_embedding.cpu(), model.relation_embedding.cpu(), test_triplets.cpu(), all_triplets.cpu(), hits=[1, 3, 10])
    return mrr

def main(args):

    use_cuda = args.gpu >= 0 and torch.cuda.is_available()
    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    print(device)
    best_mrr = 0
    
    file_path = './data/' + args.dataset 
    
    entity2id, id2entity, relation2id, id2relation, train_triplets, valid_triplets, test_triplets = load_data(file_path, args.max_relations)
    all_triplets = torch.LongTensor(np.concatenate((train_triplets, valid_triplets, test_triplets))).to(device)

    
    valid_triplets = torch.LongTensor(valid_triplets).to(device)
    test_triplets = torch.LongTensor(test_triplets).to(device)

    model = RGCN(
        len(entity2id), len(relation2id), id2entity, id2relation,
        num_bases=args.n_bases, dropout=args.dropout,
        device=device, file_path=file_path,
        emb_type=args.emb_type, hid_dim=args.hid_dim, long_text=args.long_text
    )
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    model = model.to(device)
    # print(model)

    # if use_cuda:
    #     model.cuda()
    
    mx_loss = float("inf")

    for epoch in trange(1, (args.n_epochs + 1), desc='Epochs', position=0):

        model.train()
        optimizer.zero_grad()

        loss = train(train_triplets, model, use_cuda, batch_size=args.graph_batch_size, split_size=args.graph_split_size, 
            negative_sample=args.negative_sample, reg_ratio = args.regularization, num_entities=len(entity2id), num_relations=len(relation2id))
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_norm)
        optimizer.step()
        
        
        if loss < mx_loss:
            loss = mx_loss
            torch.save({'state_dict': model.state_dict(), 'epoch': epoch},'best_mrr_model.pth')

        
            
            
            

#             print("debug0")
            
#             tqdm.write("Train Loss {} at epoch {}".format(loss, epoch))

#             # if use_cuda:
#             #     model.cpu()
                
#             print("debug1")

#             model.eval()
#             valid_mrr = valid(valid_triplets, model, test_graph, all_triplets)
            
#             print(valid_mrr)
            
#             if valid_mrr > best_mrr:
#                 best_mrr = valid_mrr
#                 torch.save({'state_dict': model.state_dict(), 'epoch': epoch},
#                             'best_mrr_model.pth')
               
#             print("debug2")
                

#             if use_cuda:
#                 model.cuda()
    
#     if use_cuda:
#         model.cuda()

#     model.eval()
#     print(model)
    

    checkpoint = torch.load('best_mrr_model.pth')
    model.load_state_dict(checkpoint['state_dict'])

    
    test_triplets = test_triplets.to(device)
    model = model.to(device)
    
    all_triplets = all_triplets.to(device)
    
    print(test_triplets.shape)
    print(all_triplets.shape)
    
    print("Everyting in GPU")
    model = model.eval()
    
    test_graph = build_test_graph(len(entity2id), len(relation2id), train_triplets[:1000,:])
    test_graph = test_graph.to(device)
    #test_graph = build_test_graph(len(entity2id), len(relation2id), train_triplets[:10000,:])
#     test_graph = build_test_graph(len(entity2id), len(relation2id), train_triplets)
    test(test_triplets[:1000], model, test_graph, all_triplets[:1000])
    # test(test_triplets, model, test_graph, all_triplets)

if __name__ == '__main__':

    parser = argparse.ArgumentParser(description='RGCN')
    
    parser.add_argument("-f")
    parser.add_argument("--test", type=bool, default=True)
    parser.add_argument("--dataset", type=str, default="FB15k-237")
    parser.add_argument("--emb_type", type=str, default="bert")
    parser.add_argument("--long_text", action="store_true")
    parser.add_argument("--max_relations", type=int, default=-1)
    parser.add_argument("--graph-batch-size", type=int, default=1024)
    parser.add_argument("--graph-split-size", type=float, default=0.5)
    parser.add_argument("--negative-sample", type=int, default=1)
    parser.add_argument("--n-epochs", type=int, default=1)
    parser.add_argument("--evaluate-every", type=int, default=10)
    
    parser.add_argument("--dropout", type=float, default=0.2)
    parser.add_argument("--gpu", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-2)
    parser.add_argument("--n-bases", type=int, default=4)
    
#     parser.add_argument("--hid_dim", type=int, default=768)
    parser.add_argument("--hid_dim", type=int, default=128)
    parser.add_argument("--regularization", type=float, default=1e-2)
    parser.add_argument("--grad-norm", type=float, default=1.0)

    args = parser.parse_args()
    print(args)

    main(args)

Namespace(dataset='FB15k-237', dropout=0.2, emb_type='bert', evaluate_every=10, f='/home/roy206/.local/share/jupyter/runtime/kernel-3578a416-edee-49c1-ab3a-2347c8e30dfa.json', gpu=0, grad_norm=1.0, graph_batch_size=1024, graph_split_size=0.5, hid_dim=128, long_text=False, lr=0.01, max_relations=-1, n_bases=4, n_epochs=1, negative_sample=1, regularization=0.01, test=True)
cuda
load data from ./data/FB15k-237
num_entity: 14541
num_relation: 237
num_train_triples: 272115
num_valid_triples: 17535
num_test_triples: 20466


Some weights of the model checkpoint at prajjwal1/bert-tiny were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Epochs:   0%|          | 0/1 [00:00<?, ?it/s]Asking to truncate to max_length but no maximum length is provided and the m

Successfully Loaded BERT
Obtained texts
Obtained encodings
Successfully obtained embedding from BERT


Epochs: 100%|██████████| 1/1 [00:00<00:00,  1.74it/s]


torch.Size([20466, 3])
torch.Size([310116, 3])
Everyting in GPU
Obtained texts


  0%|          | 2/1000 [00:00<01:08, 14.67it/s]

Obtained encodings
Successfully obtained embedding from BERT


100%|██████████| 1000/1000 [00:49<00:00, 20.34it/s]

MRR (filtered): 0.001092
Hits (filtered) @ 1: 0.000500
Hits (filtered) @ 3: 0.000500
Hits (filtered) @ 10: 0.001500





In [None]:
!ls