In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [2]:
class AspectExtractionDataset(Dataset):
    # Static constant variable
    LABEL2INDEX = {'I-SENTIMENT': 0, 'O': 1, 'I-ASPECT': 2, 'B-SENTIMENT': 3, 'B-ASPECT': 4}
    INDEX2LABEL = {0: 'I-SENTIMENT', 1: 'O', 2: 'I-ASPECT', 3: 'B-SENTIMENT', 4: 'B-ASPECT'}
    NUM_LABELS = 5
    
    @staticmethod
    def load_dataset(path):
        # Read file
        data = open(path,'r').readlines()

        # Prepare buffer
        dataset = []
        sentence = []
        seq_label = []
        for line in data:
            if '\t' in line:
                token, label = line[:-1].split('\t')
                sentence.append(token)
                seq_label.append(AspectExtractionDataset.LABEL2INDEX[label])
            else:
                dataset.append({
                    'sentence': sentence,
                    'seq_label': seq_label
                })
                sentence = []
                seq_label = []
        return dataset
    
    def __init__(self, dataset_path, tokenizer):
        self.data = AspectExtractionDataset.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        data = self.data[index]
        sentence, seq_label = data['sentence'], data['seq_label']
        subwords = []
        subword_to_word_indices = []
        for word_idx, word in enumerate(sentence):
            subword_list = self.tokenizer.encode(word, add_special_tokens=True)
            subword_to_word_indices += [word_idx for i in range(len(subword_list))]
            subwords += subword_list
        return np.array(subwords), np.array(subword_to_word_indices), np.array(seq_label)
    
    def __len__(self):
        return len(self.data)
    
        
class AspectExtractionDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super(AspectExtractionDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        max_tgt_len = max(map(lambda x: len(x[2]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        subword_to_word_indices_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        seq_label_batch = np.full((batch_size, max_tgt_len), -100, dtype=np.int64)
        
        for i, (subwords, subword_to_word_indices, seq_label) in enumerate(batch):
            subword_batch[i,:len(subwords)] = subwords
            subword_to_word_indices_batch[i,:len(subwords)] = subword_to_word_indices
            seq_label_batch[i,:len(seq_label)] = seq_label
            
        return subword_batch, subword_to_word_indices_batch, seq_label_batch

In [3]:
dataset_path = '../data/aspect-extraction-review-airy/train.txt'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
dataset = AspectExtractionDataset(dataset_path, tokenizer)

In [4]:
for subwords, subword_to_word_indices, seq_label in dataset:
    x = (subwords, subword_to_word_indices, seq_label)
    print(subwords)
    print(subword_to_word_indices)
    print(seq_label)
    break

[27829  2906  2360  2050 15262  6358  9305  2050  4487  9353 14841 23597
  2022 12881  5575  5332 15502  1012  4907 26536  2050 15536  8873 12849
  2638  5705  2072 13970 24388 17079  4014  1012]
[ 0  0  1  1  2  3  3  3  4  5  6  6  7  7  7  7  8  9 10 11 11 12 12 13
 13 13 13 14 14 15 15 16]
[1 1 1 1 1 4 3 0 0 1 1 1 4 2 3 0 1]


In [14]:
loader = AspectExtractionDataLoader(dataset, batch_size=1024, num_workers=32)

In [15]:
%%time
for i, (subwords, subword_to_word_indices, seq_label) in enumerate(loader):
    print(subwords, subword_to_word_indices, seq_label)
    if i == 10:
        break

[[27829  2906  2360 ...     0     0     0]
 [ 8915  8737  4017 ...     0     0     0]
 [ 7929  2063  9748 ...     0     0     0]
 ...
 [26209  2226  8945 ...     0     0     0]
 [27829  2906  2022 ...     0     0     0]
 [20377  2050 16137 ...     0     0     0]] [[0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 [0 0 1 ... 0 0 0]] [[   1    1    1 ... -100 -100 -100]
 [   4    3    1 ... -100 -100 -100]
 [   3    0    1 ... -100 -100 -100]
 ...
 [   3    1    1 ... -100 -100 -100]
 [   4    3    1 ... -100 -100 -100]
 [   1    1    1 ... -100 -100 -100]]
[[19379  4430 26927 ...     0     0     0]
 [ 4487 11493  2072 ...     0     0     0]
 [18178  5480  1999 ...     0     0     0]
 ...
 [14841 23597  7367 ...     0     0     0]
 [ 8915  8737  4017 ...     0     0     0]
 [ 7658  3148 13970 ...     0     0     0]] [[0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]
 ...
 [0 0 1 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 1 ... 0 0 0]] [[  