In [1]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer

In [2]:
class AspectBasedSentimentAnalysisDataset(Dataset):
    # Static constant variable
    CATEGORY2INDEX = {'wifi': 0, 'kebersihan': 1, 'bau': 2, 'service': 3, 'linen': 4, 'ac': 5, 'sunrise_meal': 6, 'general': 7, 'air_panas': 8, 'tv': 9}
    INDEX2CATEGORY = {0: 'wifi', 1: 'kebersihan', 2: 'bau', 3: 'service', 4: 'linen', 5: 'ac', 6: 'sunrise_meal', 7: 'general', 8: 'air_panas', 9: 'tv'}
    NUM_CATEGORIES = 23

    SENTIMENT2INDEX = {'neg': 0, 'pos': 1}
    INDEX2SENTIMENT = {0: 'neg', 1: 'pos'}
    NUM_SENTIMENTS = 2
    
    @staticmethod
    def load_dataset(path):
        df = pd.read_csv(path)
        df['category'] = df['category'].apply(lambda cat: AspectBasedSentimentAnalysisDataset.CATEGORY2INDEX[cat])
        df['sentiment'] = df['sentiment'].apply(lambda sen: AspectBasedSentimentAnalysisDataset.SENTIMENT2INDEX[sen])
        return df
    
    def __init__(self, dataset_path, tokenizer):
        self.data = AspectBasedSentimentAnalysisDataset.load_dataset(dataset_path)
        self.tokenizer = tokenizer
        
    def __getitem__(self, index):
        data = self.data.loc[index,:]
        review, category, sentiment = data['review'], data['category'], data['sentiment']
        subwords = self.tokenizer.encode(review, add_special_tokens=True)
        return np.array(subwords), np.array(category), np.array(sentiment)
    
    def __len__(self):
        return len(self.data)
    
        
class AspectBasedSentimentAnalysisDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super(AspectBasedSentimentAnalysisDataLoader, self).__init__(*args, **kwargs)
        self.collate_fn = self._collate_fn
        
    def _collate_fn(self, batch):
        batch_size = len(batch)
        max_seq_len = max(map(lambda x: len(x[0]), batch))
        
        subword_batch = np.zeros((batch_size, max_seq_len), dtype=np.int64)
        category_batch = np.zeros((batch_size, 1), dtype=np.int64)
        sentiment_batch = np.zeros((batch_size, 1), dtype=np.int64)
        
        for i, (subwords, category, sentiment) in enumerate(batch):
            subword_batch[i,:len(subwords)] = subwords
            category_batch[i,0] = category
            sentiment_batch[i,0] = sentiment
            
        return subword_batch, category_batch, sentiment_batch

In [3]:
dataset_path = '../data/aspect-based-sentiment-analysis/train_preprocess.csv'
pretrained_model = 'bert-base-uncased'
tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
dataset = AspectBasedSentimentAnalysisDataset(dataset_path, tokenizer)

In [4]:
for subwords, category, sentiment in dataset:
    print(subwords)
    print(category)
    print(sentiment)
    break

[  101  6369  4017  2273 20684  5654  9126  4907  2033 12274 19895  2319
  1012  1012  1012 15125  4609  7842  4017  4638  1999  9389  2403  1012
  4002  1010 16137 19190  5292  7946 12183  3070 12193  7367 10278  2050
  2322  2273  4183  1010  9587  8747 16510  2121 26068  3211  1012  1012
   102]
3
0


In [5]:
loader = AspectBasedSentimentAnalysisDataLoader(dataset, batch_size=4, num_workers=4)

In [6]:
%%time
for i, (subwords, category, sentiment) in enumerate(loader):
    print(subwords)
    print(category)
    print(sentiment)
    if i == 5:
        break

[[  101  6369  4017  2273 20684  5654  9126  4907  2033 12274 19895  2319
   1012  1012  1012 15125  4609  7842  4017  4638  1999  9389  2403  1012
   4002  1010 16137 19190  5292  7946 12183  3070 12193  7367 10278  2050
   2322  2273  4183  1010  9587  8747 16510  2121 26068  3211  1012  1012
    102     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0     0     0     0
      0     0     0     0     0     0     0     0     0]
 [  101 10514 16782  2532  2273 20684  5654  9126  1010 27829  2906  2158
   4305 12943  4817 12849  4263  1012  1012 11320 25153  6457 12943  4817
  16906  3388  1010 21877 22923  7229 21790  2243  1012