In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [2]:
class NerDataset(Dataset):
    # Static constant variable
    LABEL2INDEX = {'I-PERSON': 0, 'B-ORGANISATION': 1, 'I-ORGANISATION': 2, 'B-PLACE': 3, 'I-PLACE': 4, 'O': 5, 'B-PERSON': 6}
    INDEX2LABEL = {0: 'I-PERSON', 1: 'B-ORGANISATION', 2: 'I-ORGANISATION', 3: 'B-PLACE', 4: 'I-PLACE', 5: 'O', 6: 'B-PERSON'}
    NUM_LABELS = 7
    
    @staticmethod
    def load_dataset(path):
        # Read file
        data = open(path,'r').readlines()

        # Prepare buffer
        dataset = []
        sentence = []
        seq_label = []
        for line in data:
            if len(line.strip()) > 0:
                token, label = line[:-1].split('\t')
                sentence.append(token)
                seq_label.append(NerDataset.LABEL2INDEX[label])
            else:
                dataset.append({
                    'sentence': sentence,
                    'seq_label': seq_label
                })
                sentence = []
                seq_label = []
        return dataset
    
    def __init__(self, dataset_path, tokenizer):
        self.data = NerDataset.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        data = self.data[index]
        sentence, seq_label = data['sentence'], data['seq_label']
        subwords = []
        subword_to_word_indices = []
        for word_idx, word in enumerate(sentence):
            subword_list = self.tokenizer.encode(word, add_special_tokens=True)
            subword_to_word_indices += [word_idx for i in range(len(subword_list))]
            subwords += subword_list
        return np.array(subwords), np.array(subword_to_word_indices), np.array(seq_label)
    
    def __len__(self):
        return len(self.data)
    
        
class NerDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super(NerDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        max_tgt_len = max(map(lambda x: len(x[2]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        subword_to_word_indices_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        seq_label_batch = np.full((batch_size, max_tgt_len), -100, dtype=np.int64)
        
        for i, (subwords, subword_to_word_indices, seq_label) in enumerate(batch):
            subword_batch[i,:len(subwords)] = subwords
            subword_to_word_indices_batch[i,:len(subwords)] = subword_to_word_indices
            seq_label_batch[i,:len(seq_label)] = seq_label
            
        return subword_batch, subword_to_word_indices_batch, seq_label_batch

In [3]:
dataset_path = '../data/ner-grit/train_preprocess_0.txt'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
dataset = NerDataset(dataset_path, tokenizer)

In [4]:
for subwords, subword_to_word_indices, seq_label in dataset:
    print(subwords)
    print(subword_to_word_indices)
    print(seq_label)
    break

[ 3347  2226 11687  2050 11937 17157  2722 15488  2050 11265  4590  2072
  1018 10930  6292 11905 13320  2022  6820 24206 15125  2050  2273 27875
  2072 15488  2226 11265  4590  2072  1018 10930  6292 11905 13320  1012]
[ 0  0  1  1  2  2  3  4  4  5  5  5  6  7  7  7  7  8  8  8  9  9 10 10
 10 11 11 12 12 12 13 14 14 14 14 15]
[5 5 5 5 3 4 4 4 5 5 5 3 4 4 4 5]


In [5]:
loader = NerDataLoader(dataset, batch_size=32, num_workers=32)

In [6]:
%%time
for i, (subwords, subword_to_word_indices, seq_label) in enumerate(loader):
    print(subwords, subword_to_word_indices, seq_label)
    if i == 5:
        break

[[ 3347  2226 11687 ...     0     0     0]
 [ 2022 13639 10286 ...     0     0     0]
 [12183 22134 19817 ...     0     0     0]
 ...
 [11050  2078  7367 ...     0     0     0]
 [24595  1038 15006 ...     0     0     0]
 [ 7367 16078  4886 ...     0     0     0]] [[0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 1 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]] [[   5    5    5 ... -100 -100 -100]
 [   5    5    5 ... -100 -100 -100]
 [   5    5    5 ... -100 -100 -100]
 ...
 [   6    5    5 ... -100 -100 -100]
 [   6    0    5 ... -100 -100 -100]
 [   5    5    5 ... -100 -100 -100]]
[[ 2022 22381  2863 ...     0     0     0]
 [ 7370  2041  2905 ...     0     0     0]
 [ 7279  4305  3089 ...     0     0     0]
 ...
 [12183 22134 14255 ...     0     0     0]
 [28144  5162 25060 ...     0     0     0]
 [ 2474 12193  2566 ...     0     0     0]] [[0 0 0 ... 0 0 0]
 [0 1 2 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]] [[  