# Finetuning WReTe
WReTe is a Textual entailment dataset with a pair of input sentences and 2 possible labels: `Entail_or_Paraphrase` and `NotEntail`

In [1]:
import os, sys
sys.path.append('../')
os.chdir('../')

import random
import numpy as np
import pandas as pd
import torch
from torch import optim
import torch.nn.functional as F
from tqdm import tqdm

from transformers import BertForSequenceClassification, BertConfig, BertTokenizer
from nltk.tokenize import TweetTokenizer

from utils.forward_fn import forward_sequence_classification
from utils.metrics import document_sentiment_metrics_fn
from utils.data_utils import EntailmentDataset, EntailmentDataLoader

In [2]:
###
# common functions
###
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    
def count_param(module, trainable=False):
    if trainable:
        return sum(p.numel() for p in module.parameters() if p.requires_grad)
    else:
        return sum(p.numel() for p in module.parameters())
    
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

def metrics_to_string(metric_dict):
    string_list = []
    for key, value in metric_dict.items():
        string_list.append('{}:{:.2f}'.format(key, value))
    return ' '.join(string_list)

In [3]:
# Set random seed
set_seed(26092020)

# Load Model

In [4]:
# Load Tokenizer and Config
tokenizer = BertTokenizer.from_pretrained('indobenchmark/indobert-base-p1')
config = BertConfig.from_pretrained('indobenchmark/indobert-base-p1')
config.num_labels = EntailmentDataset.NUM_LABELS

# Instantiate model
model = BertForSequenceClassification.from_pretrained('indobenchmark/indobert-base-p1', config=config)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at indobenchmark/indobert-base-p1 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
model

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(50000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [6]:
count_param(model)

124442882

# Prepare Dataset

In [7]:
train_dataset_path = './dataset/wrete_entailment-ui/train_preprocess.csv'
valid_dataset_path = './dataset/wrete_entailment-ui/valid_preprocess.csv'
test_dataset_path = './dataset/wrete_entailment-ui/test_preprocess_masked_label.csv'

In [8]:
train_dataset = EntailmentDataset(train_dataset_path, tokenizer, lowercase=True)
valid_dataset = EntailmentDataset(valid_dataset_path, tokenizer, lowercase=True)
test_dataset = EntailmentDataset(test_dataset_path, tokenizer, lowercase=True)

train_loader = EntailmentDataLoader(dataset=train_dataset, max_seq_len=512, batch_size=32, num_workers=16, shuffle=True)  
valid_loader = EntailmentDataLoader(dataset=valid_dataset, max_seq_len=512, batch_size=32, num_workers=16, shuffle=False)  
test_loader = EntailmentDataLoader(dataset=test_dataset, max_seq_len=512, batch_size=32, num_workers=16, shuffle=False)

In [9]:
w2i, i2w = EntailmentDataset.LABEL2INDEX, EntailmentDataset.INDEX2LABEL
print(w2i)
print(i2w)

{'NotEntail': 0, 'Entail_or_Paraphrase': 1}
{0: 'NotEntail', 1: 'Entail_or_Paraphrase'}


# Test model on sample sentences

In [10]:
text_A = 'Elektron hanya menduduki 0,06 % massa total atom .'
text_B = 'Elektron hanya mengambil 0,06 % massa total atom .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Elektron hanya menduduki 0,06 % massa total atom .
Text B: Elektron hanya mengambil 0,06 % massa total atom .
Label : Entail_or_Paraphrase (54.222%)


In [11]:
text_A = 'Sekarang , tidak ada yang tahu pasti kapan sejarah dimulai .'
text_B = 'Sejarah dimulai pada awal penciptaan manusia .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Sekarang , tidak ada yang tahu pasti kapan sejarah dimulai .
Text B: Sejarah dimulai pada awal penciptaan manusia .
Label : NotEntail (50.702%)


In [12]:
text_A = 'Hatikvah , arti harafiahnya adalah “Harapan” , merupakan Lagu Kebangsaan Israel .'
text_B = 'Hatikvah merupakan lagu kebangsaan Israel .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Hatikvah , arti harafiahnya adalah “Harapan” , merupakan Lagu Kebangsaan Israel .
Text B: Hatikvah merupakan lagu kebangsaan Israel .
Label : Entail_or_Paraphrase (56.224%)


# Fine Tuning & Evaluation

In [13]:
optimizer = optim.Adam(model.parameters(), lr=5e-6)
model = model.cuda()

In [14]:
# Train
n_epochs = 10
for epoch in range(n_epochs):
    model.train()
    torch.set_grad_enabled(True)
 
    total_train_loss = 0
    list_hyp, list_label = [], []

    train_pbar = tqdm(train_loader, leave=True, total=len(train_loader))
    for i, batch_data in enumerate(train_pbar):
        # Forward model
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')

        # Update model
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        tr_loss = loss.item()
        total_train_loss = total_train_loss + tr_loss

        # Calculate metrics
        list_hyp += batch_hyp
        list_label += batch_label

        train_pbar.set_description("(Epoch {}) TRAIN LOSS:{:.4f} LR:{:.8f}".format((epoch+1),
            total_train_loss/(i+1), get_lr(optimizer)))

    # Calculate train metric
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) TRAIN LOSS:{:.4f} {} LR:{:.8f}".format((epoch+1),
        total_train_loss/(i+1), metrics_to_string(metrics), get_lr(optimizer)))

    # Evaluate on validation
    model.eval()
    torch.set_grad_enabled(False)
    
    total_loss, total_correct, total_labels = 0, 0, 0
    list_hyp, list_label = [], []

    pbar = tqdm(valid_loader, leave=True, total=len(valid_loader))
    for i, batch_data in enumerate(pbar):
        batch_seq = batch_data[-1]        
        loss, batch_hyp, batch_label = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
        
        # Calculate total loss
        valid_loss = loss.item()
        total_loss = total_loss + valid_loss

        # Calculate evaluation metrics
        list_hyp += batch_hyp
        list_label += batch_label
        metrics = document_sentiment_metrics_fn(list_hyp, list_label)

        pbar.set_description("VALID LOSS:{:.4f} {}".format(total_loss/(i+1), metrics_to_string(metrics)))
        
    metrics = document_sentiment_metrics_fn(list_hyp, list_label)
    print("(Epoch {}) VALID LOSS:{:.4f} {}".format((epoch+1),
        total_loss/(i+1), metrics_to_string(metrics)))

