In [1]:
from transformers import AutoTokenizer, AutoModel
from copy import deepcopy
import json
import torch
from tqdm import tqdm
import os
import torch.nn as nn
import random
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [2]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_examples = json.load(open('data/MetaQA/train.txt.json', 'r'))

In [3]:
hr_t_dict = {}
rt_h_dict = {}
def get_triple_dict():
    global hr_t_dict
    global rt_h_dict
    with open('data/MetaQA/train.txt', 'r') as f, open('data/MetaQA/valid.txt', 'r') as f2:
        lines = f.readlines() + f2.readlines()
        for line in lines:
            h, r, t = line.strip().split('\t')
            h, r, t = int(h), int(r), int(t)
            if (h, r) in hr_t_dict.keys():
                hr_t_dict[(h, r)].append(t)
            else:
                hr_t_dict[(h, r)] = [t]

            # if (t, r + 1) in hr_t_dict.keys():
            #     hr_t_dict[(t, r + 1)].append(h)
            # else:
            #     hr_t_dict[(t, r + 1)] = [h]

            if (r, t) in rt_h_dict.keys():
                rt_h_dict[(r, t)].append(h)
            else:
                rt_h_dict[(r, t)] = [h]

            # if (r + 1, h) in rt_h_dict.keys():
            #     rt_h_dict[(r + 1, h)].append(t)
            # else:
            #     rt_h_dict[(r + 1, h)] = [t]
                
get_triple_dict()

In [4]:
ent2idx = {}
idx2ent = {}
def get_entity_dict():
    global ent2idx
    global idx2ent
    with open('data/MetaQA_data/entities.dict', 'r') as f:
        lines = f.readlines()
        for line in lines:
            idx, ent = line.strip().split('\t')
            ent2idx[ent] = int(idx)
            idx2ent[int(idx)] = ent
get_entity_dict()
rel2idx = {}
idx2rel = {}
def get_relation_dict():
    global rel2idx
    global idx2rel
    with open('data/MetaQA_data/relations.dict', 'r') as f:
        lines = f.readlines()
        for line in lines:
            idx, rel = line.strip().split('\t')
            rel2idx[rel] = int(idx)
            idx2rel[int(idx)] = rel
get_relation_dict()

In [5]:
def reverse(example):
    reverse_example = {}
    reverse_example['head'] = example['tail']
    reverse_example['head_id'] = example['tail_id']
    reverse_example['tail'] = example['head']
    reverse_example['tail_id'] = example['head_id']
    reverse_example['relation'] = example['relation'] + '_reverse'
    if 'label' in example.keys():
        reverse_example['label'] = example['label']
    else:
        reverse_example['label'] = 1
    return reverse_example

def corrupt_examples(example, n=1):
    examples = []
    h = int(example['head_id'])
    r = rel2idx[example['relation']]
    true_t = hr_t_dict[(h, r)]
    # false_t = set(range(len(ent2idx.keys()))) - true_t
    corrupt_ts = []
    while len(corrupt_ts) < n:
        while True:
            rand = random.randint(0, len(ent2idx.keys()) - 1)
            if rand not in true_t:
                corrupt_ts.append(rand)
                break
    for corrupt_t in corrupt_ts:
        cor_example = {}
        cor_example['head'] = example['head']
        cor_example['head_id'] = example['head_id']
        cor_example['tail'] = idx2ent[corrupt_t]
        cor_example['tail_id'] = corrupt_t
        cor_example['label'] = 0
        cor_example['relation'] = example['relation']
        examples.append(cor_example)
    return examples


In [6]:
class KG_Example:
    def __init__(self, head, head_id, relation, tail, tail_id, label=1):
        self.head = head
        self.head_id = int(head_id)
        self.relation = relation
        self.tail = tail
        self.tail_id = int(tail_id)
        self.label = label

    def vectorize(self):
        h_tokenized = tokenizer(text=self.head,
                                 max_length=16,
                                 return_token_type_ids=True,
                                 truncation=True)
        
        r_tokenized = tokenizer(text=self.relation,
                                 max_length=48,
                                 return_token_type_ids=True,
                                 truncation=True)
        
        t_tokenized = tokenizer(text=self.tail,
                                max_length=16,
                                return_token_type_ids=True,
                                truncation=True)
        
        return {
                'head_id': self.head_id,
                'h_token_ids': h_tokenized['input_ids'],
                'h_token_type_ids': h_tokenized['token_type_ids'],
                'h_mask': h_tokenized['attention_mask'],
                'relation': self.relation,
                'r_token_ids': r_tokenized['input_ids'],
                'r_token_type_ids': r_tokenized['token_type_ids'],
                'r_mask': r_tokenized['attention_mask'],
                'tail_id': self.tail_id,
                't_token_ids': t_tokenized['input_ids'],
                't_token_type_ids': t_tokenized['token_type_ids'],
                't_mask': t_tokenized['attention_mask'],
                'label': self.label
                }
    
