In [1]:
import re
import torch
from torch.utils.data import Dataset, DataLoader


def clean_str(string, TREC=False):
    """
    Tokenization/string cleaning for all datasets except for SST.
    Every dataset is lower cased except for TREC
    """
    string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
    string = re.sub(r"\'s", " \'s", string)
    string = re.sub(r"\'ve", " \'ve", string)
    string = re.sub(r"n\'t", " n\'t", string)
    string = re.sub(r"\'re", " \'re", string)
    string = re.sub(r"\'d", " \'d", string)
    string = re.sub(r"\'ll", " \'ll", string)
    string = re.sub(r",", " , ", string)
    string = re.sub(r"!", " ! ", string)
    string = re.sub(r"\(", " \( ", string)
    string = re.sub(r"\)", " \) ", string)
    string = re.sub(r"\?", " \? ", string)
    string = re.sub(r"\s{2,}", " ", string)
    return string.strip() if TREC else string.strip().lower()

In [2]:
def read_data(path):    
    features = []
    labels = []
    with open(path, encoding='utf8') as f:
        for line in f:
            y, _, X = line.partition(' ')
            y = int(y)
            
            # clean the text
            X = clean_str(X.strip())            
            # lower the word 
            X = X.lower() 
            
            features.append(X)
            labels.append(y)
        

    return features, labels

In [3]:
class FakeDataset(Dataset):
    def __init__(self, X, y, tokenizer):
        self.X = X
        self.y = y
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        X = self.X[index]
        y = self.y[index]
        '''
        Padding = ['longest', 'max_length', 'do_not_pad']
        '''
        inputs = self.tokenizer.encode_plus(
            X,
            None, 
            add_special_tokens = True,
            padding = 'max_length',
            return_token_type_ids = True,
            truncation = True # default false, cut the length of the text
        )
        '''
        The Var inputs, is basically a `dict`.
        - 1. Input_ids, the index of the word in the Bert model.
        - 2. Token_type_ids, The BERT model is to do classifcitation on pairs of sentences, i.e., document level. However, in this case all should be 1. 
        - 3. attention_mask, The attention mask is a binary indicating the position of the padded indices so that the model does not attend to them, e.g., if it is a real word, then the model will not attend to them.
        '''
        ids = inputs['input_ids']
        mask = inputs['attention_mask']

        return {
            'ids': torch.LongTensor(ids),
            'mask': torch.LongTensor(mask),
            'y': torch.LongTensor([y])
        }


In [4]:
def build_vocab(data_path, tokenizer):
    '''
    To do:
    1. Split the dataset: train, test, validation.
    2. Init the Fake dataset.
    3. Build the data Iterator. # maybe not, build the iter in the next function
    '''
    X, y = read_data(data_path)


    data_set = FakeDataset(X, y, tokenizer)

    
    return data_set


In [5]:
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
data_set = build_vocab('fake', tokenizer)

In [6]:
data_set[0]

{'ids': tensor([  101, 25086,  4456,  2817,  4751,  3643,  1997, 13441,  1996,  2047,
          2259,  2335, 15339, 19011,  1037,  2047,  2817,  4107,  2590,  2592,
          2000,  2273,  2040,  2024,  5307,  3697,  6567,  2055,  2129,  2000,
          7438, 25086,  4456,  1999,  2049,  2220,  5711,  1010,  2030,  3251,
          2000,  7438,  2009,  2012,  2035,  6950,  2628,  5022,  2005,  2184,
          2086,  1998,  2179,  2053,  4489,  1999,  2331,  6165,  2090,  2273,
          2040,  2020,  3856,  2012,  6721,  2000,  2031,  5970,  2030,  8249,
          1010,  2030,  2000, 11160,  2006,  3161,  8822,  1997,  1996,  4456,
          1010,  2007,  3949,  2069,  2065,  2009, 12506,  2331,  6165,  2013,
          1996,  4456,  2020,  2659,  2058,  2035,  2069,  2055,  1015,  3867,
          1997,  5022,  2184,  2086,  2044, 11616,  2021,  1996,  4295,  2001,
          2062,  3497,  2000,  5082,  1998,  3659,  1999,  1996,  2273,  2040,
         12132,  2005,  8822,  2738,  2084,  

In [7]:
def build_iter(data_set, batch_size):
    data_iter = DataLoader(data_set,
                           batch_size = batch_size, 
                           shuffle = True,
                           num_workers = 4)
    return data_iter

In [8]:
data_iter = build_iter(data_set, 1)

In [10]:
for _, data in enumerate(data_iter, 0):
    ids = torch.LongTensor(data['ids'])
    mask = torch.LongTensor(data['mask'])
    y = torch.LongTensor(data['y']).squeeze(1)
    break

Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/opt/anaconda3/lib/python3.9/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'FakeDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 