(Epoch 1) TRAIN LOSS:0.6466 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.58it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 1) TRAIN LOSS:0.6466 ACC:0.62 F1:0.48 REC:0.53 PRE:0.59 LR:0.00000500


VALID LOSS:0.5841 ACC:0.74 F1:0.67 REC:0.67 PRE:0.79: 100%|██████████| 2/2 [00:00<00:00,  2.91it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 1) VALID LOSS:0.5841 ACC:0.74 F1:0.67 REC:0.67 PRE:0.79


(Epoch 2) TRAIN LOSS:0.5701 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.68it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 2) TRAIN LOSS:0.5701 ACC:0.78 F1:0.76 REC:0.75 PRE:0.78 LR:0.00000500


VALID LOSS:0.5407 ACC:0.74 F1:0.67 REC:0.67 PRE:0.79: 100%|██████████| 2/2 [00:00<00:00,  3.01it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 2) VALID LOSS:0.5407 ACC:0.74 F1:0.67 REC:0.67 PRE:0.79


(Epoch 3) TRAIN LOSS:0.5009 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.61it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 3) TRAIN LOSS:0.5009 ACC:0.77 F1:0.72 REC:0.71 PRE:0.82 LR:0.00000500


VALID LOSS:0.5060 ACC:0.80 F1:0.77 REC:0.76 PRE:0.81: 100%|██████████| 2/2 [00:00<00:00,  2.88it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 3) VALID LOSS:0.5060 ACC:0.80 F1:0.77 REC:0.76 PRE:0.81


(Epoch 4) TRAIN LOSS:0.4329 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.57it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 4) TRAIN LOSS:0.4329 ACC:0.84 F1:0.83 REC:0.82 PRE:0.85 LR:0.00000500


VALID LOSS:0.4807 ACC:0.80 F1:0.77 REC:0.76 PRE:0.81: 100%|██████████| 2/2 [00:00<00:00,  2.82it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 4) VALID LOSS:0.4807 ACC:0.80 F1:0.77 REC:0.76 PRE:0.81


(Epoch 5) TRAIN LOSS:0.3646 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.56it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 5) TRAIN LOSS:0.3646 ACC:0.87 F1:0.86 REC:0.84 PRE:0.89 LR:0.00000500


VALID LOSS:0.4757 ACC:0.80 F1:0.78 REC:0.77 PRE:0.80: 100%|██████████| 2/2 [00:00<00:00,  2.98it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 5) VALID LOSS:0.4757 ACC:0.80 F1:0.78 REC:0.77 PRE:0.80


(Epoch 6) TRAIN LOSS:0.3007 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.61it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 6) TRAIN LOSS:0.3007 ACC:0.90 F1:0.89 REC:0.88 PRE:0.91 LR:0.00000500


VALID LOSS:0.4795 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77: 100%|██████████| 2/2 [00:00<00:00,  2.92it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 6) VALID LOSS:0.4795 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77


(Epoch 7) TRAIN LOSS:0.2402 LR:0.00000500: 100%|██████████| 10/10 [00:04<00:00,  2.49it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 7) TRAIN LOSS:0.2402 ACC:0.92 F1:0.91 REC:0.91 PRE:0.93 LR:0.00000500


VALID LOSS:0.4980 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77: 100%|██████████| 2/2 [00:00<00:00,  2.94it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 7) VALID LOSS:0.4980 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77


