# Dataset with DataCollator

In [2]:
from transformers import DataCollatorWithPadding
from datasets import load_dataset

In [9]:
dataset = load_dataset('csv', data_files='data/waimai_10k.csv', split='train')
dataset

Dataset({
    features: ['label', 'review'],
    num_rows: 11987
})

In [10]:
dataset = dataset.filter(lambda x: x['review'] is not None)
dataset

Filter:   0%|          | 0/11987 [00:00<?, ? examples/s]

Dataset({
    features: ['label', 'review'],
    num_rows: 11987
})

In [11]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained('../model/rbt3')

In [12]:
def process_function(examples):
    tokenized_examples = tokenizer(examples['review'], max_length=128, truncation=True)
    tokenized_examples['labels'] = examples['label']
    return tokenized_examples

In [19]:
tokenized_dataset = dataset.map(process_function, batched=True, remove_columns = dataset.column_names)
tokenized_dataset

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 11987
})

In [20]:
collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [21]:
from torch.utils.data import DataLoader

In [25]:
dl = DataLoader(tokenized_dataset, batch_size=4, collate_fn=collator, shuffle=True)

In [26]:
next(enumerate(dl))

(0,
 {'input_ids': tensor([[ 101, 6843, 7623, 2923, 2571, 4638,  511, 2218, 3221, 1922, 7410, 1391,
           749,  100,  100, 5101, 7649, 4294, 1166, 1914,  100, 1126,  725, 3766,
          3300, 5831,  100,  100,  102,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0],
         [ 101, 5023, 4638, 3198, 7313, 3300, 4157,  719, 8024, 1456, 6887, 2523,
           671, 5663,  511,  102,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
             0,    0,    0,

In [27]:
num = 0
for batch in dl:
    print(batch['input_ids'].size())
    num += 1
    if num > 10:
        break

torch.Size([4, 43])
torch.Size([4, 36])
torch.Size([4, 22])
torch.Size([4, 31])
torch.Size([4, 78])
torch.Size([4, 24])
torch.Size([4, 92])
torch.Size([4, 44])
torch.Size([4, 30])
torch.Size([4, 25])
torch.Size([4, 26])
