In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [2]:
class PosTagDataset(Dataset):
    # Static constant variable
    LABEL2INDEX = {'SC': 0, 'NN': 1, 'CC': 2, 'UH': 3, 'NEG': 4, 'PR': 5, 'IN': 6, 'NND': 7, 'DT': 8, 'CD': 9, 'MD': 10, 'NNP': 11, 'VB': 12, 'RP': 13, 'Z': 14, 'JJ': 15, 'SYM': 16, 'X': 17, 'RB': 18, 'OD': 19, 'WH': 20, 'FW': 21, 'PRP': 22}
    INDEX2LABEL = {0: 'SC', 1: 'NN', 2: 'CC', 3: 'UH', 4: 'NEG', 5: 'PR', 6: 'IN', 7: 'NND', 8: 'DT', 9: 'CD', 10: 'MD', 11: 'NNP', 12: 'VB', 13: 'RP', 14: 'Z', 15: 'JJ', 16: 'SYM', 17: 'X', 18: 'RB', 19: 'OD', 20: 'WH', 21: 'FW', 22: 'PRP'}
    NUM_LABELS = 23
    
    @staticmethod
    def load_dataset(path):
        # Read file
        data = open(path,'r').readlines()

        # Prepare buffer
        dataset = []
        sentence = []
        seq_label = []
        for line in data:
            if len(line.strip()) > 0:
                token, label = line[:-1].split('\t')
                sentence.append(token)
                seq_label.append(PosTagDataset.LABEL2INDEX[label])
            else:
                dataset.append({
                    'sentence': sentence,
                    'seq_label': seq_label
                })
                sentence = []
                seq_label = []
        return dataset
    
    def __init__(self, dataset_path, tokenizer):
        self.data = PosTagDataset.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        data = self.data[index]
        sentence, seq_label = data['sentence'], data['seq_label']
        subwords = []
        subword_to_word_indices = []
        for word_idx, word in enumerate(sentence):
            subword_list = self.tokenizer.encode(word, add_special_tokens=True)
            subword_to_word_indices += [word_idx for i in range(len(subword_list))]
            subwords += subword_list
        return np.array(subwords), np.array(subword_to_word_indices), np.array(seq_label)
    
    def __len__(self):
        return len(self.data)
    
        
class PosTagDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super(PosTagDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        max_tgt_len = max(map(lambda x: len(x[2]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        subword_to_word_indices_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        seq_label_batch = np.full((batch_size, max_tgt_len), -100, dtype=np.int64)
        
        for i, (subwords, subword_to_word_indices, seq_label) in enumerate(batch):
            subword_batch[i,:len(subwords)] = subwords
            subword_to_word_indices_batch[i,:len(subwords)] = subword_to_word_indices
            seq_label_batch[i,:len(seq_label)] = seq_label
            
        return subword_batch, subword_to_word_indices_batch, seq_label_batch

In [3]:
dataset_path = '../data/pos-idn/train_preprocess.txt'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
dataset = PosTagDataset(dataset_path, tokenizer)

In [4]:
for subwords, subword_to_word_indices, seq_label in dataset:
    print(subwords)
    print(subword_to_word_indices)
    print(seq_label)
    break

[17710  2527  4895  8525  2243 25933 25804  2078 20739  2050 19330  4430
 29181  2050]
[0 0 1 1 1 2 2 2 3 3 3 3 3 3]
[ 1  0 12  1]


In [5]:
loader = PosTagDataLoader(dataset, batch_size=32, num_workers=32)

In [6]:
%%time
for i, (subwords, subword_to_word_indices, seq_label) in enumerate(loader):
    print(subwords, subword_to_word_indices, seq_label)
    if i == 5:
        break

[[17710  2527  4895 ...     0     0     0]
 [21877  5017 18447 ...     0     0     0]
 [ 2022  5677 22068 ...     0     0     0]
 ...
 [26209  2226  9389 ...     0     0     0]
 [11382  3736  2232 ...     0     0     0]
 [11382  3736  2232 ...     0     0     0]] [[0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]] [[   1    0   12    1 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100
  -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100]
 [  11   11   11   12    1    0   12    1   15    0   12   18   15    6
     1   11   11   14 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100]
 [   9    1   12   18    9    1   12    6    1    1    1    2    1    6
    11   11   14 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100]
 [  11   11   12    9    1    2   12   12    9    1    1    6   11   11
    14 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100 -100]
 [   1    1    6   11   11   12    9