class Ent_Example:
    def __init__(self, entity):
        self.entity = entity

    def vectorize(self):
        ent_tokenized = tokenizer(text=self.entity,
                                max_length=16,
                                return_token_type_ids=True,
                                truncation=True)
        
        return {'ent_token_ids': ent_tokenized['input_ids'],
                'ent_token_type_ids': ent_tokenized['token_type_ids'],
                'ent_mask': ent_tokenized['attention_mask']}

In [7]:
class KG_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, examples, corrupt_num=2, add_reverse=False):
        self.examples = []
        for i in tqdm(range(len(examples))):
            self.examples.append(KG_Example(**examples[i]))
            cor_examples = corrupt_examples(examples[i], n=corrupt_num)
            for ex in cor_examples:
                self.examples.append(KG_Example(**ex))
            if add_reverse:
                self.examples.append(KG_Example(**(reverse(examples[i]))))
                cor_examples = corrupt_examples(reverse(examples[i]), n=corrupt_num)
                for ex in cor_examples:
                    self.examples.append(KG_Example(**ex))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return self.examples[index].vectorize()
    
class Ent_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, examples):
        self.examples = []
        for i in range(len(examples)):
            self.examples.append(Ent_Example(examples[i]))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return self.examples[index].vectorize()

In [8]:
def padding_batches(batch: list, pad_token_id, return_mask=False):
    max_len = max([x.shape[0] for x in batch])
    batch_size = len(batch)
    token_batch = torch.LongTensor(batch_size, max_len).fill_(pad_token_id)
    mask_batch = torch.zeros(batch_size, max_len)
    for i, t in enumerate(batch):
        token_batch[i, :len(t)].copy_(t)
        if return_mask:
            mask_batch[i, :len(t)].fill_(1)
    if return_mask:
        return token_batch, mask_batch
    else:
        return token_batch

def _collate_fn(batch):
    batch_dict = {}
    batch_dict['h_token_ids'], batch_dict['h_mask'] = padding_batches([torch.LongTensor(x['h_token_ids']) for x in batch],
                                                                      tokenizer.pad_token_id,
                                                                      return_mask=True)
    batch_dict['h_token_type_ids'] = padding_batches([torch.LongTensor(x['h_token_type_ids']) for x in batch], 0)
    batch_dict['r_token_ids'], batch_dict['r_mask'] = padding_batches([torch.LongTensor(x['r_token_ids']) for x in batch],
                                                                      tokenizer.pad_token_id,
                                                                      return_mask=True)
    batch_dict['r_token_type_ids'] = padding_batches([torch.LongTensor(x['r_token_type_ids']) for x in batch], 0)
    batch_dict['t_token_ids'], batch_dict['t_mask'] = padding_batches([torch.LongTensor(x['t_token_ids']) for x in batch],
                                                                      tokenizer.pad_token_id,
                                                                      return_mask=True)
    batch_dict['t_token_type_ids'] = padding_batches([torch.LongTensor(x['t_token_type_ids']) for x in batch], 0)
    batch_dict['label'] = torch.LongTensor([x['label'] for x in batch])
    return batch_dict

def _collate_fn_ent(batch):
    batch_dict = {}
    batch_dict['ent_token_ids'], batch_dict['ent_mask'] = padding_batches([torch.LongTensor(x['ent_token_ids']) for x in batch],
                                                                          tokenizer.pad_token_id,
                                                                          return_mask=True)
    batch_dict['ent_token_type_ids'] = padding_batches([torch.LongTensor(x['ent_token_type_ids']) for x in batch], 0)
    return batch_dict

In [9]:
train_dataset = KG_Dataset(train_examples, corrupt_num=2, add_reverse=False)
train_dataloader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=64,
                shuffle=True,
                collate_fn=_collate_fn,
                num_workers=4)

100%|██████████| 133582/133582 [00:01<00:00, 83110.09it/s] 


In [10]:
def get_entity_loader():
    ent_list = []
    with open('data/MetaQA_data/entities.dict', 'r') as f:
        lines = f.readlines()
        for line in lines:
            ent_list.append(line.strip().split('\t')[1])
    ent_dataset = Ent_Dataset(ent_list)
    ent_dataloader = torch.utils.data.DataLoader(
        ent_dataset,
        batch_size=1024,
        shuffle=False,
        collate_fn=_collate_fn_ent,
        num_workers=4)
    return ent_dataloader
