/
reader.py
90 lines (67 loc) · 3.63 KB
/
reader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from torch.utils.data import Dataset
from transformers import AutoTokenizer
from log import logger
from utils.reader_utils import get_ner_reader, extract_spans, _assign_ner_tags
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class CoNLLReader(Dataset):
def __init__(self, max_instances=-1, max_length=50, target_vocab=None, pretrained_dir='', encoder_model='xlm-roberta-large'):
self._max_instances = max_instances
self._max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_dir + encoder_model)
self.pad_token = self.tokenizer.special_tokens_map['pad_token']
self.pad_token_id = self.tokenizer.get_vocab()[self.pad_token]
self.sep_token = self.tokenizer.special_tokens_map['sep_token']
self.label_to_id = {} if target_vocab is None else target_vocab
self.instances = []
def get_target_size(self):
return len(set(self.label_to_id.values()))
def get_target_vocab(self):
return self.label_to_id
def __len__(self):
return len(self.instances)
def __getitem__(self, item):
return self.instances[item]
def read_data(self, data):
dataset_name = data if isinstance(data, str) else 'dataframe'
logger.info('Reading file {}'.format(dataset_name))
instance_idx = 0
for fields, metadata in get_ner_reader(data=data):
if self._max_instances != -1 and instance_idx > self._max_instances:
break
sentence_str, tokens_sub_rep, token_masks_rep, coded_ner_, gold_spans_, mask = self.parse_line_for_ner(fields=fields)
tokens_tensor = torch.tensor(tokens_sub_rep, dtype=torch.long)
tag_tensor = torch.tensor(coded_ner_, dtype=torch.long).unsqueeze(0)
token_masks_rep = torch.tensor(token_masks_rep)
mask_rep = torch.tensor(mask)
self.instances.append((tokens_tensor, mask_rep, token_masks_rep, gold_spans_, tag_tensor))
instance_idx += 1
logger.info('Finished reading {:d} instances from file {}'.format(len(self.instances), dataset_name))
def parse_line_for_ner(self, fields):
tokens_, ner_tags = fields[0], fields[-1]
sentence_str, tokens_sub_rep, ner_tags_rep, token_masks_rep, mask = self.parse_tokens_for_ner(tokens_, ner_tags)
gold_spans_ = extract_spans(ner_tags_rep)
coded_ner_ = [self.label_to_id[tag] if tag in self.label_to_id else self.label_to_id['O'] for tag in ner_tags_rep]
return sentence_str, tokens_sub_rep, token_masks_rep, coded_ner_, gold_spans_, mask
def parse_tokens_for_ner(self, tokens_, ner_tags):
sentence_str = ''
tokens_sub_rep, ner_tags_rep = [self.pad_token_id], ['O']
token_masks_rep = [False]
for idx, token in enumerate(tokens_):
if self._max_length != -1 and len(tokens_sub_rep) > self._max_length:
break
sentence_str += ' ' + ' '.join(self.tokenizer.tokenize(token.lower()))
rep_ = self.tokenizer(token.lower())['input_ids']
rep_ = rep_[1:-1]
tokens_sub_rep.extend(rep_)
# if we have a NER here, in the case of B, the first NER tag is the B tag, the rest are I tags.
ner_tag = ner_tags[idx]
tags, masks = _assign_ner_tags(ner_tag, rep_)
ner_tags_rep.extend(tags)
token_masks_rep.extend(masks)
tokens_sub_rep.append(self.pad_token_id)
ner_tags_rep.append('O')
token_masks_rep.append(False)
mask = [True] * len(tokens_sub_rep)
return sentence_str, tokens_sub_rep, ner_tags_rep, token_masks_rep, mask