In [1]:
import json
from torch.utils.data import Dataset, DataLoader, Sampler, IterableDataset
import numpy as np
import torch
from tqdm import *
from transformers import DistilBertTokenizer, DistilBertModel
from torch.nn.utils.rnn import pad_sequence
# from allennlp.nn.util import batched_index_select
# from allennlp.modules import FeedForward
# import torch.nn.functional as F

In [2]:
# from allennlp.modules import FeedForward

In [3]:
from torch import nn

In [4]:
from collections import Counter

In [5]:
ner_list = ["N/A", 'FAC', 'WEA', 'LOC', 'VEH', 'GPE', 'ORG', 'PER']
re_list = ["N/A", 'ART', 'ORG-AFF', 'GEN-AFF', 'PHYS', 'PER-SOC', 'PART-WHOLE']

In [19]:
class JREDataset(Dataset):
    def __init__(self, data_path, tokenizer, max_span_len, ner_list, re_list):
        super().__init__()
        with open(data_path, 'r') as f:
            self.raw_data = [sent for sent in json.load(f) if len(sent['tokens']) > 0]
        self.ner_id2label = ner_list
        self.ner_label2id = {j:i for i, j in enumerate(ner_list)}
        self.re_id2label = re_list
        self.re_label2id = {j:i for i, j in enumerate(re_list)}
        self.data = []
        c_ner, c_re, c_span_len = Counter(), Counter(), Counter()
        self.c_ori_ner = 0
        for l in tqdm(self.raw_data):
            sub_token_mapping = []  # (index, len)
            refined_tokens = []
            cnt = 0
            # first re-tokenize the tokens with BertTokenizer
            self.c_ori_ner += len(l['entities'])
            for t in l['tokens']:
                subtokens = tokenizer.tokenize(t)
                tmp_len = len(subtokens)
                refined_tokens += subtokens
                sub_token_mapping.append((cnt, tmp_len))
                cnt += tmp_len
            
            refined_entities = {(sub_token_mapping[e[0]][0], sub_token_mapping[e[1]-1][0] + sub_token_mapping[e[1]-1][1]): e[2] for e in l['entities']}
            refined_relations = {(sub_token_mapping[r[0]][0], sub_token_mapping[r[1]-1][0] + sub_token_mapping[r[1]-1][1], \
                sub_token_mapping[r[2]][0], sub_token_mapping[r[3]-1][0] + sub_token_mapping[r[3]-1][1]): r[4] for r in l['relations']}
            c_span_len += Counter([j - i for i, j in refined_entities])
            spans, spans_label = [], []
            # span2id = {}
            cnt = 0
            for i in range(len(refined_tokens)):
                for j in range(i + 1, min(len(refined_tokens), i + max_span_len + 1)):
                    spans.append((i, j))
                    # span2id[(i, j)] = cnt
                    cnt += 1
                    spans_label.append(self.ner_label2id[refined_entities.get((i, j), 'N/A')])
            entity_pairs, entity_pairs_label = [], []
            for s, s_ in refined_entities.items():
                for t, t_ in refined_entities.items():
                    # if j >= i:
                    #     break
                    # print(s, s_, t, t_)
                    entity_pairs.append((s[0], s[1], t[0], t[1], self.ner_label2id[s_], self.ner_label2id[t_]))
                    entity_pairs_label.append(self.re_label2id[refined_relations.get((s[0], s[1], t[0], t[1]), 'N/A')])
            refined_tokens_ids = tokenizer.convert_tokens_to_ids(refined_tokens)
            self.data.append({
                'tokens': refined_tokens,
                'tokens_id': [tokenizer.cls_token_id, *refined_tokens_ids, tokenizer.sep_token_id],
                'entities': refined_entities,
                'relations': refined_relations,
                'spans': spans,
                'spans_label': spans_label,
                'entity_pairs': entity_pairs,
                'entity_pairs_label': entity_pairs_label
            })
            c_ner += Counter(list(refined_entities.values()))
            c_re += Counter(list(refined_relations.values()))
        n_ner = Counter([len(i['entities']) for i in self.data])
        n_re = Counter([len(i['relations']) for i in self.data])
        print("entity label stats:", dict(c_ner))
        print("relation label stats:", dict(c_re))
        print("span len stats:", dict(c_span_len))
        print("num of entity stats:", dict(n_ner))
        print("num of relation stats:", dict(n_re))
            # need stats
            # for i in 
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

In [7]:
tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")

In [20]:
testset = JREDataset('data/test.ACE05.json', tokenizer, 8, ner_list, re_list)

