In [17]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import torch.nn as nn

In [10]:
data_preprocessed = pd.read_csv('waimai_10k.csv')
data_preprocessed

Unnamed: 0,label,review
0,1,很快，好吃，味道足，量大
1,1,没有送水没有送水没有送水
2,1,非常快，态度好。
3,1,方便，快捷，味道可口，快递给力
4,1,菜味道很棒！送餐很及时！
...,...,...
11982,0,以前几乎天天吃，现在调料什么都不放，
11983,0,昨天订凉皮两份，什么调料都没有放，就放了点麻油，特别难吃，丢了一份，再也不想吃了
11984,0,"凉皮太辣,吃不下都"
11985,0,本来迟到了还自己点！！！


In [18]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

In [25]:
class waimai_datset(Dataset):
    def __init__(self, data):
        self.data = data
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        text = self.data.iloc[idx]['review']
        label = self.data.iloc[idx]['label']
        
        return text, label
    
    
def collate_fn(batch):
    texts, labels = [], []
    for item in batch:
        texts.append(item[0])
        labels.append(item[1])
    inputs = tokenizer(texts, max_length=128, padding='max_length', truncation=True, return_tensors='pt')
    inputs['labels'] = torch.tensor(labels)
    return inputs

In [27]:
train_loader = DataLoader(waimai_datset(data_preprocessed), batch_size=8, shuffle=True, collate_fn=collate_fn)
next(enumerate(train_loader))[1]

{'input_ids': tensor([[ 101, 7478, 2382,  ...,    0,    0,    0],
        [ 101, 3717, 4215,  ...,    0,    0,    0],
        [ 101, 6843, 7623,  ...,    0,    0,    0],
        ...,
        [ 101, 6843, 7623,  ...,    0,    0,    0],
        [ 101, 7659, 2094,  ...,    0,    0,    0],
        [ 101, 1922,  671,  ...,    0,    0,    0]]), 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'labels': tensor([0, 0, 0, 1, 0, 0, 0, 0])}

In [7]:
dataset = load_dataset('csv', data_files='waimai_10k.csv', split='train')
dataset.filter(lambda x: x['review'] is not None)
dataset

Filter: 100%|██████████| 11987/11987 [00:00<00:00, 540573.53 examples/s]


Dataset({
    features: ['label', 'review'],
    num_rows: 11987
})

In [28]:
dataset[2]

{'label': 1, 'review': '非常快，态度好。'}

In [30]:
datasets = dataset.train_test_split(test_size=0.1)
datasets

DatasetDict({
    train: Dataset({
        features: ['label', 'review'],
        num_rows: 10788
    })
    test: Dataset({
        features: ['label', 'review'],
        num_rows: 1199
    })
})

In [36]:
def process_func(examples):
    tokenized_examples = tokenizer(examples['review'], max_length=128, padding='max_length', truncation=True)
    tokenized_examples['labels'] = examples['label']
    return tokenized_examples

tokenized_datasets = datasets.map(process_func, batched=True)
tokenized_datasets['train'][:2]

{'label': [1, 0],
 'review': ['味道还行~豆腐脑的卤淀粉可能多了', '送餐速度一般一小时左右,等待太漫长啊,味道还不错'],
 'input_ids': [[101,
   1456,
   6887,
   6820,
   6121,
   172,
   6486,
   5576,
   5554,
   4638,
   1307,
   3895,
   5106,
   1377,
   5543,
   1914,
   749,
   102,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0],
  [101,
   6843,
   7623,
   6862,
   2428,
   671,
   5663,
   671,
   2207,
   3198,
  