<a href="https://colab.research.google.com/github/GitYCC/bert-minimal-tutorial/blob/master/chinese_sentiment_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Chinese Sentiment Classification

In [1]:
!git clone https://github.com/GitYCC/bert-minimal-tutorial.git

Cloning into 'bert-minimal-tutorial'...
remote: Enumerating objects: 23, done.[K
remote: Counting objects: 100% (23/23), done.[K
remote: Compressing objects: 100% (18/18), done.[K
remote: Total 23 (delta 7), reused 19 (delta 3), pack-reused 0[K
Unpacking objects: 100% (23/23), done.


In [2]:
%cd bert-minimal-tutorial

/content/bert-minimal-tutorial


In [3]:
!pip install -q -r requirements.txt

[?25l[K     |▋                               | 10kB 26.8MB/s eta 0:00:01[K     |█▎                              | 20kB 3.6MB/s eta 0:00:01[K     |██                              | 30kB 4.7MB/s eta 0:00:01[K     |██▋                             | 40kB 5.0MB/s eta 0:00:01[K     |███▏                            | 51kB 4.2MB/s eta 0:00:01[K     |███▉                            | 61kB 4.6MB/s eta 0:00:01[K     |████▌                           | 71kB 5.1MB/s eta 0:00:01[K     |█████▏                          | 81kB 5.5MB/s eta 0:00:01[K     |█████▉                          | 92kB 5.8MB/s eta 0:00:01[K     |██████▍                         | 102kB 5.7MB/s eta 0:00:01[K     |███████                         | 112kB 5.7MB/s eta 0:00:01[K     |███████▊                        | 122kB 5.7MB/s eta 0:00:01[K     |████████▍                       | 133kB 5.7MB/s eta 0:00:01[K     |█████████                       | 143kB 5.7MB/s eta 0:00:01[K     |█████████▋                

In [13]:
import os

import pandas as pd
import torch
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, random_split
from torch.nn.utils.rnn import pad_sequence
from transformers import BertTokenizer, BertForSequenceClassification
from tqdm import tqdm

from utils import RunningAverage

MODEL_NAME = 'bert-base-chinese'

## Dataloader

In [5]:
df = pd.read_csv('data/chinese_sentiment_classification.csv')
df = df.sample(frac=1).reset_index(drop=True)  # shuffle

In [6]:
df

Unnamed: 0,label,text
0,0,氣味畢較大，都是塑料的多。a柱盲區比較大，要多注意轉彎的多候，躁音挺大的，但也不像網上說的那...
1,0,中門關閉吃力
2,1,空間很大，買的時候就是看重；了他的空間大，裝東西，裝貨物，承載量很大，非常的實用，
3,0,頭頂位置偏小日間行車燈打不開，導航不是太會弄，不好操作啊，還不能取消，缺點，換其他的導航又不會，
4,1,帥到沒有朋友
...,...,...
69995,0,離合。低速行駛時，是有點尷尬，高速實用
69996,0,不滿意的地方有兩個，第一:駕駛座不支持高低調節。第二:第三排確實小了點。如果可以在加長些就無...
69997,0,手動變速箱匹配不是太順暢，離合偏重，也不是電子轉向也偏重。
69998,1,最滿意的就是車子的實用性，當初花幾萬塊錢入手的車皮本來就沒打算有多高端的表現，所以智尚s35...


In [7]:
class MultiClassDataset(Dataset):
    def __init__(self, tokenizer, df, max_len=512):
        self.tokenizer = tokenizer
        self.texts = []
        self.labels = []
        for _, row in df.iterrows():
            self.texts.append(row['text'])
            self.labels.append(row['label'])
        self.max_len = max_len

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        tokens = self.tokenizer.tokenize(text)
        tokens = tokens[:self.max_len-2]
        processed_tokens = ['[CLS]'] + tokens + ['[SEP]']

        input_ids = torch.tensor(self.tokenizer.convert_tokens_to_ids(processed_tokens))
        token_type_ids = torch.tensor([0] * len(processed_tokens))
        attention_mask = torch.tensor([1] * len(processed_tokens))

        label = torch.tensor(label)

        return input_ids, token_type_ids, attention_mask, label

    def __len__(self):
        return len(self.texts)


def create_mini_batch(samples):
    input_ids, token_type_ids, attention_mask, labels = list(zip(*samples))

    # zero pad 到同一序列長度
    input_ids = pad_sequence(input_ids, batch_first=True)
    token_type_ids = pad_sequence(token_type_ids, batch_first=True)
    attention_mask = pad_sequence(attention_mask, batch_first=True)
 
    labels = torch.stack(labels)

    return input_ids, token_type_ids, attention_mask, labels

In [8]:
tokenizer = BertTokenizer.from_pretrained(MODEL_NAME)

dataset = MultiClassDataset(tokenizer, df)

CUT_RATIO = 0.8
train_size = int(CUT_RATIO * len(dataset))
valid_size = len(dataset) - train_size
train_dataset, valid_dataset = random_split(dataset, [train_size, valid_size])

In [9]:
batch_size = 32

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    collate_fn=create_mini_batch,
    shuffle=True
)
valid_loader = DataLoader(
    dataset=valid_dataset,
    batch_size=batch_size,
    collate_fn=create_mini_batch,
)

## Model

In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device: {device}')

model = BertForSequenceClassification.from_pretrained(
    MODEL_NAME, 
    num_labels = 2, 
    output_attentions = False, 
    output_hidden_states = False,
    return_dict=True
)
model.to(device)

device: cuda


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

## Train

In [15]:
def train_batch(model, data, optimizer, device):
    model.train()
    input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

    outputs = model(
        input_ids=input_ids,
        token_type_ids=token_type_ids,
        attention_mask=attention_mask,
        labels=labels
    )
    loss = outputs.loss

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    return loss.item()

def evaluate(model, valid_loader, device):
    model.eval()

    loss = RunningAverage()
    acc = RunningAverage()

    with torch.no_grad():
        for data in tqdm(valid_loader, desc='evaluate'):
            input_ids, token_type_ids, attention_mask, labels = [d.to(device) for d in data]

            outputs = model(
                input_ids=input_ids,
                token_type_ids=token_type_ids,
                attention_mask=attention_mask,
                labels=labels
            )

            loss.add(outputs.loss.item())
            corrects = (outputs.logits.argmax(dim=-1) == labels).cpu().tolist()
            acc.add_all(corrects)

    return loss.get(), acc.get()

In [None]:
lr = 0.00001
max_iter = 100
show_per_iter = 10
valid_per_iter = 50
save_per_iter = 100
save_checkpoint_dir = 'models/'
model_prefix = 'cn_sentiment_class_'
reload_checkpoint = None

assert save_per_iter % valid_per_iter == 0

optimizer = optim.Adam(model.parameters(), lr=lr)

i = 1
is_running = True
train_loss = RunningAverage()
while is_running:
    for train_data in train_loader:
        loss = train_batch(model, train_data, optimizer, device)
        train_loss.add(loss)

        if i % show_per_iter == 0:
            print('train [{}]: loss={}'.format(i, train_loss.get()))
            train_loss.flush()

        if i % valid_per_iter == 0:
            loss, acc = evaluate(model, valid_loader, device)
            print(f'valid: loss={loss}, acc={acc}')

        if i % save_per_iter == 0:
            path = os.path.join(save_checkpoint_dir, model_prefix + f'loss{loss:.5}/')
            print(f'save model at {path}')
            model.save_pretrained(path)
        
        if i == max_iter:
            is_running = False
            break

        i += 1

train [10]: loss=0.13257569763809443
train [20]: loss=0.07600871361792087
train [30]: loss=0.10189774688333272
train [40]: loss=0.08387432228773832


evaluate:   0%|          | 0/438 [00:00<?, ?it/s]

train [50]: loss=0.11318453000858426


evaluate: 100%|██████████| 438/438 [01:22<00:00,  5.30it/s]


valid: loss=0.09331045881951372, acc=0.9674285714285714
train [60]: loss=0.16492213755846025
train [70]: loss=0.15872838124632835
train [80]: loss=0.13649716190993785
train [90]: loss=0.1124349880963564


evaluate:   0%|          | 1/438 [00:00<01:19,  5.49it/s]

train [100]: loss=0.10219061672687531


evaluate: 100%|██████████| 438/438 [01:22<00:00,  5.30it/s]


valid: loss=0.09794885570770256, acc=0.9635
save model at models/cn_sentiment_class_loss0.097949/
train [110]: loss=0.06595614422112703
train [120]: loss=0.09449335867539048
train [130]: loss=0.08556640562601388
train [140]: loss=0.06297143679112197


evaluate:   0%|          | 1/438 [00:00<01:21,  5.37it/s]

train [150]: loss=0.118714439868927


evaluate:  92%|█████████▏| 401/438 [01:15<00:07,  5.10it/s]