(Epoch 8) TRAIN LOSS:0.1765 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.50it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 8) TRAIN LOSS:0.1765 ACC:0.95 F1:0.95 REC:0.95 PRE:0.95 LR:0.00000500


VALID LOSS:0.4805 ACC:0.84 F1:0.81 REC:0.79 PRE:0.90: 100%|██████████| 2/2 [00:00<00:00,  2.90it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 8) VALID LOSS:0.4805 ACC:0.84 F1:0.81 REC:0.79 PRE:0.90


(Epoch 9) TRAIN LOSS:0.1731 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.54it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 9) TRAIN LOSS:0.1731 ACC:0.94 F1:0.93 REC:0.92 PRE:0.94 LR:0.00000500


VALID LOSS:0.6017 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77: 100%|██████████| 2/2 [00:00<00:00,  2.93it/s]
  0%|          | 0/10 [00:00<?, ?it/s]

(Epoch 9) VALID LOSS:0.6017 ACC:0.78 F1:0.76 REC:0.75 PRE:0.77


(Epoch 10) TRAIN LOSS:0.1322 LR:0.00000500: 100%|██████████| 10/10 [00:03<00:00,  2.66it/s]
  0%|          | 0/2 [00:00<?, ?it/s]

(Epoch 10) TRAIN LOSS:0.1322 ACC:0.95 F1:0.95 REC:0.94 PRE:0.95 LR:0.00000500


VALID LOSS:0.5735 ACC:0.82 F1:0.80 REC:0.78 PRE:0.83: 100%|██████████| 2/2 [00:00<00:00,  2.96it/s]

(Epoch 10) VALID LOSS:0.5735 ACC:0.82 F1:0.80 REC:0.78 PRE:0.83





In [15]:
# Evaluate on test
model.eval()
torch.set_grad_enabled(False)

total_loss, total_correct, total_labels = 0, 0, 0
list_hyp, list_label = [], []

pbar = tqdm(test_loader, leave=True, total=len(test_loader))
for i, batch_data in enumerate(pbar):
    _, batch_hyp, _ = forward_sequence_classification(model, batch_data[:-1], i2w=i2w, device='cuda')
    list_hyp += batch_hyp

# Save prediction
df = pd.DataFrame({'label':list_hyp}).reset_index()
df.to_csv('pred.txt', index=False)

print(df)

100%|██████████| 4/4 [00:00<00:00,  4.91it/s]

    index                 label
0       0  Entail_or_Paraphrase
1       1  Entail_or_Paraphrase
2       2  Entail_or_Paraphrase
3       3             NotEntail
4       4  Entail_or_Paraphrase
..    ...                   ...
95     95  Entail_or_Paraphrase
96     96             NotEntail
97     97  Entail_or_Paraphrase
98     98  Entail_or_Paraphrase
99     99  Entail_or_Paraphrase

[100 rows x 2 columns]





# Test fine-tuned model on sample sentences

In [16]:
text_A = 'Elektron hanya menduduki 0,06 % massa total atom .'
text_B = 'Elektron hanya mengambil 0,06 % massa total atom .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids=token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Elektron hanya menduduki 0,06 % massa total atom .
Text B: Elektron hanya mengambil 0,06 % massa total atom .
Label : Entail_or_Paraphrase (97.744%)


In [17]:
text_A = 'Sekarang , tidak ada yang tahu pasti kapan sejarah dimulai .'
text_B = 'Sejarah dimulai pada awal penciptaan manusia .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Sekarang , tidak ada yang tahu pasti kapan sejarah dimulai .
Text B: Sejarah dimulai pada awal penciptaan manusia .
Label : NotEntail (97.298%)


In [18]:
text_A = 'Hatikvah , arti harafiahnya adalah “Harapan” , merupakan Lagu Kebangsaan Israel .'
text_B = 'Hatikvah merupakan lagu kebangsaan Israel .'

encoded_inputs = tokenizer.encode_plus(text_A, text_B, add_special_tokens=True, return_token_type_ids=True)
subwords, token_type_ids = encoded_inputs["input_ids"], encoded_inputs["token_type_ids"]

subwords = torch.LongTensor(subwords).view(1, -1).to(model.device)
token_type_ids = torch.LongTensor(token_type_ids).view(1, -1).to(model.device)

logits = model(subwords, token_type_ids=token_type_ids)[0]
label = torch.topk(logits, k=1, dim=-1)[1].squeeze().item()

print(f'Text A: {text_A}')
print(f'Text B: {text_B}')
print(f'Label : {i2w[label]} ({F.softmax(logits, dim=-1).squeeze()[label] * 100:.3f}%)')

Text A: Hatikvah , arti harafiahnya adalah “Harapan” , merupakan Lagu Kebangsaan Israel .
Text B: Hatikvah merupakan lagu kebangsaan Israel .
Label : Entail_or_Paraphrase (97.814%)