ent_dataloader = get_entity_loader()

In [11]:
@torch.no_grad()
def get_ent_embedding(model):
    ent_embed = []
    for ent_batch in ent_dataloader:
        batch = move_to_cuda(ent_batch)
        batch_embed = model.encode(model.ent_encoder,
                                   input_ids=batch['ent_token_ids'],
                                   token_type_ids=batch['ent_token_type_ids'],
                                   attention_mask=batch['ent_mask'])
        ent_embed.append(batch_embed)
    return torch.cat(ent_embed)

def move_to_cuda(sample):
    if len(sample) == 0:
        return {}

    def _move_to_cuda(maybe_tensor):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.cuda(non_blocking=True)
        elif isinstance(maybe_tensor, dict):
            return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
        elif isinstance(maybe_tensor, list):
            return [_move_to_cuda(x) for x in maybe_tensor]
        elif isinstance(maybe_tensor, tuple):
            return [_move_to_cuda(x) for x in maybe_tensor]
        else:
            return maybe_tensor

    return _move_to_cuda(sample)

In [74]:
class KGModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ent_encoder = AutoModel.from_pretrained('bert-base-uncased').cuda()
        self.r_encoder = deepcopy(self.ent_encoder)
        self.softmax = nn.Softmax(dim=1)
        self.fc = nn.Linear(768, 200)
    
    def encode(self, model, input_ids, token_type_ids, attention_mask):
        embed = model(input_ids=input_ids,
                      token_type_ids=token_type_ids,
                      attention_mask=attention_mask,
                      return_dict=True)
        embed = embed.last_hidden_state
        embed = self.convert_features(attention_mask, embed)
        # embed = self.fc(embed)
        return embed

    def ComplEx_score(self, h_embed, r_embed, t_embed):
        reh, imh = torch.stack(h_embed.chunk(2, dim=1))
        rer, imr = torch.stack(r_embed.chunk(2, dim=1))
        ret, imt = torch.stack(t_embed.chunk(2, dim=1))
        score = (reh*rer*ret + imh*rer*imt + reh*imr*imt - imh*imr*ret).sum(dim=1)
        pred = torch.sigmoid(score)
        return pred
    
    def forward(self, batch_dict):
        h_embed = self.encode(self.ent_encoder,
                              input_ids=batch_dict['h_token_ids'],
                              token_type_ids=batch_dict['h_token_type_ids'],
                              attention_mask=batch_dict['h_mask'])
        r_embed = self.encode(self.r_encoder,
                              input_ids=batch_dict['r_token_ids'],
                              token_type_ids=batch_dict['r_token_type_ids'],
                              attention_mask=batch_dict['r_mask'])
        t_embed = self.encode(self.ent_encoder,
                              input_ids=batch_dict['t_token_ids'],
                              token_type_ids=batch_dict['t_token_type_ids'],
                              attention_mask=batch_dict['t_mask'])
        return self.ComplEx_score(h_embed, r_embed, t_embed)

    def convert_features(self, mask, last_hidden_state, convert_type='cls'):
        if convert_type=='cls':
            output_vector = last_hidden_state[:,0,:]
        else:
            input_mask_expanded = mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
            sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-4)
            output_vector = sum_embeddings / sum_mask
        # output_vector = nn.functional.normalize(output_vector, dim=1)
        return output_vector

In [75]:
model = KGModel().cuda()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [76]:
@torch.no_grad()
def get_rel_embedding(model):
    tokenized_rel = tokenizer(text=list(rel2idx.keys()),
                                max_length=48,
                                return_token_type_ids=True,
                                truncation=True)
    tokenized_rel['input_ids'] = [torch.LongTensor(x) for x in tokenized_rel['input_ids']]
    tokenized_rel['attention_mask'] = [torch.LongTensor(x) for x in tokenized_rel['attention_mask']]
    tokenized_rel['token_type_ids'] = [torch.LongTensor(x) for x in tokenized_rel['token_type_ids']]
    tokenized_rel['input_ids'] = torch.nn.utils.rnn.pad_sequence(tokenized_rel['input_ids'],
                                                                padding_value=tokenizer.pad_token_id,
                                                                batch_first=True)
    tokenized_rel['attention_mask'] = torch.nn.utils.rnn.pad_sequence(tokenized_rel['attention_mask'],
                                                                padding_value=0,
                                                                batch_first=True)
    tokenized_rel['token_type_ids'] = torch.nn.utils.rnn.pad_sequence(tokenized_rel['token_type_ids'],
                                                             padding_value=tokenizer.pad_token_type_id,
                                                             batch_first=True)
    rel_embed = model.encode(model.r_encoder,
                            input_ids=tokenized_rel['input_ids'].cuda(),
                            token_type_ids=tokenized_rel['token_type_ids'].cuda(),
                            attention_mask=tokenized_rel['attention_mask'].cuda())
    return rel_embed