100%|████████████████████████████████████████████████████████████████████████████| 2050/2050 [00:01<00:00, 1404.10it/s]

entity label stats: {'GPE': 1020, 'ORG': 837, 'PER': 2967, 'FAC': 291, 'LOC': 136, 'WEA': 109, 'VEH': 116}
relation label stats: {'PART-WHOLE': 182, 'PHYS': 278, 'ORG-AFF': 359, 'ART': 151, 'PER-SOC': 77, 'GEN-AFF': 104}
span len stats: {1: 4434, 2: 547, 3: 302, 4: 118, 5: 47, 6: 17, 8: 3, 13: 2, 7: 4, 11: 2}
num of entity stats: {0: 533, 3: 233, 5: 137, 4: 167, 2: 280, 6: 97, 7: 79, 8: 43, 11: 18, 1: 389, 9: 33, 10: 16, 13: 5, 12: 9, 15: 2, 14: 4, 16: 3, 17: 2}
num of relation stats: {0: 1453, 1: 295, 3: 73, 2: 167, 5: 15, 4: 35, 6: 4, 7: 4, 9: 1, 11: 1, 8: 2}





In [21]:
validset = JREDataset("data/valid.ACE05.json", tokenizer, 8, ner_list, re_list)

100%|████████████████████████████████████████████████████████████████████████████| 2424/2424 [00:01<00:00, 1466.39it/s]

entity label stats: {'GPE': 1265, 'ORG': 989, 'PER': 3431, 'FAC': 249, 'LOC': 156, 'VEH': 125, 'WEA': 123}
relation label stats: {'ORG-AFF': 365, 'PART-WHOLE': 162, 'PHYS': 278, 'ART': 96, 'PER-SOC': 106, 'GEN-AFF': 124}
span len stats: {1: 5151, 2: 629, 4: 118, 3: 357, 5: 44, 8: 4, 6: 21, 7: 6, 12: 3, 10: 1, 15: 1, 11: 2, 9: 1}
num of entity stats: {0: 635, 2: 339, 13: 5, 9: 32, 8: 51, 5: 166, 3: 287, 7: 82, 4: 199, 6: 110, 1: 455, 11: 22, 10: 27, 12: 4, 17: 1, 15: 3, 22: 1, 14: 3, 18: 1, 19: 1}
num of relation stats: {0: 1793, 6: 10, 1: 357, 4: 26, 3: 72, 2: 149, 7: 2, 5: 13, 8: 1, 9: 1}





In [22]:
trainset = JREDataset("data/train.ACE05.json", tokenizer, 8, ner_list, re_list)

100%|██████████████████████████████████████████████████████████████████████████| 10051/10051 [00:06<00:00, 1469.74it/s]

entity label stats: {'GPE': 5169, 'VEH': 678, 'PER': 14415, 'ORG': 3781, 'LOC': 827, 'WEA': 679, 'FAC': 921}
relation label stats: {'PART-WHOLE': 775, 'ART': 491, 'ORG-AFF': 1472, 'GEN-AFF': 511, 'PHYS': 1097, 'PER-SOC': 438}
span len stats: {1: 21184, 2: 2758, 3: 1568, 6: 95, 8: 24, 5: 204, 4: 548, 7: 46, 10: 10, 9: 10, 14: 2, 13: 4, 11: 5, 12: 8, 18: 1, 15: 2, 19: 1}
num of entity stats: {0: 2574, 12: 41, 11: 76, 10: 105, 9: 164, 3: 1129, 5: 653, 4: 808, 6: 476, 7: 371, 8: 218, 2: 1320, 1: 2060, 13: 21, 15: 6, 16: 7, 17: 5, 14: 12, 28: 1, 23: 1, 18: 2, 20: 1}
num of relation stats: {0: 7408, 4: 141, 3: 294, 1: 1443, 7: 10, 2: 666, 5: 60, 8: 2, 6: 23, 9: 1, 10: 3}





In [11]:
def my_collate_fn(batch):
    seq = pad_sequence([torch.tensor(i['tokens_id']) for i in batch], padding_value=0, batch_first=True).to('cpu')
    mask = (seq > 0).float()
    span_index = torch.zeros(len(batch), 2)
    cur_len = 0
    span_batch = []
    span_label_batch = []
    for i, j in enumerate(batch):
        span_batch += j['spans']
        span_index[i][0] = cur_len
        tmp_len = len(j['spans'])
        span_index[i][1] = tmp_len
        cur_len += tmp_len
        span_label_batch += j['spans_label']
    span_batch = torch.tensor(span_batch).to('cpu').long()
    span_index = span_index.to('cpu').long()
    span_label_batch = torch.tensor(span_label_batch).to('cpu')
    # spans = torch.tensor(inputs['spans'])
    return seq, mask, span_batch, span_label_batch, span_index

