In [1]:
from transformers import AutoTokenizer, AutoModel
from copy import deepcopy
import json
import torch
from tqdm import tqdm
import os
import torch.nn as nn
os.environ['TOKENIZERS_PARALLELISM'] = 'true'

In [2]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
train_examples = json.load(open('data/MetaQA_data/qa_train_1hop.txt.json', 'r'))
valid_examples = json.load(open('data/MetaQA_data/qa_dev_1hop.txt.json', 'r'))

In [4]:
ent2idx = {}
idx2ent = {}
def get_entity_dict():
    global ent2idx
    global idx2ent
    with open('data/MetaQA_data/entities.dict', 'r') as f:
        lines = f.readlines()
        for line in lines:
            idx, ent = line.strip().split('\t')
            ent2idx[ent] = int(idx)
            idx2ent[int(idx)] = ent
get_entity_dict()

In [5]:
class QA_Example:
    def __init__(self, head, head_id, relation, tail, tail_id):
        self.head = head
        self.head_id = int(head_id)
        self.question = relation
        self.tail = tail
        self.tail_id = [int(x) for x in tail_id]

    def vectorize(self):
        hq_tokenized = tokenizer(text=self.head,
                                 text_pair=self.question,
                                 max_length=48,
                                 return_token_type_ids=True,
                                 truncation=True)
        
        # t_tokenized = tokenizer(text=self.tail,
        #                         max_length=10,
        #                         return_token_type_ids=True,
        #                         truncation=True)
        
        return {
                # 'head_id': self.head_id,
                'hq_token_ids': hq_tokenized['input_ids'],
                'hq_token_type_ids': hq_tokenized['token_type_ids'],
                'hq_mask': hq_tokenized['attention_mask'],
                'all_answer_id': self.tail_id,
                # 'all_tail_token_ids': t_tokenized['input_ids'],
                # 'all_tail_token_type_ids': t_tokenized['token_type_ids'],
                # 'all_tail_mask': t_tokenized['attention_mask']
                }
    
class Ent_Example:
    def __init__(self, entity):
        self.entity = entity

    def vectorize(self):
        ent_tokenized = tokenizer(text=self.entity,
                                max_length=16,
                                return_token_type_ids=True,
                                truncation=True)
        
        return {'ent_token_ids': ent_tokenized['input_ids'],
                'ent_token_type_ids': ent_tokenized['token_type_ids'],
                'ent_mask': ent_tokenized['attention_mask']}

In [6]:
class QA_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, examples):
        self.examples = []
        for i in range(len(examples)):
            self.examples.append(QA_Example(**examples[i]))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return self.examples[index].vectorize()
    
class Ent_Dataset(torch.utils.data.dataset.Dataset):
    def __init__(self, examples):
        self.examples = []
        for i in range(len(examples)):
            self.examples.append(Ent_Example(examples[i]))

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return self.examples[index].vectorize()

In [7]:
def padding_batches(batch: list, pad_token_id, return_mask=False):
    max_len = max([x.shape[0] for x in batch])
    batch_size = len(batch)
    token_batch = torch.LongTensor(batch_size, max_len).fill_(pad_token_id)
    mask_batch = torch.zeros(batch_size, max_len)
    for i, t in enumerate(batch):
        token_batch[i, :len(t)].copy_(t)
        if return_mask:
            mask_batch[i, :len(t)].fill_(1)
    if return_mask:
        return token_batch, mask_batch
    else:
        return token_batch

def _collate_fn(batch):
    batch_dict = {}
    batch_dict['hq_token_ids'], batch_dict['hq_mask'] = padding_batches([torch.LongTensor(x['hq_token_ids']) for x in batch],
                                                                        tokenizer.pad_token_id,
                                                                        return_mask=True)
    batch_dict['hq_token_type_ids'] = padding_batches([torch.LongTensor(x['hq_token_type_ids']) for x in batch], 0)
    batch_dict['all_answer_id'] = [x['all_answer_id'] for x in batch]
    return batch_dict

def _collate_fn_ent(batch):
    batch_dict = {}
    batch_dict['ent_token_ids'], batch_dict['ent_mask'] = padding_batches([torch.LongTensor(x['ent_token_ids']) for x in batch],
                                                                          tokenizer.pad_token_id,
                                                                          return_mask=True)
    batch_dict['ent_token_type_ids'] = padding_batches([torch.LongTensor(x['ent_token_type_ids']) for x in batch], 0)
    return batch_dict

In [8]:
train_dataset = QA_Dataset(train_examples)
train_dataloader = torch.utils.data.DataLoader(
                train_dataset,
                batch_size=64,
                shuffle=True,
                collate_fn=_collate_fn,
                num_workers=4)
valid_dataset = QA_Dataset(valid_examples)
valid_dataloader = torch.utils.data.DataLoader(
                valid_dataset,
                batch_size=64,
                shuffle=True,
                collate_fn=_collate_fn,
                num_workers=4)

In [9]:
def get_entity_loader():
    ent_list = []
    with open('data/MetaQA_data/entities.dict', 'r') as f:
        lines = f.readlines()
        for line in lines:
            ent_list.append(line.strip().split('\t')[1])
    ent_dataset = Ent_Dataset(ent_list)
    ent_dataloader = torch.utils.data.DataLoader(
        ent_dataset,
        batch_size=1024,
        shuffle=False,
        collate_fn=_collate_fn_ent,
        num_workers=4)
    return ent_dataloader
ent_dataloader = get_entity_loader()