In [77]:
@torch.no_grad()
def evaluate_mr(model, dataset, ent_embed, rel_embed):
    total_mr = 0
    total_mrr = 0
    n = 0
    for data in tqdm(dataset):
        h, r, t = data
        h_embed = ent_embed[h][None, :]
        r_embed = rel_embed[r][None, :]
        pred = model.ComplEx_score(h_embed, r_embed, ent_embed)
        mr = (pred >= pred[t]).sum()
        total_mr += mr.item()
        total_mrr += 1 / mr.item()
        n += 1
    return total_mr / n, total_mrr / n

In [16]:
def get_evaluate_dataset():
    dataset = []
    with open('data/MetaQA/valid.txt', 'r') as f:
        lines = f.readlines()
        for line in lines:
            h, r, t = line.strip().split('\t')
            h, r, t = int(h), int(r), int(t)
            dataset.append([h, r, t])
    return dataset

In [17]:
valid_d = get_evaluate_dataset()

In [66]:
# batch_dict = move_to_cuda(batch)
# # def ComplEx_score(mo, h_embed, r_embed, t_embed):
# #     reh, imh = torch.stack(h_embed.chunk(2, dim=1))
# #     rer, imr = torch.stack(r_embed.chunk(2, dim=1))
# #     ret, imt = torch.stack(t_embed.chunk(2, dim=1))
# #     score = (reh*rer*ret + imh*rer*imt + reh*imr*imt - imh*imr*ret).sum(dim=1)
# #     pred = torch.sigmoid(score)
# #     return pred


# h_embed = model.encode(model.ent_encoder,
#                         input_ids=batch_dict['h_token_ids'],
#                         token_type_ids=batch_dict['h_token_type_ids'],
#                         attention_mask=batch_dict['h_mask'])
# r_embed = model.encode(model.r_encoder,
#                         input_ids=batch_dict['r_token_ids'],
#                         token_type_ids=batch_dict['r_token_type_ids'],
#                         attention_mask=batch_dict['r_mask'])
# t_embed = model.encode(model.ent_encoder,
#                         input_ids=batch_dict['t_token_ids'],
#                         token_type_ids=batch_dict['t_token_type_ids'],
#                         attention_mask=batch_dict['t_mask'])
# rel_embed[:,:5]

In [81]:
loss_f = nn.BCELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-5, weight_decay=1e-4)
epochs = 5
for epoch in range(epochs):
    model.train()
    with tqdm(total=len(train_dataloader)) as pbar:
        total_loss = 0
        for i, batch in enumerate(train_dataloader):
            x = model(move_to_cuda(batch))
            label = batch['label']
            optimizer.zero_grad()
            loss = loss_f(x, label.cuda() * 1.0)
            loss.backward()
            total_loss += loss.item()
            avg_loss = total_loss / (i + 1)
            optimizer.step()
            pbar.set_description(f"loss:{loss.item():.4f}, avg loss:{avg_loss:.3f}")
            pbar.update(1)
    print('Evaluating Model')
    model.eval()
    ent_embed = get_ent_embedding(model)
    rel_embed = get_rel_embedding(model)
    # acc = evaluate_accuracy(model, valid_dataloader)
    mr, mrr = evaluate_mr(model, valid_d, ent_embed, rel_embed)
    print(f'mr: {mr:.2f}, mrr: {mrr:.2f}')
torch.save(model.state_dict(), f'save_model/kgmodel_mr{mr:.2f}')

loss:0.2673, avg loss:0.286:  85%|████████▍ | 5292/6262 [10:47<01:58,  8.18it/s]


KeyboardInterrupt: 

In [None]:
1800 0.265

In [78]:
ent_embed = get_ent_embedding(model)
rel_embed = get_rel_embedding(model)
evaluate_mr(model, valid_d, ent_embed, rel_embed)

100%|██████████| 4053/4053 [00:07<00:00, 523.03it/s]


(23240.984702689366, 9.267469350604106e-05)