# Finetuning SMSA

In [7]:
import os, sys
sys.path.insert(0,'/home/karissa/indonlu/')

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
from tqdm import tqdm

from transformers import BertForSequenceClassification, BertTokenizer, BertConfig
from nltk.tokenize import TweetTokenizer, word_tokenize

from utils.forward_fn import forward_sequence_classification
from utils.metrics import document_sentiment_metrics_fn
from utils.data_utils import DocumentSentimentDataset, DocumentSentimentDataLoader

In [3]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.2f}'.format(key, value))
    return ' '.join(string_list)

In [4]:
# Set random seed
set_seed(26092020)

# Load Model

In [8]:
# Load Tokenizer and Config
tokenizer = BertTokenizer.from_pretrained('indobenchmark/indobert-base-p1')
config = BertConfig.from_pretrained('indobenchmark/indobert-base-p1')
config.num_labels = DocumentSentimentDataset.NUM_LABELS

# Instantiate model
model = BertForSequenceClassification.from_pretrained('indobenchmark/indobert-base-p1', config=config)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at indobenchmark/indobert-base-p1 and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [10]:
count_param(model)

124443651

# Prepare Dataset

In [11]:
train_dataset_path = '../dataset/smsa_doc-sentiment-prosa/train_preprocess.tsv'
valid_dataset_path = '../dataset/smsa_doc-sentiment-prosa/valid_preprocess.tsv'
test_dataset_path = '../dataset/smsa_doc-sentiment-prosa/test_preprocess_masked_label.tsv'

In [12]:
train_dataset = DocumentSentimentDataset(train_dataset_path, tokenizer, lowercase=True)
valid_dataset = DocumentSentimentDataset(valid_dataset_path, tokenizer, lowercase=True)
test_dataset = DocumentSentimentDataset(test_dataset_path, tokenizer, lowercase=True)

train_loader = DocumentSentimentDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=True)  
valid_loader = DocumentSentimentDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=False)  
test_loader = DocumentSentimentDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=16, num_workers=16, shuffle=False)

In [13]:
w2i, i2w = DocumentSentimentDataset.LABEL2INDEX, DocumentSentimentDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'positive': 0, 'neutral': 1, 'negative': 2}
{0: 'positive', 1: 'neutral', 2: 'negative'}


# Fine Tuning & Evaluation

In [14]:
optimizer = optim.Adam(model.parameters(), lr=5e-6)
model = model.cuda()

In [16]:
# Train
n_epochs = 8
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = document_sentiment_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))

(Epoch 1) TRAIN LOSS:0.1252 LR:0.00000500: 100%|██████████| 688/688 [02:19<00:00,  4.95it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.1252 ACC:0.96 F1:0.95 REC:0.94 PRE:0.95 LR:0.00000500


VALID LOSS:0.1839 ACC:0.93 F1:0.91 REC:0.90 PRE:0.91: 100%|██████████| 79/79 [00:05<00:00, 14.12it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 1) VALID LOSS:0.1839 ACC:0.93 F1:0.91 REC:0.90 PRE:0.91


(Epoch 2) TRAIN LOSS:0.0797 LR:0.00000500: 100%|██████████| 688/688 [02:18<00:00,  4.97it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.0797 ACC:0.97 F1:0.97 REC:0.97 PRE:0.97 LR:0.00000500


VALID LOSS:0.1770 ACC:0.94 F1:0.92 REC:0.91 PRE:0.93: 100%|██████████| 79/79 [00:05<00:00, 14.12it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 2) VALID LOSS:0.1770 ACC:0.94 F1:0.92 REC:0.91 PRE:0.93


(Epoch 3) TRAIN LOSS:0.0459 LR:0.00000500: 100%|██████████| 688/688 [02:18<00:00,  4.96it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.0459 ACC:0.99 F1:0.98 REC:0.98 PRE:0.98 LR:0.00000500


VALID LOSS:0.2139 ACC:0.94 F1:0.91 REC:0.92 PRE:0.91: 100%|██████████| 79/79 [00:05<00:00, 13.61it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 3) VALID LOSS:0.2139 ACC:0.94 F1:0.91 REC:0.92 PRE:0.91


(Epoch 4) TRAIN LOSS:0.0292 LR:0.00000500: 100%|██████████| 688/688 [02:18<00:00,  4.98it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.0292 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


VALID LOSS:0.2508 ACC:0.94 F1:0.91 REC:0.90 PRE:0.92: 100%|██████████| 79/79 [00:05<00:00, 14.20it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 4) VALID LOSS:0.2508 ACC:0.94 F1:0.91 REC:0.90 PRE:0.92


(Epoch 5) TRAIN LOSS:0.0205 LR:0.00000500: 100%|██████████| 688/688 [02:17<00:00,  5.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.0205 ACC:0.99 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


VALID LOSS:0.2399 ACC:0.94 F1:0.91 REC:0.91 PRE:0.92: 100%|██████████| 79/79 [00:05<00:00, 14.06it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:0.2399 ACC:0.94 F1:0.91 REC:0.91 PRE:0.92


(Epoch 6) TRAIN LOSS:0.0155 LR:0.00000500: 100%|██████████| 688/688 [02:19<00:00,  4.93it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.0155 ACC:1.00 F1:0.99 REC:0.99 PRE:0.99 LR:0.00000500


VALID LOSS:0.2668 ACC:0.93 F1:0.91 REC:0.91 PRE:0.91: 100%|██████████| 79/79 [00:05<00:00, 14.71it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 6) VALID LOSS:0.2668 ACC:0.93 F1:0.91 REC:0.91 PRE:0.91


(Epoch 7) TRAIN LOSS:0.0096 LR:0.00000500: 100%|██████████| 688/688 [02:18<00:00,  4.98it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.0096 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


VALID LOSS:0.2924 ACC:0.94 F1:0.91 REC:0.90 PRE:0.92: 100%|██████████| 79/79 [00:05<00:00, 14.17it/s]
  0%|          | 0/688 [00:00<?, ?it/s]

(Epoch 7) VALID LOSS:0.2924 ACC:0.94 F1:0.91 REC:0.90 PRE:0.92


(Epoch 8) TRAIN LOSS:0.0098 LR:0.00000500: 100%|██████████| 688/688 [02:17<00:00,  5.00it/s]
  0%|          | 0/79 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.0098 ACC:1.00 F1:1.00 REC:1.00 PRE:1.00 LR:0.00000500


VALID LOSS:0.2950 ACC:0.93 F1:0.91 REC:0.91 PRE:0.91: 100%|██████████| 79/79 [00:05<00:00, 14.66it/s]


(Epoch 8) VALID LOSS:0.2950 ACC:0.93 F1:0.91 REC:0.91 PRE:0.91


In [17]:
# Evaluate on test
model.eval()
torch.set_grad_enabled(False)

total_loss, total_correct, total_labels = 0, 0, 0
list_hyp, list_label = [], []

pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    _, batch_hyp, _ = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
    list_hyp += batch_hyp

# Save prediction
df = pd.DataFrame({'label':list_hyp}).reset_index()
df.to_csv('prediction.csv', index=False)

print(df)

100%|██████████| 32/32 [00:01<00:00, 20.91it/s]

     index     label
0        0  negative
1        1  negative
2        2  negative
3        3  negative
4        4  negative
..     ...       ...
495    495   neutral
496    496   neutral
497    497  positive
498    498  positive
499    499  positive

[500 rows x 2 columns]