In [12]:
test_loader = DataLoader(testset, batch_size=16, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=my_collate_fn, prefetch_factor=2)

In [13]:
valid_loader = DataLoader(testset, batch_size=16, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=my_collate_fn, prefetch_factor=2)

In [14]:
train_loader = DataLoader(trainset, batch_size=32, shuffle=True, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=my_collate_fn, prefetch_factor=2)

In [15]:
encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_transform.weight', 'vocab_projector.bias', 'vocab_layer_norm.bias', 'vocab_transform.bias']
- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [16]:
class DistilBertNER(nn.Module):
    def __init__(self, encoder):
        super().__init__()
        self.encoder = encoder.to('cpu')
        self.ner_classifier = nn.Sequential(
            nn.Linear(768 * 2 + 0, 768),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(768, 8)
        )
    def forward(self, seq, mask, span_batch, span_index):
        # print(seq.shape)
        embs = self.encoder(seq, attention_mask=mask)[0][:, 1:, :]
        emb_ls = []
        for i, (j, k) in enumerate(span_index):
            # print(i, j, k)
            # j, k : the start index and length for spans of this batch
            spans = span_batch[j: j+k]  # (n_span, 2)
            x = embs[i, spans[:, 0], :]  # (n_span, 768)
            y = embs[i, spans[:, 1], :]
            # y = embs[i, span_batch[j+k][0]: span_batch[j+k][0] + span_batch[j+k][1], :]
            emb_ls.append(torch.cat((x, y), dim=-1))
        emb_ls = torch.cat(emb_ls, dim=0)
        return self.ner_classifier(emb_ls)

In [17]:
model = DistilBertNER(encoder)

In [18]:

criterion = torch.nn.CrossEntropyLoss()

In [47]:
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

In [27]:
for epoch in range(5):
    for step, i in tqdm(enumerate(train_loader)):
        # break
        seq, mask, span_batch, span_label_batch, span_index = i
        res = model(seq, mask, span_batch, span_index)
        loss = criterion(res, span_label_batch)    
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        model.zero_grad()
        # print(loss)
        # bs = len(span_label_batch)
        '''
        if step % 10 == 0:
            print(loss)
            valid = torch.sum(span_label_batch != 0)
            predicted = torch.argmax(res, dim=1)
            true_valid = torch.sum((span_label_batch == predicted)*(span_label_batch != 0))
        # print(res, span_label_batch)
            print(true_valid/valid, true_valid, valid)
        '''
    print("cur results on dev:", evaluate(model, valid_loader, validset.c_ori_ner))
print("on test set:", evaluate(model, test_loader, testset.c_ori_ner))

0it [00:00, ?it/s]
  0%|                                                                                          | 0/129 [00:00<?, ?it/s]

evaluating...


 64%|███████████████████████████████████████████████████▍                             | 82/129 [01:33<00:40,  1.05s/it]

NameError: name 'cor' is not defined

In [28]:
def evaluate(model, loader, n_total_ner):
    print("evaluating...")
    # c_time = time.time()
    a_cor, a_tot = 0, n_total_ner
    a_pre = 0
    l_cor = 0
    l_tot = 0
    l_pred = 0
    # l_total_cand = 0
    model.eval()
    for l in tqdm(loader):
        seq, mask, span_batch, span_label_batch, span_index = i
        with torch.no_grad():
            res = model(seq, mask, span_batch, span_index)
            l_tot += torch.sum(span_label_batch != 0)
            a_tot += span_label_batch.shape[0]
            predicted = torch.argmax(res, dim=1)
            l_pred += torch.sum(predicted != 0)
            a_cor += torch.sum(span_label_batch == predicted)
            l_cor += torch.sum((span_label_batch == predicted)*(span_label_batch != 0))
        
    acc = a_cor / a_tot
    print('all accuracy: %4f'%acc)
    print('for valid spans: cor: %d, pred: %d, tot: %d, cand tot: %d'%(l_cor, l_pred, n_total_ner, l_tot))
    p = l_cor / l_pred
    r = l_cor / n_total_ner
    f1 = 2 * (p * r) / (p + r + 1e-6)
    print('P: %.5f, R: %.5f, F1: %.5f'%(p, r, f1))
    
    model.train()
    return f1