In [10]:
@torch.no_grad()
def get_ent_embedding(model):
    ent_embed = []
    for ent_batch in ent_dataloader:
        batch = move_to_cuda(ent_batch)
        batch_embed = model.encode(model.ent_encoder,
                                   input_ids=batch['ent_token_ids'],
                                   token_type_ids=batch['ent_token_type_ids'],
                                   attention_mask=batch['ent_mask'])
        ent_embed.append(batch_embed)
    return torch.cat(ent_embed)

def move_to_cuda(sample):
    if len(sample) == 0:
        return {}

    def _move_to_cuda(maybe_tensor):
        if torch.is_tensor(maybe_tensor):
            return maybe_tensor.cuda(non_blocking=True)
        elif isinstance(maybe_tensor, dict):
            return {key: _move_to_cuda(value) for key, value in maybe_tensor.items()}
        elif isinstance(maybe_tensor, list):
            return [_move_to_cuda(x) for x in maybe_tensor]
        elif isinstance(maybe_tensor, tuple):
            return [_move_to_cuda(x) for x in maybe_tensor]
        else:
            return maybe_tensor

    return _move_to_cuda(sample)

def get_one_hot_label(answer_ids, ent_num=43234, alpha_smooth=0):
    label = torch.zeros(len(answer_ids), ent_num) + alpha_smooth / ent_num
    for i, answers in enumerate(answer_ids):
        for answer in answers:
            label[i, int(answer)] = 1 - alpha_smooth
    return label

In [11]:
class QAModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.ent_encoder = AutoModel.from_pretrained('bert-base-uncased').cuda()
        self.ent_encoder.load_state_dict(torch.load('checkpoint/MetaQA_2023-04-23-0538.24/epoch2tail_bert.mdl'))
        self.qa_encoder = AutoModel.from_pretrained('bert-base-uncased').cuda()
        self.qa_encoder.load_state_dict(torch.load('checkpoint/MetaQA_2023-04-23-0538.24/epoch2hr_bert.mdl'))
        self.softmax = nn.Softmax(dim=1)
        self.ent_embed = get_ent_embedding(self)
    
    def encode(self, model, input_ids, token_type_ids, attention_mask):
        embed = model(input_ids=input_ids,
                      token_type_ids=token_type_ids,
                      attention_mask=attention_mask,
                      return_dict=True)
        embed = embed.last_hidden_state
        embed = self.convert_features(attention_mask, embed)
        return embed

    def forward(self, qa_batch_dict):
        qa_embed = self.encode(self.qa_encoder,
                               input_ids=qa_batch_dict['hq_token_ids'],
                               token_type_ids=qa_batch_dict['hq_token_type_ids'],
                               attention_mask=qa_batch_dict['hq_mask'])
        
        qa_embed = qa_embed[:, None, :]
        score = (qa_embed * self.ent_embed).sum(dim=2)
        return self.softmax(score)

    def convert_features(self, mask, last_hidden_state, convert_type='mean'):
        if convert_type=='first':
            output_vector = last_hidden_state[:,0,:]
        else:
            input_mask_expanded = mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
            sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
            sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-4)
            output_vector = sum_embeddings / sum_mask
        output_vector = nn.functional.normalize(output_vector, dim=1)
        return output_vector

In [12]:
model = QAModel().cuda()

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.bias', 'cls.seq_relationship.we

In [13]:
@torch.no_grad()
def evaluate_accuracy(model, dataloader):
    correct = 0
    total = 0
    for i, batch in enumerate(tqdm(dataloader)):
        pred_ans = model(move_to_cuda(batch)).argmax(dim=1)
        label = batch['all_answer_id']
        for j, ans in enumerate(pred_ans):
            if ans in label[j]:
                correct += 1
            total += 1
    return correct / total

In [26]:
@torch.no_grad()
def evaluate_mrr(model, dataloader):
    total = 0
    n = 0
    for i, batch in enumerate(tqdm(dataloader)):
        pred_ans = model(move_to_cuda(batch))
        label = batch['all_answer_id']
        for j, ans in enumerate(label):
            rank = min([(pred_ans[j]>=pred_ans[j, x]).sum() for x in ans])
            total += 1 / rank
            n += 1
    return total / n

In [35]:
loss_f = nn.BCELoss()
optimizer = torch.optim.AdamW(model.qa_encoder.parameters(), lr=5e-5, weight_decay=1e-4)
epochs = 1
acc = 0
mrr = 0
for epoch in range(epochs):
    model.train()
    with tqdm(total=len(train_dataloader)) as pbar:
        total_loss = 0
        for i, batch in enumerate(train_dataloader):
            x = model(move_to_cuda(batch))
            label = get_one_hot_label(batch['all_answer_id'], alpha_smooth=0.0)
            optimizer.zero_grad()
            loss = loss_f(x, label.cuda())
            loss.backward()
            total_loss += loss.item()
            avg_loss = total_loss / (i + 1)
            optimizer.step()
            pbar.set_description(f"acc:{acc:.2f}, mrr:{mrr:.2f}, loss:{loss.item():.4f}, avg loss:{avg_loss:.3f}")
            pbar.update(1)
    model.eval()
    acc = evaluate_accuracy(model, valid_dataloader)
    mrr = evaluate_mrr(model, valid_dataloader)
print(f"acc:{acc:.2f}")
torch.save(model.state_dict(), f'qa1hopmodel_acc{acc:.2f}')

acc:0.00, mrr:0.00, loss:0.0005, avg loss:0.000:   1%|          | 26/3266 [00:05<12:08,  4.45it/s]


KeyboardInterrupt: 